Update unet/mv_unet.py
Browse files- unet/mv_unet.py +0 -84
unet/mv_unet.py
CHANGED
|
@@ -39,55 +39,6 @@ def get_camera(
|
|
| 39 |
return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
|
| 40 |
|
| 41 |
|
| 42 |
-
def checkpoint(func, inputs, params, flag):
|
| 43 |
-
"""
|
| 44 |
-
Evaluate a function without caching intermediate activations, allowing for
|
| 45 |
-
reduced memory at the expense of extra compute in the backward pass.
|
| 46 |
-
:param func: the function to evaluate.
|
| 47 |
-
:param inputs: the argument sequence to pass to `func`.
|
| 48 |
-
:param params: a sequence of parameters `func` depends on but does not
|
| 49 |
-
explicitly take as arguments.
|
| 50 |
-
:param flag: if False, disable gradient checkpointing.
|
| 51 |
-
"""
|
| 52 |
-
if flag:
|
| 53 |
-
args = tuple(inputs) + tuple(params)
|
| 54 |
-
return CheckpointFunction.apply(func, len(inputs), *args)
|
| 55 |
-
else:
|
| 56 |
-
return func(*inputs)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class CheckpointFunction(torch.autograd.Function):
|
| 60 |
-
@staticmethod
|
| 61 |
-
def forward(ctx, run_function, length, *args):
|
| 62 |
-
ctx.run_function = run_function
|
| 63 |
-
ctx.input_tensors = list(args[:length])
|
| 64 |
-
ctx.input_params = list(args[length:])
|
| 65 |
-
|
| 66 |
-
with torch.no_grad():
|
| 67 |
-
output_tensors = ctx.run_function(*ctx.input_tensors)
|
| 68 |
-
return output_tensors
|
| 69 |
-
|
| 70 |
-
@staticmethod
|
| 71 |
-
def backward(ctx, *output_grads):
|
| 72 |
-
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
| 73 |
-
with torch.enable_grad():
|
| 74 |
-
# Fixes a bug where the first op in run_function modifies the
|
| 75 |
-
# Tensor storage in place, which is not allowed for detach()'d
|
| 76 |
-
# Tensors.
|
| 77 |
-
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
| 78 |
-
output_tensors = ctx.run_function(*shallow_copies)
|
| 79 |
-
input_grads = torch.autograd.grad(
|
| 80 |
-
output_tensors,
|
| 81 |
-
ctx.input_tensors + ctx.input_params,
|
| 82 |
-
output_grads,
|
| 83 |
-
allow_unused=True,
|
| 84 |
-
)
|
| 85 |
-
del ctx.input_tensors
|
| 86 |
-
del ctx.input_params
|
| 87 |
-
del output_tensors
|
| 88 |
-
return (None, None) + input_grads
|
| 89 |
-
|
| 90 |
-
|
| 91 |
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
| 92 |
"""
|
| 93 |
Create sinusoidal timestep embeddings.
|
|
@@ -286,7 +237,6 @@ class BasicTransformerBlock3D(nn.Module):
|
|
| 286 |
context_dim,
|
| 287 |
dropout=0.0,
|
| 288 |
gated_ff=True,
|
| 289 |
-
checkpoint=True,
|
| 290 |
ip_dim=0,
|
| 291 |
ip_weight=1,
|
| 292 |
):
|
|
@@ -313,14 +263,8 @@ class BasicTransformerBlock3D(nn.Module):
|
|
| 313 |
self.norm1 = nn.LayerNorm(dim)
|
| 314 |
self.norm2 = nn.LayerNorm(dim)
|
| 315 |
self.norm3 = nn.LayerNorm(dim)
|
| 316 |
-
self.checkpoint = checkpoint
|
| 317 |
|
| 318 |
def forward(self, x, context=None, num_frames=1):
|
| 319 |
-
return checkpoint(
|
| 320 |
-
self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
def _forward(self, x, context=None, num_frames=1):
|
| 324 |
x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
|
| 325 |
x = self.attn1(self.norm1(x), context=None) + x
|
| 326 |
x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
|
|
@@ -341,7 +285,6 @@ class SpatialTransformer3D(nn.Module):
|
|
| 341 |
dropout=0.0,
|
| 342 |
ip_dim=0,
|
| 343 |
ip_weight=1,
|
| 344 |
-
use_checkpoint=True,
|
| 345 |
):
|
| 346 |
super().__init__()
|
| 347 |
|
|
@@ -362,7 +305,6 @@ class SpatialTransformer3D(nn.Module):
|
|
| 362 |
d_head,
|
| 363 |
context_dim=context_dim[d],
|
| 364 |
dropout=dropout,
|
| 365 |
-
checkpoint=use_checkpoint,
|
| 366 |
ip_dim=ip_dim,
|
| 367 |
ip_weight=ip_weight,
|
| 368 |
)
|
|
@@ -581,7 +523,6 @@ class ResBlock(nn.Module):
|
|
| 581 |
convolution instead of a smaller 1x1 convolution to change the
|
| 582 |
channels in the skip connection.
|
| 583 |
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 584 |
-
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
| 585 |
:param up: if True, use this block for upsampling.
|
| 586 |
:param down: if True, use this block for downsampling.
|
| 587 |
"""
|
|
@@ -595,7 +536,6 @@ class ResBlock(nn.Module):
|
|
| 595 |
use_conv=False,
|
| 596 |
use_scale_shift_norm=False,
|
| 597 |
dims=2,
|
| 598 |
-
use_checkpoint=False,
|
| 599 |
up=False,
|
| 600 |
down=False,
|
| 601 |
):
|
|
@@ -605,7 +545,6 @@ class ResBlock(nn.Module):
|
|
| 605 |
self.dropout = dropout
|
| 606 |
self.out_channels = out_channels or channels
|
| 607 |
self.use_conv = use_conv
|
| 608 |
-
self.use_checkpoint = use_checkpoint
|
| 609 |
self.use_scale_shift_norm = use_scale_shift_norm
|
| 610 |
|
| 611 |
self.in_layers = nn.Sequential(
|
|
@@ -651,17 +590,6 @@ class ResBlock(nn.Module):
|
|
| 651 |
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
| 652 |
|
| 653 |
def forward(self, x, emb):
|
| 654 |
-
"""
|
| 655 |
-
Apply the block to a Tensor, conditioned on a timestep embedding.
|
| 656 |
-
:param x: an [N x C x ...] Tensor of features.
|
| 657 |
-
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
| 658 |
-
:return: an [N x C x ...] Tensor of outputs.
|
| 659 |
-
"""
|
| 660 |
-
return checkpoint(
|
| 661 |
-
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
| 662 |
-
)
|
| 663 |
-
|
| 664 |
-
def _forward(self, x, emb):
|
| 665 |
if self.updown:
|
| 666 |
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 667 |
h = in_rest(x)
|
|
@@ -702,7 +630,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 702 |
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 703 |
:param num_classes: if specified (as an int), then this model will be
|
| 704 |
class-conditional with `num_classes` classes.
|
| 705 |
-
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
| 706 |
:param num_heads: the number of attention heads in each attention layer.
|
| 707 |
:param num_heads_channels: if specified, ignore num_heads and instead use
|
| 708 |
a fixed channel width per attention head.
|
|
@@ -728,7 +655,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 728 |
conv_resample=True,
|
| 729 |
dims=2,
|
| 730 |
num_classes=None,
|
| 731 |
-
use_checkpoint=False,
|
| 732 |
num_heads=-1,
|
| 733 |
num_head_channels=-1,
|
| 734 |
num_heads_upsample=-1,
|
|
@@ -794,7 +720,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 794 |
self.channel_mult = channel_mult
|
| 795 |
self.conv_resample = conv_resample
|
| 796 |
self.num_classes = num_classes
|
| 797 |
-
self.use_checkpoint = use_checkpoint
|
| 798 |
self.num_heads = num_heads
|
| 799 |
self.num_head_channels = num_head_channels
|
| 800 |
self.num_heads_upsample = num_heads_upsample
|
|
@@ -868,7 +793,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 868 |
dropout,
|
| 869 |
out_channels=mult * model_channels,
|
| 870 |
dims=dims,
|
| 871 |
-
use_checkpoint=use_checkpoint,
|
| 872 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 873 |
)
|
| 874 |
]
|
|
@@ -888,7 +812,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 888 |
dim_head,
|
| 889 |
context_dim=context_dim,
|
| 890 |
depth=transformer_depth,
|
| 891 |
-
use_checkpoint=use_checkpoint,
|
| 892 |
ip_dim=self.ip_dim,
|
| 893 |
ip_weight=self.ip_weight,
|
| 894 |
)
|
|
@@ -906,7 +829,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 906 |
dropout,
|
| 907 |
out_channels=out_ch,
|
| 908 |
dims=dims,
|
| 909 |
-
use_checkpoint=use_checkpoint,
|
| 910 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 911 |
down=True,
|
| 912 |
)
|
|
@@ -933,7 +855,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 933 |
time_embed_dim,
|
| 934 |
dropout,
|
| 935 |
dims=dims,
|
| 936 |
-
use_checkpoint=use_checkpoint,
|
| 937 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 938 |
),
|
| 939 |
SpatialTransformer3D(
|
|
@@ -942,7 +863,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 942 |
dim_head,
|
| 943 |
context_dim=context_dim,
|
| 944 |
depth=transformer_depth,
|
| 945 |
-
use_checkpoint=use_checkpoint,
|
| 946 |
ip_dim=self.ip_dim,
|
| 947 |
ip_weight=self.ip_weight,
|
| 948 |
),
|
|
@@ -951,7 +871,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 951 |
time_embed_dim,
|
| 952 |
dropout,
|
| 953 |
dims=dims,
|
| 954 |
-
use_checkpoint=use_checkpoint,
|
| 955 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 956 |
),
|
| 957 |
)
|
|
@@ -968,7 +887,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 968 |
dropout,
|
| 969 |
out_channels=model_channels * mult,
|
| 970 |
dims=dims,
|
| 971 |
-
use_checkpoint=use_checkpoint,
|
| 972 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 973 |
)
|
| 974 |
]
|
|
@@ -988,7 +906,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 988 |
dim_head,
|
| 989 |
context_dim=context_dim,
|
| 990 |
depth=transformer_depth,
|
| 991 |
-
use_checkpoint=use_checkpoint,
|
| 992 |
ip_dim=self.ip_dim,
|
| 993 |
ip_weight=self.ip_weight,
|
| 994 |
)
|
|
@@ -1002,7 +919,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 1002 |
dropout,
|
| 1003 |
out_channels=out_ch,
|
| 1004 |
dims=dims,
|
| 1005 |
-
use_checkpoint=use_checkpoint,
|
| 1006 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 1007 |
up=True,
|
| 1008 |
)
|
|
|
|
| 39 |
return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
| 43 |
"""
|
| 44 |
Create sinusoidal timestep embeddings.
|
|
|
|
| 237 |
context_dim,
|
| 238 |
dropout=0.0,
|
| 239 |
gated_ff=True,
|
|
|
|
| 240 |
ip_dim=0,
|
| 241 |
ip_weight=1,
|
| 242 |
):
|
|
|
|
| 263 |
self.norm1 = nn.LayerNorm(dim)
|
| 264 |
self.norm2 = nn.LayerNorm(dim)
|
| 265 |
self.norm3 = nn.LayerNorm(dim)
|
|
|
|
| 266 |
|
| 267 |
def forward(self, x, context=None, num_frames=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
|
| 269 |
x = self.attn1(self.norm1(x), context=None) + x
|
| 270 |
x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
|
|
|
|
| 285 |
dropout=0.0,
|
| 286 |
ip_dim=0,
|
| 287 |
ip_weight=1,
|
|
|
|
| 288 |
):
|
| 289 |
super().__init__()
|
| 290 |
|
|
|
|
| 305 |
d_head,
|
| 306 |
context_dim=context_dim[d],
|
| 307 |
dropout=dropout,
|
|
|
|
| 308 |
ip_dim=ip_dim,
|
| 309 |
ip_weight=ip_weight,
|
| 310 |
)
|
|
|
|
| 523 |
convolution instead of a smaller 1x1 convolution to change the
|
| 524 |
channels in the skip connection.
|
| 525 |
:param dims: determines if the signal is 1D, 2D, or 3D.
|
|
|
|
| 526 |
:param up: if True, use this block for upsampling.
|
| 527 |
:param down: if True, use this block for downsampling.
|
| 528 |
"""
|
|
|
|
| 536 |
use_conv=False,
|
| 537 |
use_scale_shift_norm=False,
|
| 538 |
dims=2,
|
|
|
|
| 539 |
up=False,
|
| 540 |
down=False,
|
| 541 |
):
|
|
|
|
| 545 |
self.dropout = dropout
|
| 546 |
self.out_channels = out_channels or channels
|
| 547 |
self.use_conv = use_conv
|
|
|
|
| 548 |
self.use_scale_shift_norm = use_scale_shift_norm
|
| 549 |
|
| 550 |
self.in_layers = nn.Sequential(
|
|
|
|
| 590 |
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
| 591 |
|
| 592 |
def forward(self, x, emb):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 593 |
if self.updown:
|
| 594 |
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 595 |
h = in_rest(x)
|
|
|
|
| 630 |
:param dims: determines if the signal is 1D, 2D, or 3D.
|
| 631 |
:param num_classes: if specified (as an int), then this model will be
|
| 632 |
class-conditional with `num_classes` classes.
|
|
|
|
| 633 |
:param num_heads: the number of attention heads in each attention layer.
|
| 634 |
:param num_heads_channels: if specified, ignore num_heads and instead use
|
| 635 |
a fixed channel width per attention head.
|
|
|
|
| 655 |
conv_resample=True,
|
| 656 |
dims=2,
|
| 657 |
num_classes=None,
|
|
|
|
| 658 |
num_heads=-1,
|
| 659 |
num_head_channels=-1,
|
| 660 |
num_heads_upsample=-1,
|
|
|
|
| 720 |
self.channel_mult = channel_mult
|
| 721 |
self.conv_resample = conv_resample
|
| 722 |
self.num_classes = num_classes
|
|
|
|
| 723 |
self.num_heads = num_heads
|
| 724 |
self.num_head_channels = num_head_channels
|
| 725 |
self.num_heads_upsample = num_heads_upsample
|
|
|
|
| 793 |
dropout,
|
| 794 |
out_channels=mult * model_channels,
|
| 795 |
dims=dims,
|
|
|
|
| 796 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 797 |
)
|
| 798 |
]
|
|
|
|
| 812 |
dim_head,
|
| 813 |
context_dim=context_dim,
|
| 814 |
depth=transformer_depth,
|
|
|
|
| 815 |
ip_dim=self.ip_dim,
|
| 816 |
ip_weight=self.ip_weight,
|
| 817 |
)
|
|
|
|
| 829 |
dropout,
|
| 830 |
out_channels=out_ch,
|
| 831 |
dims=dims,
|
|
|
|
| 832 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 833 |
down=True,
|
| 834 |
)
|
|
|
|
| 855 |
time_embed_dim,
|
| 856 |
dropout,
|
| 857 |
dims=dims,
|
|
|
|
| 858 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 859 |
),
|
| 860 |
SpatialTransformer3D(
|
|
|
|
| 863 |
dim_head,
|
| 864 |
context_dim=context_dim,
|
| 865 |
depth=transformer_depth,
|
|
|
|
| 866 |
ip_dim=self.ip_dim,
|
| 867 |
ip_weight=self.ip_weight,
|
| 868 |
),
|
|
|
|
| 871 |
time_embed_dim,
|
| 872 |
dropout,
|
| 873 |
dims=dims,
|
|
|
|
| 874 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 875 |
),
|
| 876 |
)
|
|
|
|
| 887 |
dropout,
|
| 888 |
out_channels=model_channels * mult,
|
| 889 |
dims=dims,
|
|
|
|
| 890 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 891 |
)
|
| 892 |
]
|
|
|
|
| 906 |
dim_head,
|
| 907 |
context_dim=context_dim,
|
| 908 |
depth=transformer_depth,
|
|
|
|
| 909 |
ip_dim=self.ip_dim,
|
| 910 |
ip_weight=self.ip_weight,
|
| 911 |
)
|
|
|
|
| 919 |
dropout,
|
| 920 |
out_channels=out_ch,
|
| 921 |
dims=dims,
|
|
|
|
| 922 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 923 |
up=True,
|
| 924 |
)
|