swe-gpt-wiki / save_model.py
birgermoell's picture
Added pytorch model
c8213eb
raw
history blame
302 Bytes
from transformers.modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
from transformers import GPT2Config, GPT2Model
config = GPT2Config.from_pretrained("./")
model = GPT2Model(config)
load_flax_checkpoint_in_pytorch_model(model, "./flax_model.msgpack")
model.save_pretrained("./")