create scripts/pl_models.py (#3)
Browse files- create scripts/pl_models.py (7fcfad978dbd4198fe8e424df30b21f27fce070e)
Co-authored-by: Ryan Keivanfar <[email protected]>
- scripts/pl_models.py +483 -0
scripts/pl_models.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pl_models.py
|
| 2 |
+
"""
|
| 3 |
+
This module defines PyTorch Lightning modules for the Tahoeformer project.
|
| 4 |
+
It includes a base model class (`LitBaseModel`) and the main experimental model
|
| 5 |
+
(`LitEnformerSMILES`) which combines an Enformer-based DNA sequence model with
|
| 6 |
+
drug information (SMILES string processed into Morgan Fingerprints) and dose information
|
| 7 |
+
to predict gene expression.
|
| 8 |
+
|
| 9 |
+
Key components:
|
| 10 |
+
- masked_mse: A utility loss function for Mean Squared Error that handles NaN targets.
|
| 11 |
+
- LitBaseModel: A base LightningModule providing common training, validation, test steps,
|
| 12 |
+
optimizer configuration, and basic metric logging hooks.
|
| 13 |
+
- LitEnformerSMILES: The primary model for predicting drug-induced gene expression changes,
|
| 14 |
+
using Enformer for DNA and Morgan fingerprints for drugs.
|
| 15 |
+
- MetricLogger: A PyTorch Lightning Callback for detailed logging of predictions.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import pandas as pd
|
| 19 |
+
import os
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import lightning.pytorch as pl
|
| 23 |
+
from enformer_pytorch.finetune import HeadAdapterWrapper
|
| 24 |
+
from enformer_pytorch import Enformer
|
| 25 |
+
from torchmetrics.regression import PearsonCorrCoef, R2Score
|
| 26 |
+
from warnings import warn
|
| 27 |
+
import wandb
|
| 28 |
+
import numpy as np # Added for MetricLogger consistency
|
| 29 |
+
|
| 30 |
+
# --- Utility Functions ---
|
| 31 |
+
def masked_mse(y_hat, y):
|
| 32 |
+
"""
|
| 33 |
+
Computes Mean Squared Error (MSE) while ignoring NaN values in the target tensor.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
y_hat (torch.Tensor): The predicted values.
|
| 37 |
+
y (torch.Tensor): The target values, which may contain NaNs.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
torch.Tensor: A scalar tensor representing the masked MSE. Returns 0.0 if all targets are NaN.
|
| 41 |
+
"""
|
| 42 |
+
mask = torch.isnan(y)
|
| 43 |
+
if mask.all(): # Handle case where all targets in batch are NaN
|
| 44 |
+
return torch.tensor(0.0, device=y_hat.device, requires_grad=True)
|
| 45 |
+
mse = torch.mean((y[~mask] - y_hat[~mask])**2)
|
| 46 |
+
return mse
|
| 47 |
+
|
| 48 |
+
# --- Base Lightning Module ---
|
| 49 |
+
class LitBaseModel(pl.LightningModule):
|
| 50 |
+
"""
|
| 51 |
+
A base PyTorch Lightning module providing common boilerplate for training and evaluation.
|
| 52 |
+
|
| 53 |
+
This class implements a generic training/validation/test step, loss calculation using
|
| 54 |
+
`masked_mse`, optimizer configuration (AdamW), and hooks for accumulating outputs
|
| 55 |
+
for detailed metric logging via the `MetricLogger` callback.
|
| 56 |
+
|
| 57 |
+
Derived classes are expected to implement the `forward` method.
|
| 58 |
+
|
| 59 |
+
Hyperparameters:
|
| 60 |
+
learning_rate (float): The learning rate for the optimizer.
|
| 61 |
+
loss_alpha (float): A coefficient for the primary loss term (MSE). Useful if
|
| 62 |
+
additional loss terms were to be added.
|
| 63 |
+
weight_decay (float, optional): Weight decay for the AdamW optimizer. If None,
|
| 64 |
+
AdamW's internal default is used.
|
| 65 |
+
eval_gene_sets (dict, optional): A dictionary where keys are set names (e.g., 'oncogenes')
|
| 66 |
+
and values are lists of gene IDs. Used by `MetricLogger`
|
| 67 |
+
to compute metrics for specific gene subsets.
|
| 68 |
+
"""
|
| 69 |
+
def __init__(self, learning_rate=5e-6, loss_alpha=1.0, weight_decay=None,
|
| 70 |
+
eval_gene_sets=None): # eval_gene_sets: dict {'train': [genes], 'valid': [genes], 'test': [genes]}
|
| 71 |
+
"""
|
| 72 |
+
Initializes the LitBaseModel.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
learning_rate (float, optional): Learning rate. Defaults to 5e-6.
|
| 76 |
+
loss_alpha (float, optional): Alpha for MSE loss. Defaults to 1.0.
|
| 77 |
+
weight_decay (float, optional): Weight decay for AdamW. If None, uses optimizer default.
|
| 78 |
+
Defaults to None.
|
| 79 |
+
eval_gene_sets (dict, optional): Dictionary of gene sets for targeted evaluation.
|
| 80 |
+
Keys are names, values are lists of gene IDs.
|
| 81 |
+
Defaults to None.
|
| 82 |
+
"""
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.save_hyperparameters()
|
| 85 |
+
self.learning_rate = learning_rate
|
| 86 |
+
self.loss_alpha = loss_alpha # alpha for mse vs. other terms (if any)
|
| 87 |
+
self.weight_decay = weight_decay
|
| 88 |
+
self.eval_gene_sets = eval_gene_sets if eval_gene_sets else {}
|
| 89 |
+
|
| 90 |
+
# Results accumulated per epoch for MetricLogger
|
| 91 |
+
self.epoch_outputs = []
|
| 92 |
+
|
| 93 |
+
def loss_fn(self, y_hat, y):
|
| 94 |
+
"""
|
| 95 |
+
Calculates the loss for the model.
|
| 96 |
+
|
| 97 |
+
Currently uses `masked_mse` scaled by `self.loss_alpha`.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
y_hat (torch.Tensor): Predicted values from the model.
|
| 101 |
+
y (torch.Tensor): Ground truth target values.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
torch.Tensor: The computed loss value.
|
| 105 |
+
"""
|
| 106 |
+
mse_term = masked_mse(y_hat, y)
|
| 107 |
+
# Potentially: add other loss terms here, weighted by (1-loss_alpha) if desired
|
| 108 |
+
return self.loss_alpha * mse_term
|
| 109 |
+
|
| 110 |
+
def _common_step(self, batch, batch_idx, step_type):
|
| 111 |
+
"""
|
| 112 |
+
A common step for training, validation, and testing.
|
| 113 |
+
|
| 114 |
+
This method unpacks the batch, performs a forward pass, calculates the loss,
|
| 115 |
+
logs the loss, and accumulates outputs for epoch-level metric calculation
|
| 116 |
+
(for validation and test steps).
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
batch: The batch of data from the DataLoader. Expected to be a tuple containing
|
| 120 |
+
DNA sequence, Morgan fingerprints, dose, target expression,
|
| 121 |
+
and metadata (gene_id, drug_id, cell_line).
|
| 122 |
+
batch_idx (int): The index of the current batch.
|
| 123 |
+
step_type (str): A string indicating the type of step ('train', 'val', or 'test').
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
torch.Tensor: The loss for the current batch.
|
| 127 |
+
"""
|
| 128 |
+
# Batch structure will change after dataset modification:
|
| 129 |
+
# (dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line)
|
| 130 |
+
dna_seq, morgan_fingerprints, dose, target_expression, gene_id, drug_id, cell_line = batch
|
| 131 |
+
|
| 132 |
+
y_hat = self(dna_seq, morgan_fingerprints, dose) # Call forward method of derived class
|
| 133 |
+
|
| 134 |
+
loss = self.loss_fn(y_hat, target_expression)
|
| 135 |
+
self.log(f'{step_type}_loss', loss, batch_size=target_expression.shape[0], on_step=(step_type=='train' and False), on_epoch=True, prog_bar=(step_type!='train'))
|
| 136 |
+
|
| 137 |
+
if step_type != 'train':
|
| 138 |
+
# Prepare data for MetricLogger
|
| 139 |
+
batch_size = target_expression.shape[0]
|
| 140 |
+
for i in range(batch_size):
|
| 141 |
+
item_data = {
|
| 142 |
+
'pred': y_hat[i].detach(),
|
| 143 |
+
'target': target_expression[i].detach(),
|
| 144 |
+
'gene_id': gene_id[i],
|
| 145 |
+
'drug_id': drug_id[i],
|
| 146 |
+
'cell_line': cell_line[i],
|
| 147 |
+
'rank': self.trainer.global_rank
|
| 148 |
+
}
|
| 149 |
+
self.epoch_outputs.append(item_data)
|
| 150 |
+
return loss
|
| 151 |
+
|
| 152 |
+
def training_step(self, batch, batch_idx):
|
| 153 |
+
"""PyTorch Lightning training step. Calls `_common_step`."""
|
| 154 |
+
return self._common_step(batch, batch_idx, 'train')
|
| 155 |
+
|
| 156 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
| 157 |
+
"""PyTorch Lightning validation step. Calls `_common_step`."""
|
| 158 |
+
return self._common_step(batch, batch_idx, 'val')
|
| 159 |
+
|
| 160 |
+
def test_step(self, batch, batch_idx, dataloader_idx=0):
|
| 161 |
+
"""PyTorch Lightning test step. Calls `_common_step`."""
|
| 162 |
+
return self._common_step(batch, batch_idx, 'test')
|
| 163 |
+
|
| 164 |
+
def on_validation_epoch_start(self):
|
| 165 |
+
"""Clears accumulated outputs at the start of each validation epoch."""
|
| 166 |
+
self.epoch_outputs = []
|
| 167 |
+
|
| 168 |
+
def on_test_epoch_start(self):
|
| 169 |
+
"""Clears accumulated outputs at the start of each test epoch."""
|
| 170 |
+
self.epoch_outputs = []
|
| 171 |
+
|
| 172 |
+
def configure_optimizers(self):
|
| 173 |
+
"""
|
| 174 |
+
Configures the optimizer for the model.
|
| 175 |
+
|
| 176 |
+
Uses AdamW with the specified learning rate and weight decay.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
torch.optim.Optimizer: The configured AdamW optimizer.
|
| 180 |
+
"""
|
| 181 |
+
if self.weight_decay is None:
|
| 182 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
|
| 183 |
+
else:
|
| 184 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
| 185 |
+
return optimizer
|
| 186 |
+
|
| 187 |
+
# --- Enformer + Morgan Fingerprints Model ---
|
| 188 |
+
class LitEnformerSMILES(LitBaseModel): # Consider renaming to LitEnformerMorgan for clarity
|
| 189 |
+
"""
|
| 190 |
+
A PyTorch Lightning module that combines genomic sequence information (via Enformer)
|
| 191 |
+
with drug chemical structure (represented by Morgan fingerprints) and drug dose
|
| 192 |
+
to predict gene expression changes.
|
| 193 |
+
|
| 194 |
+
The model architecture consists of three main branches:
|
| 195 |
+
1. DNA Module: Uses a pre-trained Enformer model (with an adapted head) to extract
|
| 196 |
+
features from a one-hot encoded DNA sequence centered around a gene's TSS.
|
| 197 |
+
2. Drug Module: Uses pre-computed Morgan fingerprints as the drug representation.
|
| 198 |
+
3. Dose Module: Directly uses the numerical dose value.
|
| 199 |
+
|
| 200 |
+
Features from these three branches are concatenated and passed through a multi-layer
|
| 201 |
+
fusion head (MLP with ReLU, BatchNorm, Dropout) to produce the final prediction
|
| 202 |
+
of gene expression.
|
| 203 |
+
|
| 204 |
+
Inherits common training and evaluation logic from `LitBaseModel`.
|
| 205 |
+
"""
|
| 206 |
+
def __init__(self,
|
| 207 |
+
enformer_model_name: str = 'EleutherAI/enformer-official-rough',
|
| 208 |
+
enformer_target_length: int = -1,
|
| 209 |
+
num_output_tracks_enformer_head: int = 1,
|
| 210 |
+
morgan_fingerprint_dim: int = 2048, # dim of the Morgan fingerprint vector
|
| 211 |
+
dose_input_dim: int = 1,
|
| 212 |
+
fusion_hidden_dim: int = 256,
|
| 213 |
+
final_output_tracks: int = 1,
|
| 214 |
+
learning_rate=5e-6,
|
| 215 |
+
loss_alpha=1.0,
|
| 216 |
+
weight_decay=None,
|
| 217 |
+
eval_gene_sets=None):
|
| 218 |
+
"""
|
| 219 |
+
Initializes the LitEnformerSMILES (or LitEnformerMorgan) model.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
enformer_model_name (str, optional): Name or path of the pre-trained Enformer model.
|
| 223 |
+
enformer_target_length (int, optional): Target length for Enformer's internal pooling.
|
| 224 |
+
num_output_tracks_enformer_head (int, optional): Output features from Enformer head.
|
| 225 |
+
morgan_fingerprint_dim (int, optional): Dimensionality of the Morgan fingerprint vector
|
| 226 |
+
(e.g., 2048 for ECFP4). Defaults to 2048.
|
| 227 |
+
dose_input_dim (int, optional): Dimensionality of the drug dose input. Defaults to 1.
|
| 228 |
+
fusion_hidden_dim (int, optional): Hidden dimension for the fusion MLP. Defaults to 256.
|
| 229 |
+
final_output_tracks (int, optional): Number of final output values. Defaults to 1.
|
| 230 |
+
learning_rate (float, optional): Learning rate. Defaults to 5e-6.
|
| 231 |
+
loss_alpha (float, optional): Weight for MSE loss. Defaults to 1.0.
|
| 232 |
+
weight_decay (float, optional): Weight decay. Defaults to None.
|
| 233 |
+
eval_gene_sets (dict, optional): Gene sets for targeted evaluation. Defaults to None.
|
| 234 |
+
"""
|
| 235 |
+
super().__init__(learning_rate, loss_alpha, weight_decay, eval_gene_sets)
|
| 236 |
+
self.save_hyperparameters(
|
| 237 |
+
"enformer_model_name", "enformer_target_length",
|
| 238 |
+
"num_output_tracks_enformer_head", "morgan_fingerprint_dim",
|
| 239 |
+
"dose_input_dim", "fusion_hidden_dim", "final_output_tracks",
|
| 240 |
+
"learning_rate", "loss_alpha", "weight_decay"
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# 1. DNA Module (Enformer with HeadAdapter)
|
| 244 |
+
enformer_pretrained = Enformer.from_pretrained(
|
| 245 |
+
self.hparams.enformer_model_name,
|
| 246 |
+
target_length=self.hparams.enformer_target_length
|
| 247 |
+
)
|
| 248 |
+
self.dna_module = HeadAdapterWrapper(
|
| 249 |
+
enformer=enformer_pretrained,
|
| 250 |
+
num_tracks=self.hparams.num_output_tracks_enformer_head,
|
| 251 |
+
post_transformer_embed=False,
|
| 252 |
+
output_activation=nn.Identity()
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# 2. Drug Module (Morgan Fingerprints are provided as input directly)
|
| 256 |
+
# No layers needed here as fingerprints are pre-computed.
|
| 257 |
+
# The self.hparams.morgan_fingerprint_dim defines the expected input dimension.
|
| 258 |
+
|
| 259 |
+
# 3. Fusion Head
|
| 260 |
+
# Input dimension uses morgan_fingerprint_dim
|
| 261 |
+
fusion_input_dim = self.hparams.num_output_tracks_enformer_head + self.hparams.morgan_fingerprint_dim + self.hparams.dose_input_dim
|
| 262 |
+
self.fusion_head = nn.Sequential(
|
| 263 |
+
nn.Linear(fusion_input_dim, self.hparams.fusion_hidden_dim),
|
| 264 |
+
nn.ReLU(),
|
| 265 |
+
nn.BatchNorm1d(self.hparams.fusion_hidden_dim),
|
| 266 |
+
nn.Dropout(0.25),
|
| 267 |
+
nn.Linear(self.hparams.fusion_hidden_dim, self.hparams.fusion_hidden_dim // 2),
|
| 268 |
+
nn.ReLU(),
|
| 269 |
+
nn.BatchNorm1d(self.hparams.fusion_hidden_dim // 2),
|
| 270 |
+
nn.Dropout(0.1),
|
| 271 |
+
nn.Linear(self.hparams.fusion_hidden_dim // 2, self.hparams.final_output_tracks)
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def forward(self, dna_seq, morgan_fingerprints, dose):
|
| 275 |
+
"""
|
| 276 |
+
Defines the forward pass of the LitEnformerSMILES model using Morgan Fingerprints.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
dna_seq (torch.Tensor): Batch of one-hot encoded DNA sequences.
|
| 280 |
+
Shape: (batch_size, sequence_length, 4).
|
| 281 |
+
morgan_fingerprints (torch.Tensor): Batch of pre-computed Morgan fingerprint vectors.
|
| 282 |
+
Shape: (batch_size, morgan_fingerprint_dim).
|
| 283 |
+
dose (torch.Tensor): Batch of drug dose values.
|
| 284 |
+
Shape: (batch_size, dose_input_dim).
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
torch.Tensor: The model's prediction. Shape: (batch_size, final_output_tracks).
|
| 288 |
+
"""
|
| 289 |
+
# --- DNA Processing ---
|
| 290 |
+
dna_out_intermediate = self.dna_module(dna_seq, freeze_enformer=False)
|
| 291 |
+
center_seq_idx = dna_out_intermediate.shape[1] // 2
|
| 292 |
+
dna_features = dna_out_intermediate[:, center_seq_idx, :]
|
| 293 |
+
|
| 294 |
+
# --- Drug Processing (Morgan Fingerprints) ---
|
| 295 |
+
# Morgan fingerprints are directly used as features.
|
| 296 |
+
smiles_features = morgan_fingerprints # Shape: (batch_size, morgan_fingerprint_dim)
|
| 297 |
+
|
| 298 |
+
# --- Dose Processing ---
|
| 299 |
+
if dose.ndim == 1:
|
| 300 |
+
dose = dose.unsqueeze(-1)
|
| 301 |
+
|
| 302 |
+
# --- Feature Combination & Final Prediction ---
|
| 303 |
+
combined_features = torch.cat([dna_features, smiles_features, dose], dim=1)
|
| 304 |
+
prediction = self.fusion_head(combined_features)
|
| 305 |
+
return prediction
|
| 306 |
+
|
| 307 |
+
# --- Metrics Logging Callback ---
|
| 308 |
+
class MetricLogger(pl.Callback):
|
| 309 |
+
"""
|
| 310 |
+
A PyTorch Lightning Callback for comprehensive metric calculation and logging.
|
| 311 |
+
|
| 312 |
+
This callback accumulates predictions and targets during validation and test epochs.
|
| 313 |
+
At the end of these epochs, it:
|
| 314 |
+
1. Processes the accumulated outputs into a pandas DataFrame.
|
| 315 |
+
2. Saves the raw predictions and targets for the epoch to a CSV file.
|
| 316 |
+
3. Logs a sample of these raw predictions as a W&B Table if WandbLogger is used.
|
| 317 |
+
4. Calculates overall performance metrics (MSE, Pearson, R2) for the epoch.
|
| 318 |
+
5. If `eval_gene_sets` are provided in the LightningModule, calculates metrics for these specific gene subsets.
|
| 319 |
+
6. Calculates metrics per cell line if 'cell_line' information is available in the outputs.
|
| 320 |
+
7. Logs all calculated metrics to the LightningModule's logger.
|
| 321 |
+
|
| 322 |
+
Attributes:
|
| 323 |
+
save_dir_prefix (str): Prefix for the directory where metric CSVs will be saved.
|
| 324 |
+
current_epoch_data (list): List to accumulate dictionaries of pred/target/metadata per item.
|
| 325 |
+
"""
|
| 326 |
+
def __init__(self, save_dir_prefix="results"):
|
| 327 |
+
"""
|
| 328 |
+
Initializes the MetricLogger callback.
|
| 329 |
+
|
| 330 |
+
Args:
|
| 331 |
+
save_dir_prefix (str, optional): Directory prefix for saving metrics files.
|
| 332 |
+
Defaults to "results".
|
| 333 |
+
"""
|
| 334 |
+
super().__init__()
|
| 335 |
+
self.save_dir_prefix = save_dir_prefix
|
| 336 |
+
self.current_epoch_data = []
|
| 337 |
+
|
| 338 |
+
def _process_epoch_outputs(self, pl_module, stage):
|
| 339 |
+
"""
|
| 340 |
+
Processes the raw outputs collected during an epoch into a pandas DataFrame.
|
| 341 |
+
|
| 342 |
+
Converts tensor data for 'pred' and 'target' columns to NumPy/Python native types.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
pl_module (pl.LightningModule): The LightningModule instance.
|
| 346 |
+
stage (str): The current stage (e.g., "validation", "test").
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
pd.DataFrame: A DataFrame containing the processed epoch outputs.
|
| 350 |
+
Returns an empty DataFrame if no outputs were collected.
|
| 351 |
+
"""
|
| 352 |
+
if not hasattr(pl_module, 'epoch_outputs') or not pl_module.epoch_outputs:
|
| 353 |
+
warn(f"No outputs collected (pl_module.epoch_outputs is missing or empty) during {stage} epoch for MetricLogger.")
|
| 354 |
+
return pd.DataFrame()
|
| 355 |
+
|
| 356 |
+
df = pd.DataFrame(pl_module.epoch_outputs)
|
| 357 |
+
|
| 358 |
+
for col in ['pred', 'target']:
|
| 359 |
+
if col in df.columns and not df[col].empty:
|
| 360 |
+
if isinstance(df[col].iloc[0], torch.Tensor):
|
| 361 |
+
df[col] = df[col].apply(lambda x: x.cpu().float().numpy().item() if x.numel() == 1 else x.cpu().float().numpy())
|
| 362 |
+
return df
|
| 363 |
+
|
| 364 |
+
def on_validation_epoch_end(self, trainer, pl_module):
|
| 365 |
+
"""Hook called at the end of the validation epoch."""
|
| 366 |
+
if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs:
|
| 367 |
+
self.current_epoch_data = self._process_epoch_outputs(pl_module, "validation")
|
| 368 |
+
if not self.current_epoch_data.empty:
|
| 369 |
+
self._log_and_save_metrics(trainer, pl_module, "validation")
|
| 370 |
+
else:
|
| 371 |
+
warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_validation_epoch_end.")
|
| 372 |
+
|
| 373 |
+
def on_test_epoch_end(self, trainer, pl_module):
|
| 374 |
+
"""Hook called at the end of the test epoch."""
|
| 375 |
+
if hasattr(pl_module, 'epoch_outputs') and pl_module.epoch_outputs:
|
| 376 |
+
self.current_epoch_data = self._process_epoch_outputs(pl_module, "test")
|
| 377 |
+
if not self.current_epoch_data.empty:
|
| 378 |
+
self._log_and_save_metrics(trainer, pl_module, "test")
|
| 379 |
+
else:
|
| 380 |
+
warn("MetricLogger: pl_module.epoch_outputs not found or empty at on_test_epoch_end.")
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def _log_and_save_metrics(self, trainer, pl_module, stage):
|
| 384 |
+
"""
|
| 385 |
+
Calculates, logs, and saves metrics for the current stage and epoch.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
trainer (pl.Trainer): The PyTorch Lightning Trainer instance.
|
| 389 |
+
pl_module (pl.LightningModule): The LightningModule instance.
|
| 390 |
+
stage (str): The current stage (e.g., "validation", "test").
|
| 391 |
+
"""
|
| 392 |
+
epoch = trainer.current_epoch if trainer.current_epoch is not None else -1
|
| 393 |
+
save_dir = getattr(pl_module.hparams, 'save_dir',
|
| 394 |
+
os.path.join(self.save_dir_prefix, f"run_{trainer.logger.version if trainer.logger else 'local'}"))
|
| 395 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 396 |
+
|
| 397 |
+
raw_preds_path = os.path.join(save_dir, f"{stage}_predictions_epoch_{epoch}.csv")
|
| 398 |
+
self.current_epoch_data.to_csv(raw_preds_path, index=False)
|
| 399 |
+
|
| 400 |
+
if trainer.logger and hasattr(trainer.logger, 'experiment') and isinstance(trainer.logger.experiment, wandb.sdk.wandb_run.Run):
|
| 401 |
+
try:
|
| 402 |
+
trainer.logger.experiment.log({f"{stage}_raw_predictions_epoch_{epoch}": wandb.Table(dataframe=self.current_epoch_data.head(1000))})
|
| 403 |
+
except Exception as e:
|
| 404 |
+
warn(f"MetricLogger: Failed to log raw predictions table to W&B: {e}")
|
| 405 |
+
|
| 406 |
+
overall_metrics = self._calculate_metrics_for_group(self.current_epoch_data, pl_module.device)
|
| 407 |
+
if overall_metrics:
|
| 408 |
+
pl_module.log_dict({f"{stage}_{k}_epoch": v for k, v in overall_metrics.items()}, sync_dist=True)
|
| 409 |
+
|
| 410 |
+
if hasattr(pl_module, 'eval_gene_sets') and pl_module.eval_gene_sets and isinstance(pl_module.eval_gene_sets, dict) and 'gene_id' in self.current_epoch_data.columns:
|
| 411 |
+
for split_name, gene_list in pl_module.eval_gene_sets.items():
|
| 412 |
+
if not gene_list: continue
|
| 413 |
+
split_df = self.current_epoch_data[self.current_epoch_data['gene_id'].isin(gene_list)]
|
| 414 |
+
if not split_df.empty:
|
| 415 |
+
split_metrics = self._calculate_metrics_for_group(split_df, pl_module.device)
|
| 416 |
+
if split_metrics:
|
| 417 |
+
pl_module.log_dict({f"{stage}_{split_name}_genes_{k}_epoch": v for k, v in split_metrics.items()}, sync_dist=True)
|
| 418 |
+
|
| 419 |
+
if 'cell_line' in self.current_epoch_data.columns:
|
| 420 |
+
for cell_line, group_df in self.current_epoch_data.groupby('cell_line'):
|
| 421 |
+
cl_metrics = self._calculate_metrics_for_group(group_df, pl_module.device)
|
| 422 |
+
if cl_metrics:
|
| 423 |
+
pl_module.log_dict({f"{stage}_{cell_line}_cell_line_{k}_epoch": v for k,v in cl_metrics.items()}, sync_dist=True)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def _calculate_metrics_for_group(self, df_group, device):
|
| 427 |
+
"""
|
| 428 |
+
Calculates regression metrics (MSE, Pearson, R2) for a given group of predictions.
|
| 429 |
+
|
| 430 |
+
Args:
|
| 431 |
+
df_group (pd.DataFrame): DataFrame containing 'pred' and 'target' columns for the group.
|
| 432 |
+
device (torch.device): The device to perform calculations on.
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
dict: A dictionary of calculated metrics (mse, pearson, r2). Returns empty if data is insufficient.
|
| 436 |
+
"""
|
| 437 |
+
if df_group.empty or 'pred' not in df_group.columns or 'target' not in df_group.columns:
|
| 438 |
+
return {}
|
| 439 |
+
|
| 440 |
+
preds_np = np.array(df_group['pred'].tolist(), dtype=np.float32)
|
| 441 |
+
targets_np = np.array(df_group['target'].tolist(), dtype=np.float32)
|
| 442 |
+
|
| 443 |
+
preds = torch.tensor(preds_np).to(device)
|
| 444 |
+
targets = torch.tensor(targets_np).to(device)
|
| 445 |
+
|
| 446 |
+
if preds.ndim == 1:
|
| 447 |
+
preds = preds.squeeze()
|
| 448 |
+
targets = targets.squeeze()
|
| 449 |
+
|
| 450 |
+
if preds.numel() == 0 or targets.numel() == 0 or preds.shape != targets.shape :
|
| 451 |
+
warn(f"Skipping metrics calculation for a group due to mismatched or empty preds/targets. Pred shape: {preds.shape}, Target shape: {targets.shape}")
|
| 452 |
+
return {}
|
| 453 |
+
|
| 454 |
+
mse_val_tensor = masked_mse(preds.unsqueeze(-1) if preds.ndim==1 else preds,
|
| 455 |
+
targets.unsqueeze(-1) if targets.ndim==1 else targets)
|
| 456 |
+
calculated_metrics = {'mse': mse_val_tensor.item()}
|
| 457 |
+
|
| 458 |
+
if preds.numel() < 2:
|
| 459 |
+
warn(f"Skipping Pearson/R2 for a group with < 2 samples. Found {preds.numel()} samples. Only MSE will be reported.")
|
| 460 |
+
return calculated_metrics
|
| 461 |
+
|
| 462 |
+
preds_for_corr = preds.squeeze()
|
| 463 |
+
targets_for_corr = targets.squeeze()
|
| 464 |
+
|
| 465 |
+
if preds_for_corr.shape != targets_for_corr.shape or preds_for_corr.ndim > 1 and preds_for_corr.shape[1] >1:
|
| 466 |
+
warn(f"Skipping Pearson/R2 due to incompatible shapes after squeeze for correlation. Pred: {preds_for_corr.shape}, Target: {targets_for_corr.shape}")
|
| 467 |
+
return calculated_metrics
|
| 468 |
+
|
| 469 |
+
try:
|
| 470 |
+
pearson_fn = PearsonCorrCoef().to(device)
|
| 471 |
+
pearson_val = pearson_fn(preds_for_corr, targets_for_corr)
|
| 472 |
+
calculated_metrics['pearson'] = pearson_val.item()
|
| 473 |
+
except Exception as e:
|
| 474 |
+
warn(f"Could not compute Pearson Correlation: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}")
|
| 475 |
+
|
| 476 |
+
try:
|
| 477 |
+
r2_fn = R2Score().to(device)
|
| 478 |
+
r2_val = r2_fn(preds_for_corr, targets_for_corr)
|
| 479 |
+
calculated_metrics['r2'] = r2_val.item()
|
| 480 |
+
except Exception as e:
|
| 481 |
+
warn(f"Could not compute R2 Score: {e}. Preds shape: {preds_for_corr.shape}, Targets shape: {targets_for_corr.shape}")
|
| 482 |
+
|
| 483 |
+
return calculated_metrics
|