File size: 5,574 Bytes
03ae676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import numpy as np
import torch
import torch.nn as nn
from models.aggregator import make_model
from models.cls_modules import LinearClassifier
from datasets.dataset_WSI import WSIPatchDataset
from models.feature_extractor import vit_base
from utils.wsi_utils import extract_tissue_patch_coords
from torch.utils.data import DataLoader

from huggingface_hub import PyTorchModelHubMixin


class EXAONEPathV1p5Downstream(nn.Module, PyTorchModelHubMixin):
    def __init__(
        self, step_size=256, patch_size=256, num_sampled_patch=999999, macenko=True
    ):
        super(EXAONEPathV1p5Downstream, self).__init__()
        self.step_size = step_size
        self.patch_size = patch_size
        self.macenko = macenko
        self.num_sampled_patch = num_sampled_patch
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.config = {
            "step_size": step_size,
            "patch_size": patch_size,
            "macenko": macenko,
            "num_sampled_patch": num_sampled_patch,
        }

        self.feature_extractor = vit_base()
        self.feature_extractor = self.feature_extractor
        # self.feature_extractor = self.feature_extractor.to(self.device)
        # self.feature_extractor.eval()

        self.agg_model = make_model(
            embed_dim=768,
            droprate=0.0,
            num_registers=0,
            depth=4,
            )
        
        self.agg_model = LinearClassifier(self.agg_model, pool='mean')
        
        # self.agg_model.to(self.device)
        # self.agg_model.eval()

    @torch.no_grad()
    def forward(self, svs_path: str, feature_extractor_batch_size: int = 8):
        # Extract patches
        coords = extract_tissue_patch_coords(
            svs_path, patch_size=self.patch_size, step_size=self.step_size
        )

        # Extract patch-level features
        self.feature_extractor.eval()
        patch_dataset = WSIPatchDataset(
            coords=coords,
            wsi_path=svs_path,
            pretrained=True,
            macenko=self.macenko,
            patch_size=self.patch_size,
            return_coord=True,
        )
        patch_loader = DataLoader(
            dataset=patch_dataset,
            batch_size=feature_extractor_batch_size,
            num_workers=(
                feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
            ),
            pin_memory=self.device.type == "cuda",

        )
        features_list = []
        coords_list = []
        for count, items in enumerate(patch_loader):
            patches, coords = items
            print(
                f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
                end="\r",
            )
            patches = patches.to(self.device, non_blocking=True)

            feature = self.feature_extractor(patches)  # [B, 1024]
            feature /= feature.norm(dim=-1, keepdim=True)  # use normalized featuren
            feature = feature.to("cpu", non_blocking=True)
            features_list.append(feature)

            coords = coords.to(self.device, non_blocking=True)
            coords_list.append(coords)

        print("")
        print("Feature extraction finished")

        features = torch.cat(features_list)
        coords = torch.cat(coords_list)
        total_samples = features.shape[0]

        num_samples = min(self.num_sampled_patch, total_samples)
        indices = torch.randperm(total_samples)[:num_samples]
        sampled_features = features[indices]
        sampled_coords = coords[indices]
        
        # Aggregate features
        self.agg_model.eval()

        # sampled_features = torch.randn([8192, 768])
        # sampled_coords = torch.randn([8192, 2])

        logits, Y_prob, Y_hat = self.agg_model(sampled_features[None].to(self.device), sampled_coords[None].to(self.device))
        probs = Y_prob[0].cpu()

        return probs

    @torch.no_grad()
    def forward_feature_extractor(self, svs_path: str, feature_extractor_batch_size: int = 8):
        # Extract patches
        coords = extract_tissue_patch_coords(
            svs_path, patch_size=self.patch_size, step_size=self.step_size
        )

        # Extract patch-level features
        self.feature_extractor.eval()
        patch_dataset = WSIPatchDataset(
            coords=coords,
            wsi_path=svs_path,
            pretrained=True,
            macenko=self.macenko,
            patch_size=self.patch_size,
            return_coord=False
        )
        patch_loader = DataLoader(
            dataset=patch_dataset,
            batch_size=feature_extractor_batch_size,
            num_workers=(
                feature_extractor_batch_size * 2 if self.device.type == "cuda" else 0
            ),
            pin_memory=self.device.type == "cuda",
        )
        features_list = []
        for count, patches in enumerate(patch_loader):
            print(
                f"batch {count+1}/{len(patch_loader)}, {count * feature_extractor_batch_size} patches processed",
                end="\r",
            )
            patches = patches.to(self.device, non_blocking=True)

            feature = self.feature_extractor(patches)  # [B, 1024]
            feature /= feature.norm(dim=-1, keepdim=True)  # use normalized featuren
            feature = feature.to("cpu", non_blocking=True)
            features_list.append(feature)
        print("")
        print("Feature extraction finished")

        features = torch.cat(features_list)

        return features