|  |  | 
					
						
						|  | import tempfile | 
					
						
						|  |  | 
					
						
						|  | import jax | 
					
						
						|  | from jax import numpy as jnp | 
					
						
						|  | from transformers import AutoTokenizer, FlaxRobertaForMaskedLM, RobertaForMaskedLM | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def to_f32(t): | 
					
						
						|  | return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  |  | 
					
						
						|  | tokenizer = AutoTokenizer.from_pretrained("./") | 
					
						
						|  | tokenizer.save_pretrained("./") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tmp = tempfile.mkdtemp() | 
					
						
						|  | flax_model = FlaxRobertaForMaskedLM.from_pretrained("./") | 
					
						
						|  | flax_model.params = to_f32(flax_model.params) | 
					
						
						|  | flax_model.save_pretrained(tmp) | 
					
						
						|  |  | 
					
						
						|  | model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True) | 
					
						
						|  | model.save_pretrained("./", save_config=False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main() | 
					
						
						|  |  |