Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -405,7 +405,9 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 405 |
"""
|
| 406 |
|
| 407 |
# Compute English token embeddings
|
|
|
|
| 408 |
tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
|
|
|
|
| 409 |
|
| 410 |
if projected_bio_embeddings is not None:
|
| 411 |
(
|
|
@@ -696,11 +698,14 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 696 |
|
| 697 |
if projected_bio_embeddings is None:
|
| 698 |
# Compute bio sequences embeddings
|
|
|
|
| 699 |
bio_embeddings_list = [
|
| 700 |
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
| 701 |
for bio_seq_num in range(num_bio_sequences)
|
| 702 |
]
|
| 703 |
|
|
|
|
|
|
|
| 704 |
# Project these embeddings
|
| 705 |
projected_bio_embeddings = [
|
| 706 |
self.projection_model(
|
|
@@ -710,9 +715,14 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 710 |
)
|
| 711 |
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
|
| 712 |
]
|
|
|
|
| 713 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
|
|
|
| 714 |
|
| 715 |
# decode
|
|
|
|
|
|
|
|
|
|
| 716 |
logits = self.biobrain_decoder(
|
| 717 |
english_token_ids=english_token_ids,
|
| 718 |
projected_bio_embeddings=projected_bio_embeddings,
|
|
|
|
| 405 |
"""
|
| 406 |
|
| 407 |
# Compute English token embeddings
|
| 408 |
+
print("(debug) in biobraindecoder, english tokens ids : ", english_token_ids.shape)
|
| 409 |
tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
|
| 410 |
+
print("(debug) tokens_embeddings shape : ", tokens_embeddings.shape)
|
| 411 |
|
| 412 |
if projected_bio_embeddings is not None:
|
| 413 |
(
|
|
|
|
| 698 |
|
| 699 |
if projected_bio_embeddings is None:
|
| 700 |
# Compute bio sequences embeddings
|
| 701 |
+
print("(debug) shape bio tokens ids : ", bio_tokens_ids.shape)
|
| 702 |
bio_embeddings_list = [
|
| 703 |
self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
|
| 704 |
for bio_seq_num in range(num_bio_sequences)
|
| 705 |
]
|
| 706 |
|
| 707 |
+
print("(debug) shape of embeddings : ", bio_embeddings_list[0].shape)
|
| 708 |
+
|
| 709 |
# Project these embeddings
|
| 710 |
projected_bio_embeddings = [
|
| 711 |
self.projection_model(
|
|
|
|
| 715 |
)
|
| 716 |
for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
|
| 717 |
]
|
| 718 |
+
print("(debug) Shape output projection model : ", projected_bio_embeddings[0].shape)
|
| 719 |
projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
|
| 720 |
+
print("(debug) Shape projected bio embeddings : "), projected_bio_embeddings.shape)
|
| 721 |
|
| 722 |
# decode
|
| 723 |
+
print("(debug) Going in biobrain decoder : ")
|
| 724 |
+
print("(debug) English token ids : ", english_token_ids.shape)
|
| 725 |
+
print("(debug) Projected bio embeddings : ", projected_bio_embeddings.shape)
|
| 726 |
logits = self.biobrain_decoder(
|
| 727 |
english_token_ids=english_token_ids,
|
| 728 |
projected_bio_embeddings=projected_bio_embeddings,
|