Update chatNT.py
Browse files
chatNT.py
CHANGED
|
@@ -590,7 +590,6 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 590 |
def __init__(self, config: ChatNTConfig) -> None:
|
| 591 |
print("(debug) Entering in class")
|
| 592 |
if isinstance(config, dict):
|
| 593 |
-
print("(debug) going in if condition")
|
| 594 |
# If config is a dictionary instead of ChatNTConfig (which can happen
|
| 595 |
# depending how the config was saved), we convert it to the config
|
| 596 |
config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
|
|
@@ -598,14 +597,24 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 598 |
)
|
| 599 |
config["gpt_config"] = GptConfig(**config["gpt_config"])
|
| 600 |
config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
|
| 601 |
-
print("(debug) Type esm_config : ", type(config["esm_config"]))
|
| 602 |
-
print("(debug) esm_config : ", config["esm_config"])
|
| 603 |
config["perceiver_resampler_config"] = PerceiverResamplerConfig(
|
| 604 |
**config["perceiver_resampler_config"]
|
| 605 |
)
|
| 606 |
config = ChatNTConfig(**config) # type: ignore
|
| 607 |
-
print("(debug) Type config : ", type(config))
|
| 608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
print("(debug) config : ", config)
|
| 610 |
print("(debug) config type : ", type(config))
|
| 611 |
print("(debug) gpt config : ", config.gpt_config)
|
|
|
|
| 590 |
def __init__(self, config: ChatNTConfig) -> None:
|
| 591 |
print("(debug) Entering in class")
|
| 592 |
if isinstance(config, dict):
|
|
|
|
| 593 |
# If config is a dictionary instead of ChatNTConfig (which can happen
|
| 594 |
# depending how the config was saved), we convert it to the config
|
| 595 |
config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
|
|
|
|
| 597 |
)
|
| 598 |
config["gpt_config"] = GptConfig(**config["gpt_config"])
|
| 599 |
config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
|
|
|
|
|
|
|
| 600 |
config["perceiver_resampler_config"] = PerceiverResamplerConfig(
|
| 601 |
**config["perceiver_resampler_config"]
|
| 602 |
)
|
| 603 |
config = ChatNTConfig(**config) # type: ignore
|
|
|
|
| 604 |
|
| 605 |
+
else:
|
| 606 |
+
if isinstance(config.gpt_config, dict):
|
| 607 |
+
config.gpt_config["rope_config"] = RotaryEmbeddingConfig(
|
| 608 |
+
**config.gpt_config.["rope_config"]
|
| 609 |
+
)
|
| 610 |
+
config.gpt_config = GptConfig(**config.gpt_config)
|
| 611 |
+
|
| 612 |
+
if isinstance(config.esm_config, dict):
|
| 613 |
+
config.esm_config = ESMTransformerConfig(**config.esm_config)
|
| 614 |
+
|
| 615 |
+
if isinstance(config.perceiver_resampler_config, dict):
|
| 616 |
+
config.esm_config = PerceiverResamplerConfig(**config.perceiver_resampler_config)
|
| 617 |
+
|
| 618 |
print("(debug) config : ", config)
|
| 619 |
print("(debug) config type : ", type(config))
|
| 620 |
print("(debug) gpt config : ", config.gpt_config)
|