|
import typing as t |
|
|
|
import torch |
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
def tile(x: torch.Tensor, size: int, pad_value: int | float | None = None): |
|
C, H, W = x.shape[-3:] |
|
|
|
pad_h = (size - H % size) % size |
|
pad_w = (size - W % size) % size |
|
if pad_h > 0 or pad_w > 0: |
|
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), value=pad_value) |
|
|
|
nh, nw = x.size(-2) // size, x.size(-1) // size |
|
return ( |
|
x.view(-1, C, nh, size, nw, size) |
|
.permute(0, 2, 4, 1, 3, 5) |
|
.reshape(-1, C, size, size) |
|
) |
|
|
|
|
|
def small_tiles_to_large_tiles( |
|
small_tiles: torch.Tensor, |
|
width: int, |
|
large_tile_size: int, |
|
sampled_large_tiles_idx: list | torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
|
|
has_channel = small_tiles.ndim == 4 |
|
small_tile_size = small_tiles.size(-1) |
|
num_small_tiles = small_tiles.size(0) |
|
|
|
nw = width // small_tile_size |
|
nh = num_small_tiles // nw |
|
|
|
r = large_tile_size // small_tile_size |
|
|
|
num_large_tiles = (nh // r) * (nw // r) |
|
large_tile_indices = ( |
|
range(num_large_tiles) |
|
if sampled_large_tiles_idx is None |
|
else sampled_large_tiles_idx |
|
) |
|
|
|
tiles = [] |
|
for k in large_tile_indices: |
|
start_row = (k // (nw // r)) * r |
|
start_col = (k % (nw // r)) * r |
|
for i in range(start_row, start_row + r): |
|
for j in range(start_col, start_col + r): |
|
tiles.append(small_tiles[i * nw + j]) |
|
|
|
stacked = torch.stack(tiles, dim=0).view(-1, r, r, *small_tiles.shape[1:]) |
|
if has_channel: |
|
large_tiles = stacked.permute(0, 3, 1, 4, 2, 5).reshape( |
|
-1, small_tiles.size(1), large_tile_size, large_tile_size |
|
) |
|
else: |
|
large_tiles = stacked.permute(0, 1, 3, 2, 4).reshape( |
|
-1, large_tile_size, large_tile_size |
|
) |
|
return large_tiles |
|
|
|
|
|
def small_tile_flags_to_large_tile_flags( |
|
small_tile_flags: torch.Tensor, |
|
width: int, |
|
small_tile_size: int, |
|
large_tile_size: int, |
|
aggregation: t.Literal["any", "all"] = "any", |
|
): |
|
small_tile_flags = small_tile_flags.view(-1, 1, 1) |
|
num_small_tiles = small_tile_flags.size(0) |
|
nw = width // small_tile_size |
|
r = large_tile_size // small_tile_size |
|
num_large_tiles = num_small_tiles // r**2 |
|
large_tile_flags = small_tiles_to_large_tiles( |
|
small_tile_flags, |
|
width=nw, |
|
large_tile_size=r, |
|
).view(num_large_tiles, -1) |
|
return ( |
|
large_tile_flags.any(-1) if aggregation == "any" else large_tile_flags.all(-1) |
|
) |
|
|
|
|
|
def format_first_stg_act_as_second_stg_inp( |
|
x: torch.Tensor, |
|
height: int, |
|
width: int, |
|
small_tile_size: int, |
|
large_tile_size: int, |
|
): |
|
assert height % small_tile_size == 0 and width % small_tile_size == 0 |
|
D = x.size(1) |
|
nh, nw = height // small_tile_size, width // small_tile_size |
|
r = large_tile_size // small_tile_size |
|
x = x.view(-1, nh, nw, D) |
|
x = x.permute(0, 3, 1, 2).reshape(-1, D, nh // r, r, nw // r, r) |
|
x = x.permute(0, 2, 4, 1, 3, 5).reshape(-1, D, r, r) |
|
return x |
|
|
|
|
|
def format_second_stg_inp_as_first_stg_act( |
|
x: torch.Tensor, height: int, width: int, small_tile_size: int, large_tile_size: int |
|
): |
|
D = x.size(1) |
|
nh, nw = height // small_tile_size, width // small_tile_size |
|
r = large_tile_size // small_tile_size |
|
x = x.view(-1, nh // r, nw // r, D, r, r) |
|
x = x.permute(0, 3, 1, 4, 2, 5).reshape(-1, D, nh, nw) |
|
x = x.permute(0, 2, 3, 1).reshape(-1, D) |
|
return x |
|
|
|
|
|
def format_second_stg_act_as_third_stg_inp( |
|
x: torch.Tensor, |
|
height: int, |
|
width: int, |
|
large_tile_size: int, |
|
): |
|
D = x.size(1) |
|
nh = height // large_tile_size |
|
nw = width // large_tile_size |
|
return x.view(-1, nh, nw, D).permute(0, 3, 1, 2).contiguous() |
|
|
|
|
|
def forward_with_batch_size_limit( |
|
net, |
|
x: torch.Tensor, |
|
batch_size_on_gpu: int, |
|
device: str | torch.device, |
|
out_device: str | torch.device, |
|
preproc_fn: t.Callable[[torch.Tensor], torch.Tensor] | None = None, |
|
dtype: torch.dtype = torch.float32, |
|
): |
|
features = list() |
|
for start_idx in range(0, x.size(0), batch_size_on_gpu): |
|
end_idx = min(x.size(0), start_idx + batch_size_on_gpu) |
|
batch = x[start_idx:end_idx].to(device=device, non_blocking=True) |
|
batch = preproc_fn(batch) if preproc_fn else batch |
|
batch = batch.to(dtype=dtype, non_blocking=True) |
|
actual_bs = end_idx - start_idx |
|
batch = pad_to_batch(batch, batch_size_on_gpu) |
|
batch: torch.Tensor = forward_compiled(net, batch) |
|
|
|
features.append(batch[:actual_bs].to(device=out_device, non_blocking=True)) |
|
if torch.device(out_device).type == "cpu" and torch.device(device).type == "cuda": |
|
torch.cuda.synchronize() |
|
return torch.cat(features) |
|
|
|
|
|
@t.overload |
|
def backward_with_batch_size_limit( |
|
net, |
|
x: torch.Tensor, |
|
grad: torch.Tensor, |
|
batch_size_on_gpu: int, |
|
device: str | torch.device, |
|
out_device: str | torch.device, |
|
dtype: torch.dtype, |
|
ret_grad: t.Literal[True], |
|
) -> torch.Tensor: ... |
|
|
|
|
|
@t.overload |
|
def backward_with_batch_size_limit( |
|
net, |
|
x: torch.Tensor, |
|
grad: torch.Tensor, |
|
batch_size_on_gpu: int, |
|
device: str | torch.device, |
|
out_device: str | torch.device, |
|
dtype: torch.dtype, |
|
ret_grad: t.Literal[False], |
|
) -> None: ... |
|
|
|
|
|
def backward_with_batch_size_limit( |
|
net, |
|
x: torch.Tensor, |
|
grad: torch.Tensor, |
|
batch_size_on_gpu: int, |
|
device: str | torch.device, |
|
out_device: str | torch.device, |
|
dtype: torch.dtype, |
|
ret_grad: bool, |
|
): |
|
assert x.size(0) == grad.size(0) |
|
|
|
grads = [] |
|
total = x.size(0) |
|
for start in range(0, total, batch_size_on_gpu): |
|
end = min(total, start + batch_size_on_gpu) |
|
actual_bs = end - start |
|
|
|
batch = x[start:end].to(device=device, dtype=dtype, non_blocking=True) |
|
batch = pad_to_batch(batch, batch_size_on_gpu) |
|
if ret_grad: |
|
batch.requires_grad_(True) |
|
|
|
with torch.autocast(device_type="cuda", dtype=dtype): |
|
out = net(batch) |
|
|
|
|
|
grad_batch = grad[start:end].to(device=device, dtype=dtype, non_blocking=True) |
|
grad_batch = pad_to_batch(grad_batch, batch_size_on_gpu) |
|
|
|
with torch._dynamo.utils.maybe_enable_compiled_autograd( |
|
True, fullgraph=True, dynamic=False |
|
): |
|
out.backward(grad_batch) |
|
|
|
|
|
if ret_grad: |
|
assert batch.grad is not None |
|
grads.append(batch.grad[:actual_bs].to(out_device, non_blocking=True)) |
|
|
|
if ret_grad: |
|
if ( |
|
torch.device(out_device).type == "cpu" |
|
and torch.device(device).type == "cuda" |
|
): |
|
torch.cuda.synchronize() |
|
return torch.cat(grads) |
|
|
|
|
|
@torch.compile(fullgraph=True, dynamic=False) |
|
def forward_compiled(net, x: torch.Tensor) -> torch.Tensor: |
|
return net(x) |
|
|
|
|
|
def pad_to_batch(t: torch.Tensor, batch_size: int) -> torch.Tensor: |
|
assert ( |
|
t.size(0) <= batch_size |
|
), f"'{t.shape}' size tensor cannot be padded to be batch size of '{batch_size}'" |
|
pad = batch_size - t.size(0) |
|
return torch.cat([t, t.new_zeros((pad,) + t.shape[1:])], dim=0) if pad > 0 else t |
|
|
|
|
|
def scale_and_normalize(x: torch.Tensor, inplace: bool = False): |
|
x = x.clamp_(0, 255) if inplace else x.clamp(0, 255) |
|
x = TF.normalize( |
|
x / 255, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=inplace |
|
) |
|
return x |
|
|
|
|
|
def combine_tile_list(tile_list: list[torch.Tensor], ncols: int): |
|
""" |
|
Combines a flat list of tile tensors (each with shape (C, H, W)) into one output tensor, |
|
arranging them in a grid with the specified number of columns. The tiles in the final row |
|
or column may have different sizes. |
|
|
|
Args: |
|
tile_list (list of torch.Tensor): A flat list of tile tensors, each with shape |
|
(channels, tile_height, tile_width). It is assumed |
|
that the number of channels is consistent across all tiles. |
|
ncols (int): Number of columns to arrange the tiles in. |
|
|
|
Returns: |
|
torch.Tensor: A tensor of shape (channels, total_height, total_width), where: |
|
- total_height is the sum of maximum tile heights in each row. |
|
- total_width is the sum of maximum tile widths in each column. |
|
""" |
|
if not tile_list: |
|
raise ValueError("tile_list is empty") |
|
|
|
ntiles = len(tile_list) |
|
nrows = (ntiles + ncols - 1) // ncols |
|
|
|
|
|
nested_tiles = [tile_list[i * ncols : (i + 1) * ncols] for i in range(nrows)] |
|
|
|
|
|
row_heights = [max(tile.shape[1] for tile in row) for row in nested_tiles] |
|
|
|
|
|
col_widths = [] |
|
for col in range(ncols): |
|
max_width = 0 |
|
for row in nested_tiles: |
|
if col < len(row): |
|
tile_w = row[col].shape[2] |
|
if tile_w > max_width: |
|
max_width = tile_w |
|
col_widths.append(max_width) |
|
|
|
|
|
total_height = sum(row_heights) |
|
total_width = sum(col_widths) |
|
|
|
|
|
channels = tile_list[0].shape[0] |
|
|
|
|
|
out_tensor = torch.zeros( |
|
channels, |
|
total_height, |
|
total_width, |
|
dtype=tile_list[0].dtype, |
|
device=tile_list[0].device, |
|
) |
|
|
|
|
|
y_offset = 0 |
|
for i, row in enumerate(nested_tiles): |
|
x_offset = 0 |
|
for j, tile in enumerate(row): |
|
tile_h, tile_w = tile.shape[1], tile.shape[2] |
|
out_tensor[ |
|
:, y_offset : y_offset + tile_h, x_offset : x_offset + tile_w |
|
] = tile |
|
x_offset += col_widths[j] |
|
y_offset += row_heights[i] |
|
|
|
return out_tensor |
|
|