EXAONE-Path-2.0 / utils /tensor_utils.py
mjpyeon's picture
initial commit
0dce87a
import typing as t
import torch
import torchvision.transforms.functional as TF
def tile(x: torch.Tensor, size: int, pad_value: int | float | None = None):
C, H, W = x.shape[-3:]
pad_h = (size - H % size) % size
pad_w = (size - W % size) % size
if pad_h > 0 or pad_w > 0:
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), value=pad_value)
nh, nw = x.size(-2) // size, x.size(-1) // size
return (
x.view(-1, C, nh, size, nw, size)
.permute(0, 2, 4, 1, 3, 5)
.reshape(-1, C, size, size)
)
def small_tiles_to_large_tiles(
small_tiles: torch.Tensor,
width: int,
large_tile_size: int,
sampled_large_tiles_idx: list | torch.Tensor | None = None,
) -> torch.Tensor:
has_channel = small_tiles.ndim == 4
small_tile_size = small_tiles.size(-1)
num_small_tiles = small_tiles.size(0)
nw = width // small_tile_size
nh = num_small_tiles // nw
r = large_tile_size // small_tile_size
num_large_tiles = (nh // r) * (nw // r)
large_tile_indices = (
range(num_large_tiles)
if sampled_large_tiles_idx is None
else sampled_large_tiles_idx
)
tiles = []
for k in large_tile_indices:
start_row = (k // (nw // r)) * r
start_col = (k % (nw // r)) * r
for i in range(start_row, start_row + r):
for j in range(start_col, start_col + r):
tiles.append(small_tiles[i * nw + j])
stacked = torch.stack(tiles, dim=0).view(-1, r, r, *small_tiles.shape[1:])
if has_channel:
large_tiles = stacked.permute(0, 3, 1, 4, 2, 5).reshape(
-1, small_tiles.size(1), large_tile_size, large_tile_size
)
else:
large_tiles = stacked.permute(0, 1, 3, 2, 4).reshape(
-1, large_tile_size, large_tile_size
)
return large_tiles
def small_tile_flags_to_large_tile_flags(
small_tile_flags: torch.Tensor,
width: int,
small_tile_size: int,
large_tile_size: int,
aggregation: t.Literal["any", "all"] = "any",
):
small_tile_flags = small_tile_flags.view(-1, 1, 1)
num_small_tiles = small_tile_flags.size(0)
nw = width // small_tile_size
r = large_tile_size // small_tile_size
num_large_tiles = num_small_tiles // r**2
large_tile_flags = small_tiles_to_large_tiles(
small_tile_flags,
width=nw,
large_tile_size=r,
).view(num_large_tiles, -1)
return (
large_tile_flags.any(-1) if aggregation == "any" else large_tile_flags.all(-1)
)
def format_first_stg_act_as_second_stg_inp(
x: torch.Tensor,
height: int,
width: int,
small_tile_size: int,
large_tile_size: int,
):
assert height % small_tile_size == 0 and width % small_tile_size == 0
D = x.size(1)
nh, nw = height // small_tile_size, width // small_tile_size
r = large_tile_size // small_tile_size
x = x.view(-1, nh, nw, D)
x = x.permute(0, 3, 1, 2).reshape(-1, D, nh // r, r, nw // r, r)
x = x.permute(0, 2, 4, 1, 3, 5).reshape(-1, D, r, r)
return x
def format_second_stg_inp_as_first_stg_act(
x: torch.Tensor, height: int, width: int, small_tile_size: int, large_tile_size: int
):
D = x.size(1)
nh, nw = height // small_tile_size, width // small_tile_size
r = large_tile_size // small_tile_size
x = x.view(-1, nh // r, nw // r, D, r, r)
x = x.permute(0, 3, 1, 4, 2, 5).reshape(-1, D, nh, nw)
x = x.permute(0, 2, 3, 1).reshape(-1, D)
return x
def format_second_stg_act_as_third_stg_inp(
x: torch.Tensor,
height: int,
width: int,
large_tile_size: int,
):
D = x.size(1)
nh = height // large_tile_size
nw = width // large_tile_size
return x.view(-1, nh, nw, D).permute(0, 3, 1, 2).contiguous()
def forward_with_batch_size_limit(
net,
x: torch.Tensor,
batch_size_on_gpu: int,
device: str | torch.device,
out_device: str | torch.device,
preproc_fn: t.Callable[[torch.Tensor], torch.Tensor] | None = None,
dtype: torch.dtype = torch.float32,
):
features = list()
for start_idx in range(0, x.size(0), batch_size_on_gpu):
end_idx = min(x.size(0), start_idx + batch_size_on_gpu)
batch = x[start_idx:end_idx].to(device=device, non_blocking=True)
batch = preproc_fn(batch) if preproc_fn else batch
batch = batch.to(dtype=dtype, non_blocking=True)
actual_bs = end_idx - start_idx
batch = pad_to_batch(batch, batch_size_on_gpu)
batch: torch.Tensor = forward_compiled(net, batch)
# batch = net(batch)
features.append(batch[:actual_bs].to(device=out_device, non_blocking=True))
if torch.device(out_device).type == "cpu" and torch.device(device).type == "cuda":
torch.cuda.synchronize()
return torch.cat(features)
@t.overload
def backward_with_batch_size_limit(
net,
x: torch.Tensor,
grad: torch.Tensor,
batch_size_on_gpu: int,
device: str | torch.device,
out_device: str | torch.device,
dtype: torch.dtype,
ret_grad: t.Literal[True],
) -> torch.Tensor: ...
@t.overload
def backward_with_batch_size_limit(
net,
x: torch.Tensor,
grad: torch.Tensor,
batch_size_on_gpu: int,
device: str | torch.device,
out_device: str | torch.device,
dtype: torch.dtype,
ret_grad: t.Literal[False],
) -> None: ...
def backward_with_batch_size_limit(
net,
x: torch.Tensor,
grad: torch.Tensor,
batch_size_on_gpu: int,
device: str | torch.device,
out_device: str | torch.device,
dtype: torch.dtype,
ret_grad: bool,
):
assert x.size(0) == grad.size(0)
grads = []
total = x.size(0)
for start in range(0, total, batch_size_on_gpu):
end = min(total, start + batch_size_on_gpu)
actual_bs = end - start
batch = x[start:end].to(device=device, dtype=dtype, non_blocking=True)
batch = pad_to_batch(batch, batch_size_on_gpu)
if ret_grad:
batch.requires_grad_(True)
with torch.autocast(device_type="cuda", dtype=dtype):
out = net(batch)
# out = forward_compiled(net, batch)
grad_batch = grad[start:end].to(device=device, dtype=dtype, non_blocking=True)
grad_batch = pad_to_batch(grad_batch, batch_size_on_gpu)
with torch._dynamo.utils.maybe_enable_compiled_autograd(
True, fullgraph=True, dynamic=False
):
out.backward(grad_batch)
# out.backward(grad_batch)
if ret_grad:
assert batch.grad is not None
grads.append(batch.grad[:actual_bs].to(out_device, non_blocking=True))
if ret_grad:
if (
torch.device(out_device).type == "cpu"
and torch.device(device).type == "cuda"
):
torch.cuda.synchronize()
return torch.cat(grads)
@torch.compile(fullgraph=True, dynamic=False)
def forward_compiled(net, x: torch.Tensor) -> torch.Tensor:
return net(x)
def pad_to_batch(t: torch.Tensor, batch_size: int) -> torch.Tensor:
assert (
t.size(0) <= batch_size
), f"'{t.shape}' size tensor cannot be padded to be batch size of '{batch_size}'"
pad = batch_size - t.size(0)
return torch.cat([t, t.new_zeros((pad,) + t.shape[1:])], dim=0) if pad > 0 else t
def scale_and_normalize(x: torch.Tensor, inplace: bool = False):
x = x.clamp_(0, 255) if inplace else x.clamp(0, 255)
x = TF.normalize(
x / 255, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=inplace
)
return x
def combine_tile_list(tile_list: list[torch.Tensor], ncols: int):
"""
Combines a flat list of tile tensors (each with shape (C, H, W)) into one output tensor,
arranging them in a grid with the specified number of columns. The tiles in the final row
or column may have different sizes.
Args:
tile_list (list of torch.Tensor): A flat list of tile tensors, each with shape
(channels, tile_height, tile_width). It is assumed
that the number of channels is consistent across all tiles.
ncols (int): Number of columns to arrange the tiles in.
Returns:
torch.Tensor: A tensor of shape (channels, total_height, total_width), where:
- total_height is the sum of maximum tile heights in each row.
- total_width is the sum of maximum tile widths in each column.
"""
if not tile_list:
raise ValueError("tile_list is empty")
ntiles = len(tile_list)
nrows = (ntiles + ncols - 1) // ncols # Ceiling division to get the number of rows
# Convert the flat tile list into a nested list (rows of tiles)
nested_tiles = [tile_list[i * ncols : (i + 1) * ncols] for i in range(nrows)]
# Compute the maximum tile height for each row
row_heights = [max(tile.shape[1] for tile in row) for row in nested_tiles]
# Compute the maximum tile width for each column (consider only rows that have a tile in that column)
col_widths = []
for col in range(ncols):
max_width = 0
for row in nested_tiles:
if col < len(row):
tile_w = row[col].shape[2]
if tile_w > max_width:
max_width = tile_w
col_widths.append(max_width)
# Calculate the total output dimensions
total_height = sum(row_heights)
total_width = sum(col_widths)
# Determine the number of channels from the first tile
channels = tile_list[0].shape[0]
# Preallocate the output tensor (this avoids repeated concatenation and extra memory copies)
out_tensor = torch.zeros(
channels,
total_height,
total_width,
dtype=tile_list[0].dtype,
device=tile_list[0].device,
)
# Place each tile in its proper location by calculating offsets
y_offset = 0
for i, row in enumerate(nested_tiles):
x_offset = 0
for j, tile in enumerate(row):
tile_h, tile_w = tile.shape[1], tile.shape[2]
out_tensor[
:, y_offset : y_offset + tile_h, x_offset : x_offset + tile_w
] = tile
x_offset += col_widths[j]
y_offset += row_heights[i]
return out_tensor