| #!/usr/bin/env python | |
| import tempfile | |
| import jax | |
| from jax import numpy as jnp | |
| from transformers import AutoTokenizer, FlaxRobertaForMaskedLM, RobertaForMaskedLM | |
| def to_f32(t): | |
| return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) | |
| def main(): | |
| # Saving extra files from config.json and tokenizer.json files | |
| tokenizer = AutoTokenizer.from_pretrained("./") | |
| tokenizer.save_pretrained("./") | |
| # Temporary saving bfloat16 Flax model into float32 | |
| tmp = tempfile.mkdtemp() | |
| flax_model = FlaxRobertaForMaskedLM.from_pretrained("./") | |
| flax_model.params = to_f32(flax_model.params) | |
| flax_model.save_pretrained(tmp) | |
| # Converting float32 Flax to PyTorch | |
| model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True) | |
| model.save_pretrained("./", save_config=False) | |
| if __name__ == "__main__": | |
| main() | |