|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from models.aggregator import make_model |
|
from models.cls_modules import LinearClassifier |
|
from datasets.dataset_WSI import WSIPatchDataset |
|
from models.feature_extractor import vit_base |
|
from utils.wsi_utils import extract_tissue_patch_coords |
|
from torch.utils.data import DataLoader |
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
class EXAONEPathV1p5Downstream(nn.Module, PyTorchModelHubMixin): |
|
def __init__( |
|
self, step_size=256, patch_size=256, num_sampled_patch=999999, macenko=True |
|
): |
|
super(EXAONEPathV1p5Downstream, self).__init__() |
|
self.step_size = step_size |
|
self.patch_size = patch_size |
|
self.macenko = macenko |
|
self.num_sampled_patch = num_sampled_patch |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
self.config = { |
|
"step_size": step_size, |
|
"patch_size": patch_size, |
|
"macenko": macenko, |
|
"num_sampled_patch": num_sampled_patch, |
|
} |
|
|
|
self.feature_extractor = vit_base() |
|
self.feature_extractor = self.feature_extractor |
|
|
|
|
|
|
|
self.agg_model = make_model( |
|
embed_dim=768, |
|
droprate=0.0, |
|
num_registers=0, |
|
depth=4, |
|
) |
|
|
|
self.agg_model = LinearClassifier(self.agg_model, pool='mean') |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
def forward(self, svs_path: str, feature_extractor_batch_size: int = 8): |
|
|
|
coords = extract_tissue_patch_coords( |
|
svs_path, patch_size=self.patch_size, step_size=self.step_size |
|
) |
|
|
|
|
|
self.feature_extractor.eval() |
|
patch_dataset = WSIPatchDataset( |
|
coords=coords, |
|
wsi_path=svs_path, |
|
pretrained=True, |
|
macenko=self.macenko, |
|
patch_size=self.patch_size, |
|
return_coord=True, |
|
) |
|
patch_loader = DataLoader( |
|
dataset=patch_dataset, |
|
batch_size=feature_extractor_batch_size, |
|
num_workers=( |
|
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0 |
|
), |
|
pin_memory=self.device.type == "cuda", |
|
|
|
) |
|
features_list = [] |
|
coords_list = [] |
|
for count, items in enumerate(patch_loader): |
|
patches, coords = items |
|
print( |
|
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed", |
|
end="\r", |
|
) |
|
patches = patches.to(self.device, non_blocking=True) |
|
|
|
feature = self.feature_extractor(patches) |
|
feature /= feature.norm(dim=-1, keepdim=True) |
|
feature = feature.to("cpu", non_blocking=True) |
|
features_list.append(feature) |
|
|
|
coords = coords.to(self.device, non_blocking=True) |
|
coords_list.append(coords) |
|
|
|
print("") |
|
print("Feature extraction finished") |
|
|
|
features = torch.cat(features_list) |
|
coords = torch.cat(coords_list) |
|
total_samples = features.shape[0] |
|
|
|
num_samples = min(self.num_sampled_patch, total_samples) |
|
indices = torch.randperm(total_samples)[:num_samples] |
|
sampled_features = features[indices] |
|
sampled_coords = coords[indices] |
|
|
|
|
|
self.agg_model.eval() |
|
|
|
|
|
|
|
|
|
logits, Y_prob, Y_hat = self.agg_model(sampled_features[None].to(self.device), sampled_coords[None].to(self.device)) |
|
probs = Y_prob[0].cpu() |
|
|
|
return probs |
|
|
|
@torch.no_grad() |
|
def forward_feature_extractor(self, svs_path: str, feature_extractor_batch_size: int = 8): |
|
|
|
coords = extract_tissue_patch_coords( |
|
svs_path, patch_size=self.patch_size, step_size=self.step_size |
|
) |
|
|
|
|
|
self.feature_extractor.eval() |
|
patch_dataset = WSIPatchDataset( |
|
coords=coords, |
|
wsi_path=svs_path, |
|
pretrained=True, |
|
macenko=self.macenko, |
|
patch_size=self.patch_size, |
|
return_coord=False |
|
) |
|
patch_loader = DataLoader( |
|
dataset=patch_dataset, |
|
batch_size=feature_extractor_batch_size, |
|
num_workers=( |
|
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0 |
|
), |
|
pin_memory=self.device.type == "cuda", |
|
) |
|
features_list = [] |
|
for count, patches in enumerate(patch_loader): |
|
print( |
|
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed", |
|
end="\r", |
|
) |
|
patches = patches.to(self.device, non_blocking=True) |
|
|
|
feature = self.feature_extractor(patches) |
|
feature /= feature.norm(dim=-1, keepdim=True) |
|
feature = feature.to("cpu", non_blocking=True) |
|
features_list.append(feature) |
|
print("") |
|
print("Feature extraction finished") |
|
|
|
features = torch.cat(features_list) |
|
|
|
return features |