|  |  | 
					
						
						|  |  | 
					
						
						|  | import math | 
					
						
						|  | import numpy as np | 
					
						
						|  | import torch as th | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from diffusers.configuration_utils import ConfigMixin | 
					
						
						|  | from diffusers.models.modeling_utils import ModelMixin | 
					
						
						|  | from typing import Any, List, Optional | 
					
						
						|  | from torch import Tensor | 
					
						
						|  |  | 
					
						
						|  | from abc import abstractmethod | 
					
						
						|  | from .util import ( | 
					
						
						|  | checkpoint, | 
					
						
						|  | conv_nd, | 
					
						
						|  | avg_pool_nd, | 
					
						
						|  | zero_module, | 
					
						
						|  | timestep_embedding, | 
					
						
						|  | ) | 
					
						
						|  | from .attention import SpatialTransformer, SpatialTransformer3D | 
					
						
						|  |  | 
					
						
						|  | class TimestepBlock(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | Any module where forward() takes timestep embeddings as a second argument. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @abstractmethod | 
					
						
						|  | def forward(self, x, emb): | 
					
						
						|  | """ | 
					
						
						|  | Apply the module to `x` given `emb` timestep embeddings. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): | 
					
						
						|  | """ | 
					
						
						|  | A sequential module that passes timestep embeddings to the children that | 
					
						
						|  | support it as an extra input. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, emb, context=None, num_frames=1): | 
					
						
						|  | for layer in self: | 
					
						
						|  | if isinstance(layer, TimestepBlock): | 
					
						
						|  | x = layer(x, emb) | 
					
						
						|  | elif isinstance(layer, SpatialTransformer3D): | 
					
						
						|  | x = layer(x, context, num_frames=num_frames) | 
					
						
						|  | elif isinstance(layer, SpatialTransformer): | 
					
						
						|  | x = layer(x, context) | 
					
						
						|  | else: | 
					
						
						|  | x = layer(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Upsample(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | An upsampling layer with an optional convolution. | 
					
						
						|  | :param channels: channels in the inputs and outputs. | 
					
						
						|  | :param use_conv: a bool determining if a convolution is applied. | 
					
						
						|  | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then | 
					
						
						|  | upsampling occurs in the inner-two dimensions. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.channels = channels | 
					
						
						|  | self.out_channels = out_channels or channels | 
					
						
						|  | self.use_conv = use_conv | 
					
						
						|  | self.dims = dims | 
					
						
						|  | if use_conv: | 
					
						
						|  | self.conv = conv_nd( | 
					
						
						|  | dims, self.channels, self.out_channels, 3, padding=padding | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | assert x.shape[1] == self.channels | 
					
						
						|  | if self.dims == 3: | 
					
						
						|  | x = F.interpolate( | 
					
						
						|  | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | x = F.interpolate(x, scale_factor=2, mode="nearest") | 
					
						
						|  | if self.use_conv: | 
					
						
						|  | x = self.conv(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class Downsample(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | A downsampling layer with an optional convolution. | 
					
						
						|  | :param channels: channels in the inputs and outputs. | 
					
						
						|  | :param use_conv: a bool determining if a convolution is applied. | 
					
						
						|  | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then | 
					
						
						|  | downsampling occurs in the inner-two dimensions. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.channels = channels | 
					
						
						|  | self.out_channels = out_channels or channels | 
					
						
						|  | self.use_conv = use_conv | 
					
						
						|  | self.dims = dims | 
					
						
						|  | stride = 2 if dims != 3 else (1, 2, 2) | 
					
						
						|  | if use_conv: | 
					
						
						|  | self.op = conv_nd( | 
					
						
						|  | dims, | 
					
						
						|  | self.channels, | 
					
						
						|  | self.out_channels, | 
					
						
						|  | 3, | 
					
						
						|  | stride=stride, | 
					
						
						|  | padding=padding, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | assert self.channels == self.out_channels | 
					
						
						|  | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | assert x.shape[1] == self.channels | 
					
						
						|  | return self.op(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ResBlock(TimestepBlock): | 
					
						
						|  | """ | 
					
						
						|  | A residual block that can optionally change the number of channels. | 
					
						
						|  | :param channels: the number of input channels. | 
					
						
						|  | :param emb_channels: the number of timestep embedding channels. | 
					
						
						|  | :param dropout: the rate of dropout. | 
					
						
						|  | :param out_channels: if specified, the number of out channels. | 
					
						
						|  | :param use_conv: if True and out_channels is specified, use a spatial | 
					
						
						|  | convolution instead of a smaller 1x1 convolution to change the | 
					
						
						|  | channels in the skip connection. | 
					
						
						|  | :param dims: determines if the signal is 1D, 2D, or 3D. | 
					
						
						|  | :param use_checkpoint: if True, use gradient checkpointing on this module. | 
					
						
						|  | :param up: if True, use this block for upsampling. | 
					
						
						|  | :param down: if True, use this block for downsampling. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | channels, | 
					
						
						|  | emb_channels, | 
					
						
						|  | dropout, | 
					
						
						|  | out_channels=None, | 
					
						
						|  | use_conv=False, | 
					
						
						|  | use_scale_shift_norm=False, | 
					
						
						|  | dims=2, | 
					
						
						|  | use_checkpoint=False, | 
					
						
						|  | up=False, | 
					
						
						|  | down=False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.channels = channels | 
					
						
						|  | self.emb_channels = emb_channels | 
					
						
						|  | self.dropout = dropout | 
					
						
						|  | self.out_channels = out_channels or channels | 
					
						
						|  | self.use_conv = use_conv | 
					
						
						|  | self.use_checkpoint = use_checkpoint | 
					
						
						|  | self.use_scale_shift_norm = use_scale_shift_norm | 
					
						
						|  |  | 
					
						
						|  | self.in_layers = nn.Sequential( | 
					
						
						|  | nn.GroupNorm(32, channels), | 
					
						
						|  | nn.SiLU(), | 
					
						
						|  | conv_nd(dims, channels, self.out_channels, 3, padding=1), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.updown = up or down | 
					
						
						|  |  | 
					
						
						|  | if up: | 
					
						
						|  | self.h_upd = Upsample(channels, False, dims) | 
					
						
						|  | self.x_upd = Upsample(channels, False, dims) | 
					
						
						|  | elif down: | 
					
						
						|  | self.h_upd = Downsample(channels, False, dims) | 
					
						
						|  | self.x_upd = Downsample(channels, False, dims) | 
					
						
						|  | else: | 
					
						
						|  | self.h_upd = self.x_upd = nn.Identity() | 
					
						
						|  |  | 
					
						
						|  | self.emb_layers = nn.Sequential( | 
					
						
						|  | nn.SiLU(), | 
					
						
						|  | nn.Linear( | 
					
						
						|  | emb_channels, | 
					
						
						|  | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | self.out_layers = nn.Sequential( | 
					
						
						|  | nn.GroupNorm(32, self.out_channels), | 
					
						
						|  | nn.SiLU(), | 
					
						
						|  | nn.Dropout(p=dropout), | 
					
						
						|  | zero_module( | 
					
						
						|  | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.out_channels == channels: | 
					
						
						|  | self.skip_connection = nn.Identity() | 
					
						
						|  | elif use_conv: | 
					
						
						|  | self.skip_connection = conv_nd( | 
					
						
						|  | dims, channels, self.out_channels, 3, padding=1 | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x, emb): | 
					
						
						|  | """ | 
					
						
						|  | Apply the block to a Tensor, conditioned on a timestep embedding. | 
					
						
						|  | :param x: an [N x C x ...] Tensor of features. | 
					
						
						|  | :param emb: an [N x emb_channels] Tensor of timestep embeddings. | 
					
						
						|  | :return: an [N x C x ...] Tensor of outputs. | 
					
						
						|  | """ | 
					
						
						|  | return checkpoint( | 
					
						
						|  | self._forward, (x, emb), self.parameters(), self.use_checkpoint | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def _forward(self, x, emb): | 
					
						
						|  | if self.updown: | 
					
						
						|  | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] | 
					
						
						|  | h = in_rest(x) | 
					
						
						|  | h = self.h_upd(h) | 
					
						
						|  | x = self.x_upd(x) | 
					
						
						|  | h = in_conv(h) | 
					
						
						|  | else: | 
					
						
						|  | h = self.in_layers(x) | 
					
						
						|  | emb_out = self.emb_layers(emb).type(h.dtype) | 
					
						
						|  | while len(emb_out.shape) < len(h.shape): | 
					
						
						|  | emb_out = emb_out[..., None] | 
					
						
						|  | if self.use_scale_shift_norm: | 
					
						
						|  | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] | 
					
						
						|  | scale, shift = th.chunk(emb_out, 2, dim=1) | 
					
						
						|  | h = out_norm(h) * (1 + scale) + shift | 
					
						
						|  | h = out_rest(h) | 
					
						
						|  | else: | 
					
						
						|  | h = h + emb_out | 
					
						
						|  | h = self.out_layers(h) | 
					
						
						|  | return self.skip_connection(x) + h | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AttentionBlock(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | An attention block that allows spatial positions to attend to each other. | 
					
						
						|  | Originally ported from here, but adapted to the N-d case. | 
					
						
						|  | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | channels, | 
					
						
						|  | num_heads=1, | 
					
						
						|  | num_head_channels=-1, | 
					
						
						|  | use_checkpoint=False, | 
					
						
						|  | use_new_attention_order=False, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.channels = channels | 
					
						
						|  | if num_head_channels == -1: | 
					
						
						|  | self.num_heads = num_heads | 
					
						
						|  | else: | 
					
						
						|  | assert ( | 
					
						
						|  | channels % num_head_channels == 0 | 
					
						
						|  | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" | 
					
						
						|  | self.num_heads = channels // num_head_channels | 
					
						
						|  | self.use_checkpoint = use_checkpoint | 
					
						
						|  | self.norm = nn.GroupNorm(32, channels) | 
					
						
						|  | self.qkv = conv_nd(1, channels, channels * 3, 1) | 
					
						
						|  | if use_new_attention_order: | 
					
						
						|  |  | 
					
						
						|  | self.attention = QKVAttention(self.num_heads) | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | self.attention = QKVAttentionLegacy(self.num_heads) | 
					
						
						|  |  | 
					
						
						|  | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) | 
					
						
						|  |  | 
					
						
						|  | def forward(self, x): | 
					
						
						|  | return checkpoint(self._forward, (x,), self.parameters(), True) | 
					
						
						|  |  | 
					
						
						|  | def _forward(self, x): | 
					
						
						|  | b, c, *spatial = x.shape | 
					
						
						|  | x = x.reshape(b, c, -1) | 
					
						
						|  | qkv = self.qkv(self.norm(x)) | 
					
						
						|  | h = self.attention(qkv) | 
					
						
						|  | h = self.proj_out(h) | 
					
						
						|  | return (x + h).reshape(b, c, *spatial) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class QKVAttentionLegacy(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, n_heads): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.n_heads = n_heads | 
					
						
						|  |  | 
					
						
						|  | def forward(self, qkv): | 
					
						
						|  | """ | 
					
						
						|  | Apply QKV attention. | 
					
						
						|  | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. | 
					
						
						|  | :return: an [N x (H * C) x T] tensor after attention. | 
					
						
						|  | """ | 
					
						
						|  | bs, width, length = qkv.shape | 
					
						
						|  | assert width % (3 * self.n_heads) == 0 | 
					
						
						|  | ch = width // (3 * self.n_heads) | 
					
						
						|  | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) | 
					
						
						|  | scale = 1 / math.sqrt(math.sqrt(ch)) | 
					
						
						|  | weight = th.einsum( | 
					
						
						|  | "bct,bcs->bts", q * scale, k * scale | 
					
						
						|  | ) | 
					
						
						|  | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) | 
					
						
						|  | a = th.einsum("bts,bcs->bct", weight, v) | 
					
						
						|  | return a.reshape(bs, -1, length) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class QKVAttention(nn.Module): | 
					
						
						|  | """ | 
					
						
						|  | A module which performs QKV attention and splits in a different order. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, n_heads): | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.n_heads = n_heads | 
					
						
						|  |  | 
					
						
						|  | def forward(self, qkv): | 
					
						
						|  | """ | 
					
						
						|  | Apply QKV attention. | 
					
						
						|  | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. | 
					
						
						|  | :return: an [N x (H * C) x T] tensor after attention. | 
					
						
						|  | """ | 
					
						
						|  | bs, width, length = qkv.shape | 
					
						
						|  | assert width % (3 * self.n_heads) == 0 | 
					
						
						|  | ch = width // (3 * self.n_heads) | 
					
						
						|  | q, k, v = qkv.chunk(3, dim=1) | 
					
						
						|  | scale = 1 / math.sqrt(math.sqrt(ch)) | 
					
						
						|  | weight = th.einsum( | 
					
						
						|  | "bct,bcs->bts", | 
					
						
						|  | (q * scale).view(bs * self.n_heads, ch, length), | 
					
						
						|  | (k * scale).view(bs * self.n_heads, ch, length), | 
					
						
						|  | ) | 
					
						
						|  | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) | 
					
						
						|  | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) | 
					
						
						|  | return a.reshape(bs, -1, length) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MultiViewUNetModel(ModelMixin, ConfigMixin): | 
					
						
						|  | """ | 
					
						
						|  | The full multi-view UNet model with attention, timestep embedding and camera embedding. | 
					
						
						|  | :param in_channels: channels in the input Tensor. | 
					
						
						|  | :param model_channels: base channel count for the model. | 
					
						
						|  | :param out_channels: channels in the output Tensor. | 
					
						
						|  | :param num_res_blocks: number of residual blocks per downsample. | 
					
						
						|  | :param attention_resolutions: a collection of downsample rates at which | 
					
						
						|  | attention will take place. May be a set, list, or tuple. | 
					
						
						|  | For example, if this contains 4, then at 4x downsampling, attention | 
					
						
						|  | will be used. | 
					
						
						|  | :param dropout: the dropout probability. | 
					
						
						|  | :param channel_mult: channel multiplier for each level of the UNet. | 
					
						
						|  | :param conv_resample: if True, use learned convolutions for upsampling and | 
					
						
						|  | downsampling. | 
					
						
						|  | :param dims: determines if the signal is 1D, 2D, or 3D. | 
					
						
						|  | :param num_classes: if specified (as an int), then this model will be | 
					
						
						|  | class-conditional with `num_classes` classes. | 
					
						
						|  | :param use_checkpoint: use gradient checkpointing to reduce memory usage. | 
					
						
						|  | :param num_heads: the number of attention heads in each attention layer. | 
					
						
						|  | :param num_heads_channels: if specified, ignore num_heads and instead use | 
					
						
						|  | a fixed channel width per attention head. | 
					
						
						|  | :param num_heads_upsample: works with num_heads to set a different number | 
					
						
						|  | of heads for upsampling. Deprecated. | 
					
						
						|  | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. | 
					
						
						|  | :param resblock_updown: use residual blocks for up/downsampling. | 
					
						
						|  | :param use_new_attention_order: use a different attention pattern for potentially | 
					
						
						|  | increased efficiency. | 
					
						
						|  | :param camera_dim: dimensionality of camera input. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | image_size, | 
					
						
						|  | in_channels, | 
					
						
						|  | model_channels, | 
					
						
						|  | out_channels, | 
					
						
						|  | num_res_blocks, | 
					
						
						|  | attention_resolutions, | 
					
						
						|  | dropout=0, | 
					
						
						|  | channel_mult=(1, 2, 4, 8), | 
					
						
						|  | conv_resample=True, | 
					
						
						|  | dims=2, | 
					
						
						|  | num_classes=None, | 
					
						
						|  | use_checkpoint=False, | 
					
						
						|  | num_heads=-1, | 
					
						
						|  | num_head_channels=-1, | 
					
						
						|  | num_heads_upsample=-1, | 
					
						
						|  | use_scale_shift_norm=False, | 
					
						
						|  | resblock_updown=False, | 
					
						
						|  | use_new_attention_order=False, | 
					
						
						|  | use_spatial_transformer=False, | 
					
						
						|  | transformer_depth=1, | 
					
						
						|  | context_dim=None, | 
					
						
						|  | n_embed=None, | 
					
						
						|  | legacy=True, | 
					
						
						|  | disable_self_attentions=None, | 
					
						
						|  | num_attention_blocks=None, | 
					
						
						|  | disable_middle_self_attn=False, | 
					
						
						|  | use_linear_in_transformer=False, | 
					
						
						|  | adm_in_channels=None, | 
					
						
						|  | camera_dim=None, | 
					
						
						|  | ): | 
					
						
						|  | super().__init__() | 
					
						
						|  | if use_spatial_transformer: | 
					
						
						|  | assert ( | 
					
						
						|  | context_dim is not None | 
					
						
						|  | ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." | 
					
						
						|  |  | 
					
						
						|  | if context_dim is not None: | 
					
						
						|  | assert ( | 
					
						
						|  | use_spatial_transformer | 
					
						
						|  | ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." | 
					
						
						|  | from omegaconf.listconfig import ListConfig | 
					
						
						|  |  | 
					
						
						|  | if type(context_dim) == ListConfig: | 
					
						
						|  | context_dim = list(context_dim) | 
					
						
						|  |  | 
					
						
						|  | if num_heads_upsample == -1: | 
					
						
						|  | num_heads_upsample = num_heads | 
					
						
						|  |  | 
					
						
						|  | if num_heads == -1: | 
					
						
						|  | assert ( | 
					
						
						|  | num_head_channels != -1 | 
					
						
						|  | ), "Either num_heads or num_head_channels has to be set" | 
					
						
						|  |  | 
					
						
						|  | if num_head_channels == -1: | 
					
						
						|  | assert ( | 
					
						
						|  | num_heads != -1 | 
					
						
						|  | ), "Either num_heads or num_head_channels has to be set" | 
					
						
						|  |  | 
					
						
						|  | self.image_size = image_size | 
					
						
						|  | self.in_channels = in_channels | 
					
						
						|  | self.model_channels = model_channels | 
					
						
						|  | self.out_channels = out_channels | 
					
						
						|  | if isinstance(num_res_blocks, int): | 
					
						
						|  | self.num_res_blocks = len(channel_mult) * [num_res_blocks] | 
					
						
						|  | else: | 
					
						
						|  | if len(num_res_blocks) != len(channel_mult): | 
					
						
						|  | raise ValueError( | 
					
						
						|  | "provide num_res_blocks either as an int (globally constant) or " | 
					
						
						|  | "as a list/tuple (per-level) with the same length as channel_mult" | 
					
						
						|  | ) | 
					
						
						|  | self.num_res_blocks = num_res_blocks | 
					
						
						|  | if disable_self_attentions is not None: | 
					
						
						|  |  | 
					
						
						|  | assert len(disable_self_attentions) == len(channel_mult) | 
					
						
						|  | if num_attention_blocks is not None: | 
					
						
						|  | assert len(num_attention_blocks) == len(self.num_res_blocks) | 
					
						
						|  | assert all( | 
					
						
						|  | map( | 
					
						
						|  | lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], | 
					
						
						|  | range(len(num_attention_blocks)), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | print( | 
					
						
						|  | f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " | 
					
						
						|  | f"This option has LESS priority than attention_resolutions {attention_resolutions}, " | 
					
						
						|  | f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " | 
					
						
						|  | f"attention will still not be set." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self.attention_resolutions = attention_resolutions | 
					
						
						|  | self.dropout = dropout | 
					
						
						|  | self.channel_mult = channel_mult | 
					
						
						|  | self.conv_resample = conv_resample | 
					
						
						|  | self.num_classes = num_classes | 
					
						
						|  | self.use_checkpoint = use_checkpoint | 
					
						
						|  | self.num_heads = num_heads | 
					
						
						|  | self.num_head_channels = num_head_channels | 
					
						
						|  | self.num_heads_upsample = num_heads_upsample | 
					
						
						|  | self.predict_codebook_ids = n_embed is not None | 
					
						
						|  |  | 
					
						
						|  | time_embed_dim = model_channels * 4 | 
					
						
						|  | self.time_embed = nn.Sequential( | 
					
						
						|  | nn.Linear(model_channels, time_embed_dim), | 
					
						
						|  | nn.SiLU(), | 
					
						
						|  | nn.Linear(time_embed_dim, time_embed_dim), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if camera_dim is not None: | 
					
						
						|  | time_embed_dim = model_channels * 4 | 
					
						
						|  | self.camera_embed = nn.Sequential( | 
					
						
						|  | nn.Linear(camera_dim, time_embed_dim), | 
					
						
						|  | nn.SiLU(), | 
					
						
						|  | nn.Linear(time_embed_dim, time_embed_dim), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.num_classes is not None: | 
					
						
						|  | if isinstance(self.num_classes, int): | 
					
						
						|  | self.label_emb = nn.Embedding(self.num_classes, time_embed_dim) | 
					
						
						|  | elif self.num_classes == "continuous": | 
					
						
						|  |  | 
					
						
						|  | self.label_emb = nn.Linear(1, time_embed_dim) | 
					
						
						|  | elif self.num_classes == "sequential": | 
					
						
						|  | assert adm_in_channels is not None | 
					
						
						|  | self.label_emb = nn.Sequential( | 
					
						
						|  | nn.Sequential( | 
					
						
						|  | nn.Linear(adm_in_channels, time_embed_dim), | 
					
						
						|  | nn.SiLU(), | 
					
						
						|  | nn.Linear(time_embed_dim, time_embed_dim), | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError() | 
					
						
						|  |  | 
					
						
						|  | self.input_blocks = nn.ModuleList( | 
					
						
						|  | [ | 
					
						
						|  | TimestepEmbedSequential( | 
					
						
						|  | conv_nd(dims, in_channels, model_channels, 3, padding=1) | 
					
						
						|  | ) | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | self._feature_size = model_channels | 
					
						
						|  | input_block_chans = [model_channels] | 
					
						
						|  | ch = model_channels | 
					
						
						|  | ds = 1 | 
					
						
						|  | for level, mult in enumerate(channel_mult): | 
					
						
						|  | for nr in range(self.num_res_blocks[level]): | 
					
						
						|  | layers: List[Any] = [ | 
					
						
						|  | ResBlock( | 
					
						
						|  | ch, | 
					
						
						|  | time_embed_dim, | 
					
						
						|  | dropout, | 
					
						
						|  | out_channels=mult * model_channels, | 
					
						
						|  | dims=dims, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | use_scale_shift_norm=use_scale_shift_norm, | 
					
						
						|  | ) | 
					
						
						|  | ] | 
					
						
						|  | ch = mult * model_channels | 
					
						
						|  | if ds in attention_resolutions: | 
					
						
						|  | if num_head_channels == -1: | 
					
						
						|  | dim_head = ch // num_heads | 
					
						
						|  | else: | 
					
						
						|  | num_heads = ch // num_head_channels | 
					
						
						|  | dim_head = num_head_channels | 
					
						
						|  | if legacy: | 
					
						
						|  |  | 
					
						
						|  | dim_head = ( | 
					
						
						|  | ch // num_heads | 
					
						
						|  | if use_spatial_transformer | 
					
						
						|  | else num_head_channels | 
					
						
						|  | ) | 
					
						
						|  | if disable_self_attentions is not None: | 
					
						
						|  | disabled_sa = disable_self_attentions[level] | 
					
						
						|  | else: | 
					
						
						|  | disabled_sa = False | 
					
						
						|  |  | 
					
						
						|  | if num_attention_blocks is None or nr < num_attention_blocks[level]: | 
					
						
						|  | layers.append( | 
					
						
						|  | AttentionBlock( | 
					
						
						|  | ch, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | num_heads=num_heads, | 
					
						
						|  | num_head_channels=dim_head, | 
					
						
						|  | use_new_attention_order=use_new_attention_order, | 
					
						
						|  | ) | 
					
						
						|  | if not use_spatial_transformer | 
					
						
						|  | else SpatialTransformer3D( | 
					
						
						|  | ch, | 
					
						
						|  | num_heads, | 
					
						
						|  | dim_head, | 
					
						
						|  | depth=transformer_depth, | 
					
						
						|  | context_dim=context_dim, | 
					
						
						|  | disable_self_attn=disabled_sa, | 
					
						
						|  | use_linear=use_linear_in_transformer, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | self.input_blocks.append(TimestepEmbedSequential(*layers)) | 
					
						
						|  | self._feature_size += ch | 
					
						
						|  | input_block_chans.append(ch) | 
					
						
						|  | if level != len(channel_mult) - 1: | 
					
						
						|  | out_ch = ch | 
					
						
						|  | self.input_blocks.append( | 
					
						
						|  | TimestepEmbedSequential( | 
					
						
						|  | ResBlock( | 
					
						
						|  | ch, | 
					
						
						|  | time_embed_dim, | 
					
						
						|  | dropout, | 
					
						
						|  | out_channels=out_ch, | 
					
						
						|  | dims=dims, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | use_scale_shift_norm=use_scale_shift_norm, | 
					
						
						|  | down=True, | 
					
						
						|  | ) | 
					
						
						|  | if resblock_updown | 
					
						
						|  | else Downsample( | 
					
						
						|  | ch, conv_resample, dims=dims, out_channels=out_ch | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | ch = out_ch | 
					
						
						|  | input_block_chans.append(ch) | 
					
						
						|  | ds *= 2 | 
					
						
						|  | self._feature_size += ch | 
					
						
						|  |  | 
					
						
						|  | if num_head_channels == -1: | 
					
						
						|  | dim_head = ch // num_heads | 
					
						
						|  | else: | 
					
						
						|  | num_heads = ch // num_head_channels | 
					
						
						|  | dim_head = num_head_channels | 
					
						
						|  | if legacy: | 
					
						
						|  |  | 
					
						
						|  | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels | 
					
						
						|  | self.middle_block = TimestepEmbedSequential( | 
					
						
						|  | ResBlock( | 
					
						
						|  | ch, | 
					
						
						|  | time_embed_dim, | 
					
						
						|  | dropout, | 
					
						
						|  | dims=dims, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | use_scale_shift_norm=use_scale_shift_norm, | 
					
						
						|  | ), | 
					
						
						|  | AttentionBlock( | 
					
						
						|  | ch, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | num_heads=num_heads, | 
					
						
						|  | num_head_channels=dim_head, | 
					
						
						|  | use_new_attention_order=use_new_attention_order, | 
					
						
						|  | ) | 
					
						
						|  | if not use_spatial_transformer | 
					
						
						|  | else SpatialTransformer3D( | 
					
						
						|  | ch, | 
					
						
						|  | num_heads, | 
					
						
						|  | dim_head, | 
					
						
						|  | depth=transformer_depth, | 
					
						
						|  | context_dim=context_dim, | 
					
						
						|  | disable_self_attn=disable_middle_self_attn, | 
					
						
						|  | use_linear=use_linear_in_transformer, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | ), | 
					
						
						|  | ResBlock( | 
					
						
						|  | ch, | 
					
						
						|  | time_embed_dim, | 
					
						
						|  | dropout, | 
					
						
						|  | dims=dims, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | use_scale_shift_norm=use_scale_shift_norm, | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | self._feature_size += ch | 
					
						
						|  |  | 
					
						
						|  | self.output_blocks = nn.ModuleList([]) | 
					
						
						|  | for level, mult in list(enumerate(channel_mult))[::-1]: | 
					
						
						|  | for i in range(self.num_res_blocks[level] + 1): | 
					
						
						|  | ich = input_block_chans.pop() | 
					
						
						|  | layers = [ | 
					
						
						|  | ResBlock( | 
					
						
						|  | ch + ich, | 
					
						
						|  | time_embed_dim, | 
					
						
						|  | dropout, | 
					
						
						|  | out_channels=model_channels * mult, | 
					
						
						|  | dims=dims, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | use_scale_shift_norm=use_scale_shift_norm, | 
					
						
						|  | ) | 
					
						
						|  | ] | 
					
						
						|  | ch = model_channels * mult | 
					
						
						|  | if ds in attention_resolutions: | 
					
						
						|  | if num_head_channels == -1: | 
					
						
						|  | dim_head = ch // num_heads | 
					
						
						|  | else: | 
					
						
						|  | num_heads = ch // num_head_channels | 
					
						
						|  | dim_head = num_head_channels | 
					
						
						|  | if legacy: | 
					
						
						|  |  | 
					
						
						|  | dim_head = ( | 
					
						
						|  | ch // num_heads | 
					
						
						|  | if use_spatial_transformer | 
					
						
						|  | else num_head_channels | 
					
						
						|  | ) | 
					
						
						|  | if disable_self_attentions is not None: | 
					
						
						|  | disabled_sa = disable_self_attentions[level] | 
					
						
						|  | else: | 
					
						
						|  | disabled_sa = False | 
					
						
						|  |  | 
					
						
						|  | if num_attention_blocks is None or i < num_attention_blocks[level]: | 
					
						
						|  | layers.append( | 
					
						
						|  | AttentionBlock( | 
					
						
						|  | ch, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | num_heads=num_heads_upsample, | 
					
						
						|  | num_head_channels=dim_head, | 
					
						
						|  | use_new_attention_order=use_new_attention_order, | 
					
						
						|  | ) | 
					
						
						|  | if not use_spatial_transformer | 
					
						
						|  | else SpatialTransformer3D( | 
					
						
						|  | ch, | 
					
						
						|  | num_heads, | 
					
						
						|  | dim_head, | 
					
						
						|  | depth=transformer_depth, | 
					
						
						|  | context_dim=context_dim, | 
					
						
						|  | disable_self_attn=disabled_sa, | 
					
						
						|  | use_linear=use_linear_in_transformer, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | ) | 
					
						
						|  | ) | 
					
						
						|  | if level and i == self.num_res_blocks[level]: | 
					
						
						|  | out_ch = ch | 
					
						
						|  | layers.append( | 
					
						
						|  | ResBlock( | 
					
						
						|  | ch, | 
					
						
						|  | time_embed_dim, | 
					
						
						|  | dropout, | 
					
						
						|  | out_channels=out_ch, | 
					
						
						|  | dims=dims, | 
					
						
						|  | use_checkpoint=use_checkpoint, | 
					
						
						|  | use_scale_shift_norm=use_scale_shift_norm, | 
					
						
						|  | up=True, | 
					
						
						|  | ) | 
					
						
						|  | if resblock_updown | 
					
						
						|  | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) | 
					
						
						|  | ) | 
					
						
						|  | ds //= 2 | 
					
						
						|  | self.output_blocks.append(TimestepEmbedSequential(*layers)) | 
					
						
						|  | self._feature_size += ch | 
					
						
						|  |  | 
					
						
						|  | self.out = nn.Sequential( | 
					
						
						|  | nn.GroupNorm(32, ch), | 
					
						
						|  | nn.SiLU(), | 
					
						
						|  | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), | 
					
						
						|  | ) | 
					
						
						|  | if self.predict_codebook_ids: | 
					
						
						|  | self.id_predictor = nn.Sequential( | 
					
						
						|  | nn.GroupNorm(32, ch), | 
					
						
						|  | conv_nd(dims, model_channels, n_embed, 1), | 
					
						
						|  |  | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | x, | 
					
						
						|  | timesteps=None, | 
					
						
						|  | context=None, | 
					
						
						|  | y: Optional[Tensor] = None, | 
					
						
						|  | camera=None, | 
					
						
						|  | num_frames=1, | 
					
						
						|  | **kwargs, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | Apply the model to an input batch. | 
					
						
						|  | :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). | 
					
						
						|  | :param timesteps: a 1-D batch of timesteps. | 
					
						
						|  | :param context: conditioning plugged in via crossattn | 
					
						
						|  | :param y: an [N] Tensor of labels, if class-conditional. | 
					
						
						|  | :param num_frames: a integer indicating number of frames for tensor reshaping. | 
					
						
						|  | :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). | 
					
						
						|  | """ | 
					
						
						|  | assert ( | 
					
						
						|  | x.shape[0] % num_frames == 0 | 
					
						
						|  | ), "[UNet] input batch size must be dividable by num_frames!" | 
					
						
						|  | assert (y is not None) == ( | 
					
						
						|  | self.num_classes is not None | 
					
						
						|  | ), "must specify y if and only if the model is class-conditional" | 
					
						
						|  | hs = [] | 
					
						
						|  | t_emb = timestep_embedding( | 
					
						
						|  | timesteps, self.model_channels, repeat_only=False | 
					
						
						|  | ).to(x.dtype) | 
					
						
						|  |  | 
					
						
						|  | emb = self.time_embed(t_emb) | 
					
						
						|  |  | 
					
						
						|  | if self.num_classes is not None: | 
					
						
						|  | assert y is not None | 
					
						
						|  | assert y.shape[0] == x.shape[0] | 
					
						
						|  | emb = emb + self.label_emb(y) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if camera is not None: | 
					
						
						|  | assert camera.shape[0] == emb.shape[0] | 
					
						
						|  | emb = emb + self.camera_embed(camera) | 
					
						
						|  |  | 
					
						
						|  | h = x | 
					
						
						|  | for module in self.input_blocks: | 
					
						
						|  | h = module(h, emb, context, num_frames=num_frames) | 
					
						
						|  | hs.append(h) | 
					
						
						|  | h = self.middle_block(h, emb, context, num_frames=num_frames) | 
					
						
						|  | for module in self.output_blocks: | 
					
						
						|  | h = th.cat([h, hs.pop()], dim=1) | 
					
						
						|  | h = module(h, emb, context, num_frames=num_frames) | 
					
						
						|  | h = h.type(x.dtype) | 
					
						
						|  | if self.predict_codebook_ids: | 
					
						
						|  | return self.id_predictor(h) | 
					
						
						|  | else: | 
					
						
						|  | return self.out(h) | 
					
						
						|  |  |