import numpy as np import torch import torch.nn as nn from models.aggregator import make_model from models.cls_modules import LinearClassifier from datasets.dataset_WSI import WSIPatchDataset from models.feature_extractor import vit_base from utils.wsi_utils import extract_tissue_patch_coords from torch.utils.data import DataLoader from huggingface_hub import PyTorchModelHubMixin class EXAONEPathV1p5Downstream(nn.Module, PyTorchModelHubMixin): def __init__( self, step_size=256, patch_size=256, num_sampled_patch=999999, macenko=True ): super(EXAONEPathV1p5Downstream, self).__init__() self.step_size = step_size self.patch_size = patch_size self.macenko = macenko self.num_sampled_patch = num_sampled_patch self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.config = { "step_size": step_size, "patch_size": patch_size, "macenko": macenko, "num_sampled_patch": num_sampled_patch, } self.feature_extractor = vit_base() self.feature_extractor = self.feature_extractor # self.feature_extractor = self.feature_extractor.to(self.device) # self.feature_extractor.eval() self.agg_model = make_model( embed_dim=768, droprate=0.0, num_registers=0, depth=4, ) self.agg_model = LinearClassifier(self.agg_model, pool='mean') # self.agg_model.to(self.device) # self.agg_model.eval() @torch.no_grad() def forward(self, svs_path: str, feature_extractor_batch_size: int = 8): # Extract patches coords = extract_tissue_patch_coords( svs_path, patch_size=self.patch_size, step_size=self.step_size ) # Extract patch-level features self.feature_extractor.eval() patch_dataset = WSIPatchDataset( coords=coords, wsi_path=svs_path, pretrained=True, macenko=self.macenko, patch_size=self.patch_size, return_coord=True, ) patch_loader = DataLoader( dataset=patch_dataset, batch_size=feature_extractor_batch_size, num_workers=( feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0 ), pin_memory=self.device.type == "cuda", ) features_list = [] coords_list = [] for count, items in enumerate(patch_loader): patches, coords = items print( f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed", end="\r", ) patches = patches.to(self.device, non_blocking=True) feature = self.feature_extractor(patches) # [B, 1024] feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren feature = feature.to("cpu", non_blocking=True) features_list.append(feature) coords = coords.to(self.device, non_blocking=True) coords_list.append(coords) print("") print("Feature extraction finished") features = torch.cat(features_list) coords = torch.cat(coords_list) total_samples = features.shape[0] num_samples = min(self.num_sampled_patch, total_samples) indices = torch.randperm(total_samples)[:num_samples] sampled_features = features[indices] sampled_coords = coords[indices] # Aggregate features self.agg_model.eval() # sampled_features = torch.randn([8192, 768]) # sampled_coords = torch.randn([8192, 2]) logits, Y_prob, Y_hat = self.agg_model(sampled_features[None].to(self.device), sampled_coords[None].to(self.device)) probs = Y_prob[0].cpu() return probs @torch.no_grad() def forward_feature_extractor(self, svs_path: str, feature_extractor_batch_size: int = 8): # Extract patches coords = extract_tissue_patch_coords( svs_path, patch_size=self.patch_size, step_size=self.step_size ) # Extract patch-level features self.feature_extractor.eval() patch_dataset = WSIPatchDataset( coords=coords, wsi_path=svs_path, pretrained=True, macenko=self.macenko, patch_size=self.patch_size, return_coord=False ) patch_loader = DataLoader( dataset=patch_dataset, batch_size=feature_extractor_batch_size, num_workers=( feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0 ), pin_memory=self.device.type == "cuda", ) features_list = [] for count, patches in enumerate(patch_loader): print( f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed", end="\r", ) patches = patches.to(self.device, non_blocking=True) feature = self.feature_extractor(patches) # [B, 1024] feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren feature = feature.to("cpu", non_blocking=True) features_list.append(feature) print("") print("Feature extraction finished") features = torch.cat(features_list) return features