2ms's picture
init commit
03ae676
import numpy as np
import torch
import torch.nn as nn
from models.aggregator import make_model
from models.cls_modules import LinearClassifier
from datasets.dataset_WSI import WSIPatchDataset
from models.feature_extractor import vit_base
from utils.wsi_utils import extract_tissue_patch_coords
from torch.utils.data import DataLoader
from huggingface_hub import PyTorchModelHubMixin
class EXAONEPathV1p5Downstream(nn.Module, PyTorchModelHubMixin):
def __init__(
self, step_size=256, patch_size=256, num_sampled_patch=999999, macenko=True
):
super(EXAONEPathV1p5Downstream, self).__init__()
self.step_size = step_size
self.patch_size = patch_size
self.macenko = macenko
self.num_sampled_patch = num_sampled_patch
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = {
"step_size": step_size,
"patch_size": patch_size,
"macenko": macenko,
"num_sampled_patch": num_sampled_patch,
}
self.feature_extractor = vit_base()
self.feature_extractor = self.feature_extractor
# self.feature_extractor = self.feature_extractor.to(self.device)
# self.feature_extractor.eval()
self.agg_model = make_model(
embed_dim=768,
droprate=0.0,
num_registers=0,
depth=4,
)
self.agg_model = LinearClassifier(self.agg_model, pool='mean')
# self.agg_model.to(self.device)
# self.agg_model.eval()
@torch.no_grad()
def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
# Extract patches
coords = extract_tissue_patch_coords(
svs_path, patch_size=self.patch_size, step_size=self.step_size
)
# Extract patch-level features
self.feature_extractor.eval()
patch_dataset = WSIPatchDataset(
coords=coords,
wsi_path=svs_path,
pretrained=True,
macenko=self.macenko,
patch_size=self.patch_size,
return_coord=True,
)
patch_loader = DataLoader(
dataset=patch_dataset,
batch_size=feature_extractor_batch_size,
num_workers=(
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
),
pin_memory=self.device.type == "cuda",
)
features_list = []
coords_list = []
for count, items in enumerate(patch_loader):
patches, coords = items
print(
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
end="\r",
)
patches = patches.to(self.device, non_blocking=True)
feature = self.feature_extractor(patches) # [B, 1024]
feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren
feature = feature.to("cpu", non_blocking=True)
features_list.append(feature)
coords = coords.to(self.device, non_blocking=True)
coords_list.append(coords)
print("")
print("Feature extraction finished")
features = torch.cat(features_list)
coords = torch.cat(coords_list)
total_samples = features.shape[0]
num_samples = min(self.num_sampled_patch, total_samples)
indices = torch.randperm(total_samples)[:num_samples]
sampled_features = features[indices]
sampled_coords = coords[indices]
# Aggregate features
self.agg_model.eval()
# sampled_features = torch.randn([8192, 768])
# sampled_coords = torch.randn([8192, 2])
logits, Y_prob, Y_hat = self.agg_model(sampled_features[None].to(self.device), sampled_coords[None].to(self.device))
probs = Y_prob[0].cpu()
return probs
@torch.no_grad()
def forward_feature_extractor(self, svs_path: str, feature_extractor_batch_size: int = 8):
# Extract patches
coords = extract_tissue_patch_coords(
svs_path, patch_size=self.patch_size, step_size=self.step_size
)
# Extract patch-level features
self.feature_extractor.eval()
patch_dataset = WSIPatchDataset(
coords=coords,
wsi_path=svs_path,
pretrained=True,
macenko=self.macenko,
patch_size=self.patch_size,
return_coord=False
)
patch_loader = DataLoader(
dataset=patch_dataset,
batch_size=feature_extractor_batch_size,
num_workers=(
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
),
pin_memory=self.device.type == "cuda",
)
features_list = []
for count, patches in enumerate(patch_loader):
print(
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
end="\r",
)
patches = patches.to(self.device, non_blocking=True)
feature = self.feature_extractor(patches) # [B, 1024]
feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren
feature = feature.to("cpu", non_blocking=True)
features_list.append(feature)
print("")
print("Feature extraction finished")
features = torch.cat(features_list)
return features