|
import math |
|
import typing as t |
|
from functools import partial |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.nn as nn |
|
from cucim import CuImage |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from torchvision.transforms import functional as TF |
|
from torchvision.transforms import v2 as T |
|
|
|
from networks.vit import vit4k_base, vit_base, vit_global_base |
|
from utils.tensor_utils import ( |
|
format_first_stg_act_as_second_stg_inp, |
|
format_second_stg_act_as_third_stg_inp, |
|
forward_with_batch_size_limit, |
|
scale_and_normalize, |
|
tile, |
|
) |
|
from utils.wsi_utils import load_slide_img, segment_tissue |
|
|
|
if t.TYPE_CHECKING: |
|
from _typeshed import StrPath |
|
|
|
|
|
class Transform(T.Transform): |
|
|
|
def _transform(self, inpt, params): |
|
return self.transform(inpt, params) |
|
|
|
|
|
class PadToDivisible(Transform): |
|
def __init__(self, size: int, pad_value: float | None = None): |
|
super().__init__() |
|
self.size = size |
|
self.pad_value = pad_value |
|
|
|
def transform(self, inpt, params): |
|
assert isinstance(inpt, torch.Tensor) and inpt.ndim >= 3 |
|
|
|
H, W = inpt.shape[-2:] |
|
|
|
pad_h = (self.size - H % self.size) % self.size |
|
pad_w = (self.size - W % self.size) % self.size |
|
|
|
if pad_h > 0 or pad_w > 0: |
|
inpt = torch.nn.functional.pad( |
|
inpt, (0, pad_w, 0, pad_h), value=self.pad_value |
|
) |
|
|
|
return inpt |
|
|
|
|
|
class EXAONEPathV20(nn.Module, PyTorchModelHubMixin): |
|
def __init__( |
|
self, |
|
small_tile_size: int = 256, |
|
large_tile_size: int = 4096, |
|
): |
|
super().__init__() |
|
|
|
self.small_tile_size = small_tile_size |
|
self.large_tile_size = large_tile_size |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
self.model_first_stg = vit_base().to(self.device).eval() |
|
self.model_second_stg = vit4k_base().to(self.device).eval() |
|
self.model_third_stg = vit_global_base().to(self.device).eval() |
|
|
|
def forward( |
|
self, |
|
svs_path: "StrPath", |
|
target_mpp: float = 0.5, |
|
first_stg_batch_size: int = 128, |
|
): |
|
small_tiles, is_tile_valid, padded_size, small_tile_size, large_tile_size = ( |
|
self._load_wsi(svs_path, target_mpp=target_mpp) |
|
) |
|
width, height = padded_size |
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
|
with torch.no_grad(): |
|
act1 = forward_with_batch_size_limit( |
|
self.model_first_stg, |
|
small_tiles, |
|
batch_size_on_gpu=first_stg_batch_size, |
|
preproc_fn=partial( |
|
_preproc, |
|
small_tile_size_with_this_mpp=small_tile_size, |
|
small_tile_size_with_target_mpp=self.small_tile_size, |
|
), |
|
device=self.device, |
|
out_device="cpu", |
|
dtype=torch.bfloat16, |
|
) |
|
act1 = act1.to(self.device) |
|
act1_formatted = format_first_stg_act_as_second_stg_inp( |
|
act1, |
|
height=height, |
|
width=width, |
|
small_tile_size=small_tile_size, |
|
large_tile_size=large_tile_size, |
|
) |
|
act2: torch.Tensor = self.model_second_stg(act1_formatted) |
|
act2_formatted = format_second_stg_act_as_third_stg_inp( |
|
act2, |
|
height=height, |
|
width=width, |
|
large_tile_size=large_tile_size, |
|
) |
|
act3: torch.Tensor = self.model_third_stg(act2_formatted) |
|
|
|
return act1[is_tile_valid], act2, act3 |
|
|
|
def _load_wsi(self, svs_path: "StrPath", target_mpp: float): |
|
svs_path = str(svs_path) |
|
|
|
|
|
with CuImage(str(svs_path)) as wsi_obj: |
|
try: |
|
mpp = float(wsi_obj.metadata["aperio"]["MPP"]) |
|
except KeyError: |
|
print( |
|
f"Warning: MPP metadata not found, using default value of {target_mpp}" |
|
) |
|
mpp = target_mpp |
|
|
|
img = load_slide_img(wsi_obj) |
|
height, width = img.shape[:2] |
|
mask_tensor = torch.from_numpy( |
|
segment_tissue(Path(svs_path), seg_level=-1)[0] |
|
) |
|
mask_tensor = TF.resize(mask_tensor.unsqueeze(0), [height, width]).squeeze( |
|
0 |
|
) |
|
x: torch.Tensor = torch.from_numpy(img).permute(2, 0, 1) |
|
|
|
small_tile_size = math.ceil(self.small_tile_size * (target_mpp / mpp)) |
|
large_tile_size = ( |
|
self.large_tile_size // self.small_tile_size |
|
) * small_tile_size |
|
pad_image = PadToDivisible(large_tile_size, 255) |
|
pad_mask = PadToDivisible(large_tile_size, 0) |
|
|
|
x = pad_image(x) |
|
padded_size = (x.size(-1), x.size(-2)) |
|
|
|
x = tile(x, small_tile_size) |
|
mask_padded = pad_mask(mask_tensor.unsqueeze(0)) |
|
mask_tile = tile(mask_padded, small_tile_size).squeeze(1) |
|
is_tile_valid = mask_tile.sum(dim=(1, 2)) > 0 |
|
|
|
return x, is_tile_valid, padded_size, small_tile_size, large_tile_size |
|
|
|
|
|
def _preproc( |
|
x: torch.Tensor, |
|
small_tile_size_with_this_mpp: int, |
|
small_tile_size_with_target_mpp: int, |
|
): |
|
|
|
if small_tile_size_with_this_mpp != small_tile_size_with_target_mpp: |
|
x = TF.resize( |
|
x, |
|
[small_tile_size_with_target_mpp, small_tile_size_with_target_mpp], |
|
) |
|
|
|
|
|
x = scale_and_normalize(x) |
|
|
|
return x |
|
|