mjpyeon commited on
Commit
adfc8b9
·
1 Parent(s): 9810060

fix preprocessing

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. exaonepath.py +31 -30
.gitignore CHANGED
@@ -172,3 +172,6 @@ cython_debug/
172
 
173
  # PyPI configuration file
174
  .pypirc
 
 
 
 
172
 
173
  # PyPI configuration file
174
  .pypirc
175
+
176
+ # Project-specific files
177
+ test.py
exaonepath.py CHANGED
@@ -1,5 +1,6 @@
1
  import math
2
  import typing as t
 
3
 
4
  import torch
5
  import torch.nn as nn
@@ -22,7 +23,13 @@ if t.TYPE_CHECKING:
22
  from _typeshed import StrPath
23
 
24
 
25
- class PadToDivisible(T.Transform):
 
 
 
 
 
 
26
  def __init__(self, size: int, pad_value: float | None = None):
27
  super().__init__()
28
  self.size = size
@@ -44,32 +51,6 @@ class PadToDivisible(T.Transform):
44
  return inpt
45
 
46
 
47
- class Preprocessing(T.Transform):
48
- def __init__(
49
- self, small_tile_size_with_this_mpp: int, small_tile_size_with_target_mpp: int
50
- ):
51
- self.small_tile_size_with_this_mpp = small_tile_size_with_this_mpp
52
- self.small_tile_size_with_target_mpp = small_tile_size_with_target_mpp
53
-
54
- def transform(self, inpt, params):
55
- assert isinstance(inpt, torch.Tensor) and inpt.ndim >= 3
56
-
57
- # Scale the input tensor to the target MPP
58
- if self.small_tile_size_with_this_mpp != self.small_tile_size_with_target_mpp:
59
- inpt = TF.resize(
60
- inpt,
61
- [
62
- self.small_tile_size_with_target_mpp,
63
- self.small_tile_size_with_target_mpp,
64
- ],
65
- )
66
-
67
- # Normalize the input tensor
68
- inpt = scale_and_normalize(inpt)
69
-
70
- return inpt
71
-
72
-
73
  class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
74
  def __init__(
75
  self,
@@ -103,7 +84,8 @@ class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
103
  self.model_first_stg,
104
  small_tiles,
105
  batch_size_on_gpu=first_stg_batch_size,
106
- preproc_fn=Preprocessing(
 
107
  small_tile_size_with_this_mpp=small_tile_size,
108
  small_tile_size_with_target_mpp=self.small_tile_size,
109
  ),
@@ -111,14 +93,14 @@ class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
111
  out_device=self.device,
112
  dtype=torch.bfloat16,
113
  )
114
- act1 = format_first_stg_act_as_second_stg_inp(
115
  act1,
116
  height=height,
117
  width=width,
118
  small_tile_size=small_tile_size,
119
  large_tile_size=large_tile_size,
120
  )
121
- act2: torch.Tensor = self.model_second_stg(act1)
122
  act2_formatted = format_second_stg_act_as_third_stg_inp(
123
  act2,
124
  height=height,
@@ -126,6 +108,7 @@ class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
126
  large_tile_size=large_tile_size,
127
  )
128
  act3: torch.Tensor = self.model_third_stg(act2_formatted)
 
129
  return act1[is_tile_valid], act2, act3
130
 
131
  def _load_wsi(self, svs_path: "StrPath", target_mpp: float):
@@ -163,3 +146,21 @@ class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
163
  is_tile_valid = mask_tile.sum(dim=(1, 2)) > 0
164
 
165
  return x, is_tile_valid, padded_size, small_tile_size, large_tile_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  import typing as t
3
+ from functools import partial
4
 
5
  import torch
6
  import torch.nn as nn
 
23
  from _typeshed import StrPath
24
 
25
 
26
+ class Transform(T.Transform):
27
+ # For compatibility with torchvision <= 0.20
28
+ def _transform(self, inpt, params):
29
+ return self.transform(inpt, params)
30
+
31
+
32
+ class PadToDivisible(Transform):
33
  def __init__(self, size: int, pad_value: float | None = None):
34
  super().__init__()
35
  self.size = size
 
51
  return inpt
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  class EXAONEPathV20(nn.Module, PyTorchModelHubMixin):
55
  def __init__(
56
  self,
 
84
  self.model_first_stg,
85
  small_tiles,
86
  batch_size_on_gpu=first_stg_batch_size,
87
+ preproc_fn=partial(
88
+ _preproc,
89
  small_tile_size_with_this_mpp=small_tile_size,
90
  small_tile_size_with_target_mpp=self.small_tile_size,
91
  ),
 
93
  out_device=self.device,
94
  dtype=torch.bfloat16,
95
  )
96
+ act1_formatted = format_first_stg_act_as_second_stg_inp(
97
  act1,
98
  height=height,
99
  width=width,
100
  small_tile_size=small_tile_size,
101
  large_tile_size=large_tile_size,
102
  )
103
+ act2: torch.Tensor = self.model_second_stg(act1_formatted)
104
  act2_formatted = format_second_stg_act_as_third_stg_inp(
105
  act2,
106
  height=height,
 
108
  large_tile_size=large_tile_size,
109
  )
110
  act3: torch.Tensor = self.model_third_stg(act2_formatted)
111
+
112
  return act1[is_tile_valid], act2, act3
113
 
114
  def _load_wsi(self, svs_path: "StrPath", target_mpp: float):
 
146
  is_tile_valid = mask_tile.sum(dim=(1, 2)) > 0
147
 
148
  return x, is_tile_valid, padded_size, small_tile_size, large_tile_size
149
+
150
+
151
+ def _preproc(
152
+ x: torch.Tensor,
153
+ small_tile_size_with_this_mpp: int,
154
+ small_tile_size_with_target_mpp: int,
155
+ ):
156
+ # Scale the input tensor to the target MPP
157
+ if small_tile_size_with_this_mpp != small_tile_size_with_target_mpp:
158
+ x = TF.resize(
159
+ x,
160
+ [small_tile_size_with_target_mpp, small_tile_size_with_target_mpp],
161
+ )
162
+
163
+ # Normalize the input tensor
164
+ x = scale_and_normalize(x)
165
+
166
+ return x