Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -416,9 +416,11 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 416 |
_,
|
| 417 |
bio_embed_dim,
|
| 418 |
) = projected_bio_embeddings.shape
|
| 419 |
-
|
| 420 |
# Insert the bio embeddings at the SEQ token positions
|
| 421 |
processed_tokens_ids = english_token_ids.clone()
|
|
|
|
|
|
|
| 422 |
for bio_seq_num in range(num_bio_sequences):
|
| 423 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
| 424 |
processed_tokens_ids,
|
|
@@ -426,6 +428,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 426 |
projected_bio_embeddings[:, bio_seq_num, :, :],
|
| 427 |
bio_seq_num=bio_seq_num,
|
| 428 |
)
|
|
|
|
| 429 |
|
| 430 |
# Regular GPT pass through
|
| 431 |
print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
|
|
|
|
| 416 |
_,
|
| 417 |
bio_embed_dim,
|
| 418 |
) = projected_bio_embeddings.shape
|
| 419 |
+
|
| 420 |
# Insert the bio embeddings at the SEQ token positions
|
| 421 |
processed_tokens_ids = english_token_ids.clone()
|
| 422 |
+
print("(debug) Inside : processed tokens ids shape : ", processed_tokens_ids.shape)
|
| 423 |
+
print("(debug) Inside : projected bio embeddings shape : ", projected_bio_embeddings.shape)
|
| 424 |
for bio_seq_num in range(num_bio_sequences):
|
| 425 |
tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
|
| 426 |
processed_tokens_ids,
|
|
|
|
| 428 |
projected_bio_embeddings[:, bio_seq_num, :, :],
|
| 429 |
bio_seq_num=bio_seq_num,
|
| 430 |
)
|
| 431 |
+
print("After call : ", tokens_embeddings.shape)
|
| 432 |
|
| 433 |
# Regular GPT pass through
|
| 434 |
print("(debug) tokens embeddings shape : ", tokens_embeddings.shape)
|