2ms's picture
init commit
03ae676
import os
import torch
from openslide import OpenSlide
from utils.preprocessor import MacenkoNormalizer, preprocessor
from torch.utils.data import Dataset
class WSIPatchDataset(Dataset):
def __init__(
self,
coords,
wsi_path,
pretrained=False,
patch_size=256,
patch_level=0,
macenko=True,
return_coord=False,
):
self.pretrained = pretrained
self.wsi = OpenSlide(wsi_path)
self.patch_size = patch_size
self.patch_level = patch_level
self.return_coord = return_coord
if macenko:
normalizer = MacenkoNormalizer(
target_path=os.path.join(
os.path.dirname(os.path.dirname(os.path.join(__file__))),
"macenko_target",
"macenko_param.pt",
)
)
else:
normalizer = None
self.roi_transforms = preprocessor(pretrained=pretrained, normalizer=normalizer)
self.coords = coords
self.length = len(self.coords)
def __len__(self):
return self.length
def __getitem__(self, idx):
coord = self.coords[idx]
img = self.wsi.read_region(
coord, self.patch_level, (self.patch_size, self.patch_size)
).convert("RGB")
img = self.roi_transforms(img)
if self.return_coord:
return img, torch.tensor(coord)
else:
return img