Upload ModularStarEncoder
Browse files- modularStarEncoder.py +4 -4
modularStarEncoder.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from transformers import Starcoder2Model
|
| 2 |
import sys
|
| 3 |
-
from
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Optional, Tuple, Union, List
|
|
@@ -307,10 +307,10 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
|
|
| 307 |
output_attentions=output_attentions,
|
| 308 |
output_hidden_states=True,
|
| 309 |
return_dict=True,
|
| 310 |
-
)
|
| 311 |
|
| 312 |
|
| 313 |
-
DEVICE = source_embedding[-1].get_device()
|
| 314 |
if DEVICE<0:
|
| 315 |
DEVICE = "cpu"
|
| 316 |
|
|
@@ -334,7 +334,7 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
|
|
| 334 |
|
| 335 |
pooled_and_normalized = []
|
| 336 |
for idx,matr_layer in enumerate(self.matryoshka_layers):
|
| 337 |
-
source_embedding_proj = projection_fn[idx](source_embedding[matr_layer])
|
| 338 |
|
| 339 |
normalized_source_embedding, embedding_norms = pool_and_normalize(
|
| 340 |
source_embedding_proj,
|
|
|
|
| 1 |
from transformers import Starcoder2Model
|
| 2 |
import sys
|
| 3 |
+
from config import ModularStarEncoderConfig
|
| 4 |
import os
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import Optional, Tuple, Union, List
|
|
|
|
| 307 |
output_attentions=output_attentions,
|
| 308 |
output_hidden_states=True,
|
| 309 |
return_dict=True,
|
| 310 |
+
)
|
| 311 |
|
| 312 |
|
| 313 |
+
DEVICE = source_embedding.hidden_states[-1].get_device()
|
| 314 |
if DEVICE<0:
|
| 315 |
DEVICE = "cpu"
|
| 316 |
|
|
|
|
| 334 |
|
| 335 |
pooled_and_normalized = []
|
| 336 |
for idx,matr_layer in enumerate(self.matryoshka_layers):
|
| 337 |
+
source_embedding_proj = projection_fn[idx](source_embedding.hidden_states[matr_layer])
|
| 338 |
|
| 339 |
normalized_source_embedding, embedding_norms = pool_and_normalize(
|
| 340 |
source_embedding_proj,
|