| import torch |
| import sys |
| from sklearn.preprocessing import StandardScaler |
| import pytorch_lightning as pl |
| from torch.utils.data import DataLoader |
| from lightning.pytorch.utilities.combined_loader import CombinedLoader |
| import numpy as np |
| from scipy.spatial import cKDTree |
| import math |
| from functools import partial |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| from torch.utils.data import TensorDataset |
| from sklearn.neighbors import kneighbors_graph |
| import igraph as ig |
| from leidenalg import find_partition, ModularityVertexPartition |
|
|
| class WeightedBranchedVeresDataModule(pl.LightningDataModule): |
|
|
| def __init__(self, args): |
| super().__init__() |
| self.save_hyperparameters() |
|
|
| self.data_path = args.data_path |
| self.batch_size = args.batch_size |
| self.max_dim = args.dim |
| self.whiten = args.whiten |
| self.k = 20 |
| self.num_timesteps = 8 |
| |
| self.num_branches = args.branches if hasattr(args, 'branches') else None |
| self.split_ratios = args.split_ratios |
| self.metric_clusters = args.metric_clusters |
| self.discard_small = args.discard if hasattr(args, 'discard') else False |
| self.args = args |
| self._prepare_data() |
|
|
| def _prepare_data(self): |
| print("Preparing Veres cell data with Leiden clustering in WeightedBranchedVeresLeidenDataModule") |
| df = pd.read_csv(self.data_path) |
|
|
| |
| coords_by_t = { |
| t: df[df["samples"] == t].iloc[:, 1:].values |
| for t in sorted(df["samples"].unique()) |
| } |
|
|
| n0 = coords_by_t[0].shape[0] |
| self.n_samples = n0 |
|
|
| print("Timepoint distribution:") |
| for t in sorted(coords_by_t.keys()): |
| print(f" t={t}: {coords_by_t[t].shape[0]} points") |
|
|
| |
| final_t = max(coords_by_t.keys()) |
| coords_final = coords_by_t[final_t] |
| k = 20 |
| knn_graph = kneighbors_graph(coords_final, k, mode='connectivity', include_self=False) |
| sources, targets = knn_graph.nonzero() |
| edgelist = list(zip(sources.tolist(), targets.tolist())) |
| graph = ig.Graph(edgelist, directed=False) |
| partition = find_partition(graph, ModularityVertexPartition) |
| leiden_labels = np.array(partition.membership) |
| n_leiden = len(np.unique(leiden_labels)) |
| print(f"Leiden found {n_leiden} clusters at t={final_t}") |
|
|
| df_final = df[df["samples"] == final_t].copy() |
| df_final["branch"] = leiden_labels |
|
|
| cluster_counts = df_final["branch"].value_counts().sort_index() |
| print(f"Branch distribution at t={final_t} (pre-merge):") |
| print(cluster_counts) |
|
|
| |
| min_cells = 100 |
| cluster_data_dict = {} |
| cluster_sizes = [] |
| for b in range(n_leiden): |
| branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values |
| cluster_data_dict[b] = branch_data |
| cluster_sizes.append(branch_data.shape[0]) |
|
|
| large_clusters = [b for b, size in enumerate(cluster_sizes) if size >= min_cells] |
| small_clusters = [b for b, size in enumerate(cluster_sizes) if size < min_cells] |
|
|
| |
| if len(large_clusters) == 0: |
| large_clusters = list(range(n_leiden)) |
| small_clusters = [] |
|
|
| if self.discard_small: |
| |
| print(f"Discarding {len(small_clusters)} small clusters (< {min_cells} cells)") |
| |
| mask = np.isin(leiden_labels, large_clusters) |
| df_final = df_final[mask].copy() |
| merged_labels = leiden_labels[mask] |
| |
| |
| new_ids = np.unique(merged_labels) |
| id_map = {old: new for new, old in enumerate(new_ids)} |
| merged_labels = np.array([id_map[x] for x in merged_labels]) |
| n_merged = len(np.unique(merged_labels)) |
| |
| df_final["branch"] = merged_labels |
| print(f"Kept {n_merged} large clusters") |
| else: |
| centroids = {b: np.mean(cluster_data_dict[b], axis=0) for b in range(n_leiden) if cluster_data_dict[b].shape[0] > 0} |
|
|
| merged_labels = leiden_labels.copy() |
| for b in small_clusters: |
| if cluster_data_dict[b].shape[0] == 0: |
| continue |
| |
| dists = [np.linalg.norm(centroids[b] - centroids[bl]) for bl in large_clusters] |
| nearest_large = large_clusters[int(np.argmin(dists))] |
| merged_labels[leiden_labels == b] = nearest_large |
|
|
| |
| new_ids = np.unique(merged_labels) |
| id_map = {old: new for new, old in enumerate(new_ids)} |
| merged_labels = np.array([id_map[x] for x in merged_labels]) |
| n_merged = len(np.unique(merged_labels)) |
|
|
| df_final["branch"] = merged_labels |
| print(f"Merged into {n_merged} clusters") |
| |
| cluster_counts_merged = df_final["branch"].value_counts().sort_index() |
| print(f"Branch distribution at t={final_t} (post-merge):") |
| print(cluster_counts_merged) |
|
|
| endpoints = {} |
| cluster_sizes = [] |
| for b in range(n_merged): |
| branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values |
| cluster_sizes.append(branch_data.shape[0]) |
| replace = branch_data.shape[0] < n0 |
| sampled_indices = np.random.choice(branch_data.shape[0], size=n0, replace=replace) |
| endpoints[b] = branch_data[sampled_indices] |
| total_t_final = sum(cluster_sizes) |
|
|
| x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) |
| self.coords_t0 = x0 |
| |
| self.coords_intermediate = {t: torch.tensor(coords_by_t[t], dtype=torch.float32) |
| for t in coords_by_t.keys() if t != 0 and t != final_t} |
|
|
| self.branch_endpoints = {b: torch.tensor(endpoints[b], dtype=torch.float32) for b in range(n_merged)} |
| self.num_branches = n_merged |
|
|
| |
| time_labels_list = [np.zeros(len(self.coords_t0))] |
| for t in sorted(self.coords_intermediate.keys()): |
| time_labels_list.append(np.ones(len(self.coords_intermediate[t])) * t) |
| for b in range(self.num_branches): |
| time_labels_list.append(np.ones(len(self.branch_endpoints[b])) * final_t) |
| self.time_labels = np.concatenate(time_labels_list) |
|
|
| |
| split_index = int(n0 * self.split_ratios[0]) |
| if n0 - split_index < self.batch_size: |
| split_index = n0 - self.batch_size |
|
|
| train_x0 = x0[:split_index] |
| val_x0 = x0[split_index:] |
| self.val_x0 = val_x0 |
|
|
| train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) |
| val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) |
|
|
| |
| branch_weights = [size / total_t_final for size in cluster_sizes] |
|
|
| |
| train_intermediate = {} |
| val_intermediate = {} |
| self.train_coords_intermediate = {} |
| for t in sorted(self.coords_intermediate.keys()): |
| coords_t = self.coords_intermediate[t] |
| train_coords_t = coords_t[:split_index] |
| val_coords_t = coords_t[split_index:] |
| train_weights_t = torch.full((train_coords_t.shape[0], 1), fill_value=1.0) |
| val_weights_t = torch.full((val_coords_t.shape[0], 1), fill_value=1.0) |
| train_intermediate[f"x{t}"] = (train_coords_t, train_weights_t) |
| val_intermediate[f"x{t}"] = (val_coords_t, val_weights_t) |
| self.train_coords_intermediate[t] = train_coords_t |
|
|
| train_loaders = { |
| "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), |
| } |
| val_loaders = { |
| "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), |
| } |
|
|
| |
| for t_key in sorted(train_intermediate.keys()): |
| train_coords_t, train_weights_t = train_intermediate[t_key] |
| val_coords_t, val_weights_t = val_intermediate[t_key] |
| train_loaders[t_key] = DataLoader( |
| TensorDataset(train_coords_t, train_weights_t), |
| batch_size=self.batch_size, |
| shuffle=True, |
| drop_last=True |
| ) |
| val_loaders[t_key] = DataLoader( |
| TensorDataset(val_coords_t, val_weights_t), |
| batch_size=self.batch_size, |
| shuffle=False, |
| drop_last=True |
| ) |
|
|
| for b in range(self.num_branches): |
| |
| branch_size = self.branch_endpoints[b].shape[0] |
| branch_split_index = int(branch_size * self.split_ratios[0]) |
| if branch_size - branch_split_index < self.batch_size: |
| branch_split_index = max(0, branch_size - self.batch_size) |
| |
| train_branch = self.branch_endpoints[b][:branch_split_index] |
| val_branch = self.branch_endpoints[b][branch_split_index:] |
| train_branch_weights = torch.full((train_branch.shape[0], 1), fill_value=branch_weights[b]) |
| val_branch_weights = torch.full((val_branch.shape[0], 1), fill_value=branch_weights[b]) |
| train_loaders[f"x1_{b+1}"] = DataLoader( |
| TensorDataset(train_branch, train_branch_weights), |
| batch_size=self.batch_size, |
| shuffle=True, |
| drop_last=True |
| ) |
| val_loaders[f"x1_{b+1}"] = DataLoader( |
| TensorDataset(val_branch, val_branch_weights), |
| batch_size=self.batch_size, |
| shuffle=True, |
| drop_last=True |
| ) |
|
|
| self.train_dataloaders = train_loaders |
| self.val_dataloaders = val_loaders |
|
|
| |
| all_data_list = [coords_by_t[t] for t in sorted(coords_by_t.keys())] |
| all_data = np.vstack(all_data_list) |
| self.dataset = torch.tensor(all_data, dtype=torch.float32) |
| self.tree = cKDTree(all_data) |
|
|
| self.test_dataloaders = { |
| "x0": DataLoader(TensorDataset(self.val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), |
| "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), |
| } |
|
|
| |
| cluster_0_data = self.coords_t0.cpu().numpy() |
| cluster_1_list = [self.coords_intermediate[t].cpu().numpy() for t in sorted(self.coords_intermediate.keys())] |
| cluster_1_list.extend([self.branch_endpoints[b].cpu().numpy() for b in range(self.num_branches)]) |
| cluster_1_data = np.vstack(cluster_1_list) |
|
|
| self.metric_samples_dataloaders = [ |
| DataLoader(torch.tensor(cluster_0_data, dtype=torch.float32), batch_size=cluster_0_data.shape[0], shuffle=False, drop_last=False), |
| DataLoader(torch.tensor(cluster_1_data, dtype=torch.float32), batch_size=cluster_1_data.shape[0], shuffle=False, drop_last=False), |
| ] |
|
|
| def train_dataloader(self): |
| combined_loaders = { |
| "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), |
| "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), |
| } |
| return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
| def val_dataloader(self): |
| combined_loaders = { |
| "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), |
| "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), |
| } |
| return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
| def test_dataloader(self): |
| combined_loaders = { |
| "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), |
| "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), |
| } |
| return CombinedLoader(combined_loaders, mode="max_size_cycle") |
|
|
| def get_manifold_proj(self, points): |
| return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) |
|
|
| @staticmethod |
| def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): |
| points_np = x.detach().cpu().numpy() |
| _, idx = tree.query(points_np, k=k) |
| nearest_pts = dataset[idx] |
| dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) |
| weights = torch.exp(-dists / temp) |
| weights = weights / weights.sum(dim=1, keepdim=True) |
| smoothed = (weights * nearest_pts).sum(dim=1) |
| alpha = 0.3 |
| return (1 - alpha) * x + alpha * smoothed |
|
|
| def get_timepoint_data(self): |
| result = { |
| 't0': self.coords_t0, |
| 'time_labels': self.time_labels |
| } |
| |
| for t in sorted(self.coords_intermediate.keys()): |
| result[f't{t}'] = self.coords_intermediate[t] |
| final_t = max([0] + list(self.coords_intermediate.keys())) + 1 |
| for b in range(self.num_branches): |
| result[f't{final_t}_{b}'] = self.branch_endpoints[b] |
| return result |
|
|
| def get_train_intermediate_data(self): |
| if hasattr(self, 'train_coords_intermediate'): |
| return self.train_coords_intermediate |
| else: |
| |
| print("Warning: train_coords_intermediate not found, returning full intermediate data.") |
| return self.coords_intermediate |
|
|