File size: 5,574 Bytes
03ae676 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import numpy as np
import torch
import torch.nn as nn
from models.aggregator import make_model
from models.cls_modules import LinearClassifier
from datasets.dataset_WSI import WSIPatchDataset
from models.feature_extractor import vit_base
from utils.wsi_utils import extract_tissue_patch_coords
from torch.utils.data import DataLoader
from huggingface_hub import PyTorchModelHubMixin
class EXAONEPathV1p5Downstream(nn.Module, PyTorchModelHubMixin):
def __init__(
self, step_size=256, patch_size=256, num_sampled_patch=999999, macenko=True
):
super(EXAONEPathV1p5Downstream, self).__init__()
self.step_size = step_size
self.patch_size = patch_size
self.macenko = macenko
self.num_sampled_patch = num_sampled_patch
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.config = {
"step_size": step_size,
"patch_size": patch_size,
"macenko": macenko,
"num_sampled_patch": num_sampled_patch,
}
self.feature_extractor = vit_base()
self.feature_extractor = self.feature_extractor
# self.feature_extractor = self.feature_extractor.to(self.device)
# self.feature_extractor.eval()
self.agg_model = make_model(
embed_dim=768,
droprate=0.0,
num_registers=0,
depth=4,
)
self.agg_model = LinearClassifier(self.agg_model, pool='mean')
# self.agg_model.to(self.device)
# self.agg_model.eval()
@torch.no_grad()
def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
# Extract patches
coords = extract_tissue_patch_coords(
svs_path, patch_size=self.patch_size, step_size=self.step_size
)
# Extract patch-level features
self.feature_extractor.eval()
patch_dataset = WSIPatchDataset(
coords=coords,
wsi_path=svs_path,
pretrained=True,
macenko=self.macenko,
patch_size=self.patch_size,
return_coord=True,
)
patch_loader = DataLoader(
dataset=patch_dataset,
batch_size=feature_extractor_batch_size,
num_workers=(
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
),
pin_memory=self.device.type == "cuda",
)
features_list = []
coords_list = []
for count, items in enumerate(patch_loader):
patches, coords = items
print(
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
end="\r",
)
patches = patches.to(self.device, non_blocking=True)
feature = self.feature_extractor(patches) # [B, 1024]
feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren
feature = feature.to("cpu", non_blocking=True)
features_list.append(feature)
coords = coords.to(self.device, non_blocking=True)
coords_list.append(coords)
print("")
print("Feature extraction finished")
features = torch.cat(features_list)
coords = torch.cat(coords_list)
total_samples = features.shape[0]
num_samples = min(self.num_sampled_patch, total_samples)
indices = torch.randperm(total_samples)[:num_samples]
sampled_features = features[indices]
sampled_coords = coords[indices]
# Aggregate features
self.agg_model.eval()
# sampled_features = torch.randn([8192, 768])
# sampled_coords = torch.randn([8192, 2])
logits, Y_prob, Y_hat = self.agg_model(sampled_features[None].to(self.device), sampled_coords[None].to(self.device))
probs = Y_prob[0].cpu()
return probs
@torch.no_grad()
def forward_feature_extractor(self, svs_path: str, feature_extractor_batch_size: int = 8):
# Extract patches
coords = extract_tissue_patch_coords(
svs_path, patch_size=self.patch_size, step_size=self.step_size
)
# Extract patch-level features
self.feature_extractor.eval()
patch_dataset = WSIPatchDataset(
coords=coords,
wsi_path=svs_path,
pretrained=True,
macenko=self.macenko,
patch_size=self.patch_size,
return_coord=False
)
patch_loader = DataLoader(
dataset=patch_dataset,
batch_size=feature_extractor_batch_size,
num_workers=(
feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
),
pin_memory=self.device.type == "cuda",
)
features_list = []
for count, patches in enumerate(patch_loader):
print(
f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
end="\r",
)
patches = patches.to(self.device, non_blocking=True)
feature = self.feature_extractor(patches) # [B, 1024]
feature /= feature.norm(dim=-1, keepdim=True) # use normalized featuren
feature = feature.to("cpu", non_blocking=True)
features_list.append(feature)
print("")
print("Feature extraction finished")
features = torch.cat(features_list)
return features |