feat: for converting v2, added lines to save model weights and print config
Browse files- convert_v2_weights.py +8 -1
convert_v2_weights.py
CHANGED
|
@@ -131,6 +131,12 @@ new_state_dict = remap_state_dict(state_dict, config)
|
|
| 131 |
flash_model = BertModel(config)
|
| 132 |
flash_model.load_state_dict(new_state_dict)
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
|
| 135 |
inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
|
| 136 |
v2_model.eval()
|
|
@@ -141,4 +147,5 @@ output_v2 = v2_model(**inp)
|
|
| 141 |
output_flash = flash_model(**inp)
|
| 142 |
x = output_v2.last_hidden_state
|
| 143 |
y = output_flash.last_hidden_state
|
| 144 |
-
print(torch.abs(x - y))
|
|
|
|
|
|
| 131 |
flash_model = BertModel(config)
|
| 132 |
flash_model.load_state_dict(new_state_dict)
|
| 133 |
|
| 134 |
+
|
| 135 |
+
torch.save(new_state_dict, 'converted_weights.bin')
|
| 136 |
+
print(config.to_json_string())
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
"""
|
| 140 |
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
|
| 141 |
inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
|
| 142 |
v2_model.eval()
|
|
|
|
| 147 |
output_flash = flash_model(**inp)
|
| 148 |
x = output_v2.last_hidden_state
|
| 149 |
y = output_flash.last_hidden_state
|
| 150 |
+
print(torch.abs(x - y))
|
| 151 |
+
"""
|