further clean!
Browse files- main.py +2 -2
- mvdream/attention.py +13 -85
- mvdream/models.py +12 -176
main.py
CHANGED
|
@@ -5,8 +5,8 @@ import argparse
|
|
| 5 |
from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
|
| 6 |
|
| 7 |
pipe = MVDreamStableDiffusionPipeline.from_pretrained(
|
| 8 |
-
|
| 9 |
-
"ashawkey/mvdream-sd2.1-diffusers",
|
| 10 |
torch_dtype=torch.float16
|
| 11 |
)
|
| 12 |
pipe = pipe.to("cuda")
|
|
|
|
| 5 |
from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
|
| 6 |
|
| 7 |
pipe = MVDreamStableDiffusionPipeline.from_pretrained(
|
| 8 |
+
"./weights", # local weights
|
| 9 |
+
# "ashawkey/mvdream-sd2.1-diffusers",
|
| 10 |
torch_dtype=torch.float16
|
| 11 |
)
|
| 12 |
pipe = pipe.to("cuda")
|
mvdream/attention.py
CHANGED
|
@@ -2,14 +2,14 @@
|
|
| 2 |
|
| 3 |
import math
|
| 4 |
import torch
|
|
|
|
| 5 |
import torch.nn.functional as F
|
|
|
|
| 6 |
|
| 7 |
from inspect import isfunction
|
| 8 |
-
from torch import nn, einsum
|
| 9 |
-
from torch.amp.autocast_mode import autocast
|
| 10 |
from einops import rearrange, repeat
|
| 11 |
from typing import Optional, Any
|
| 12 |
-
from .util import checkpoint
|
| 13 |
|
| 14 |
try:
|
| 15 |
import xformers # type: ignore
|
|
@@ -25,28 +25,12 @@ import os
|
|
| 25 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
| 26 |
|
| 27 |
|
| 28 |
-
def uniq(arr):
|
| 29 |
-
return {el: True for el in arr}.keys()
|
| 30 |
-
|
| 31 |
-
|
| 32 |
def default(val, d):
|
| 33 |
if val is not None:
|
| 34 |
return val
|
| 35 |
return d() if isfunction(d) else d
|
| 36 |
|
| 37 |
|
| 38 |
-
def max_neg_value(t):
|
| 39 |
-
return -torch.finfo(t.dtype).max
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def init_(tensor):
|
| 43 |
-
dim = tensor.shape[-1]
|
| 44 |
-
std = 1 / math.sqrt(dim)
|
| 45 |
-
tensor.uniform_(-std, std)
|
| 46 |
-
return tensor
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# feedforward
|
| 50 |
class GEGLU(nn.Module):
|
| 51 |
def __init__(self, dim_in, dim_out):
|
| 52 |
super().__init__()
|
|
@@ -76,66 +60,6 @@ class FeedForward(nn.Module):
|
|
| 76 |
return self.net(x)
|
| 77 |
|
| 78 |
|
| 79 |
-
def zero_module(module):
|
| 80 |
-
"""
|
| 81 |
-
Zero out the parameters of a module and return it.
|
| 82 |
-
"""
|
| 83 |
-
for p in module.parameters():
|
| 84 |
-
p.detach().zero_()
|
| 85 |
-
return module
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def Normalize(in_channels):
|
| 89 |
-
return torch.nn.GroupNorm(
|
| 90 |
-
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class SpatialSelfAttention(nn.Module):
|
| 95 |
-
def __init__(self, in_channels):
|
| 96 |
-
super().__init__()
|
| 97 |
-
self.in_channels = in_channels
|
| 98 |
-
|
| 99 |
-
self.norm = Normalize(in_channels)
|
| 100 |
-
self.q = torch.nn.Conv2d(
|
| 101 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 102 |
-
)
|
| 103 |
-
self.k = torch.nn.Conv2d(
|
| 104 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 105 |
-
)
|
| 106 |
-
self.v = torch.nn.Conv2d(
|
| 107 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 108 |
-
)
|
| 109 |
-
self.proj_out = torch.nn.Conv2d(
|
| 110 |
-
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
def forward(self, x):
|
| 114 |
-
h_ = x
|
| 115 |
-
h_ = self.norm(h_)
|
| 116 |
-
q = self.q(h_)
|
| 117 |
-
k = self.k(h_)
|
| 118 |
-
v = self.v(h_)
|
| 119 |
-
|
| 120 |
-
# compute attention
|
| 121 |
-
b, c, h, w = q.shape
|
| 122 |
-
q = rearrange(q, "b c h w -> b (h w) c")
|
| 123 |
-
k = rearrange(k, "b c h w -> b c (h w)")
|
| 124 |
-
w_ = torch.einsum("bij,bjk->bik", q, k)
|
| 125 |
-
|
| 126 |
-
w_ = w_ * (int(c) ** (-0.5))
|
| 127 |
-
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 128 |
-
|
| 129 |
-
# attend to values
|
| 130 |
-
v = rearrange(v, "b c h w -> b c (h w)")
|
| 131 |
-
w_ = rearrange(w_, "b i j -> b j i")
|
| 132 |
-
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
| 133 |
-
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
| 134 |
-
h_ = self.proj_out(h_)
|
| 135 |
-
|
| 136 |
-
return x + h_
|
| 137 |
-
|
| 138 |
-
|
| 139 |
class CrossAttention(nn.Module):
|
| 140 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
| 141 |
super().__init__()
|
|
@@ -167,9 +91,9 @@ class CrossAttention(nn.Module):
|
|
| 167 |
if _ATTN_PRECISION == "fp32":
|
| 168 |
with autocast(enabled=False, device_type="cuda"):
|
| 169 |
q, k = q.float(), k.float()
|
| 170 |
-
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
| 171 |
else:
|
| 172 |
-
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
|
| 173 |
|
| 174 |
del q, k
|
| 175 |
|
|
@@ -182,7 +106,7 @@ class CrossAttention(nn.Module):
|
|
| 182 |
# attention, what we cannot get enough of
|
| 183 |
sim = sim.softmax(dim=-1)
|
| 184 |
|
| 185 |
-
out = einsum("b i j, b j d -> b i d", sim, v)
|
| 186 |
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
| 187 |
return self.to_out(out)
|
| 188 |
|
|
@@ -326,7 +250,9 @@ class SpatialTransformer(nn.Module):
|
|
| 326 |
context_dim = [context_dim]
|
| 327 |
self.in_channels = in_channels
|
| 328 |
inner_dim = n_heads * d_head
|
| 329 |
-
self.norm =
|
|
|
|
|
|
|
| 330 |
if not use_linear:
|
| 331 |
self.proj_in = nn.Conv2d(
|
| 332 |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
|
@@ -410,7 +336,7 @@ class SpatialTransformer3D(nn.Module):
|
|
| 410 |
dropout=0.0,
|
| 411 |
context_dim=None,
|
| 412 |
disable_self_attn=False,
|
| 413 |
-
use_linear=
|
| 414 |
use_checkpoint=True,
|
| 415 |
):
|
| 416 |
super().__init__()
|
|
@@ -419,7 +345,9 @@ class SpatialTransformer3D(nn.Module):
|
|
| 419 |
context_dim = [context_dim]
|
| 420 |
self.in_channels = in_channels
|
| 421 |
inner_dim = n_heads * d_head
|
| 422 |
-
self.norm =
|
|
|
|
|
|
|
| 423 |
if not use_linear:
|
| 424 |
self.proj_in = nn.Conv2d(
|
| 425 |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
|
|
|
| 2 |
|
| 3 |
import math
|
| 4 |
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
+
from torch.amp.autocast_mode import autocast
|
| 8 |
|
| 9 |
from inspect import isfunction
|
|
|
|
|
|
|
| 10 |
from einops import rearrange, repeat
|
| 11 |
from typing import Optional, Any
|
| 12 |
+
from .util import checkpoint, zero_module
|
| 13 |
|
| 14 |
try:
|
| 15 |
import xformers # type: ignore
|
|
|
|
| 25 |
_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def default(val, d):
|
| 29 |
if val is not None:
|
| 30 |
return val
|
| 31 |
return d() if isfunction(d) else d
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
class GEGLU(nn.Module):
|
| 35 |
def __init__(self, dim_in, dim_out):
|
| 36 |
super().__init__()
|
|
|
|
| 60 |
return self.net(x)
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
class CrossAttention(nn.Module):
|
| 64 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
| 65 |
super().__init__()
|
|
|
|
| 91 |
if _ATTN_PRECISION == "fp32":
|
| 92 |
with autocast(enabled=False, device_type="cuda"):
|
| 93 |
q, k = q.float(), k.float()
|
| 94 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
| 95 |
else:
|
| 96 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
|
| 97 |
|
| 98 |
del q, k
|
| 99 |
|
|
|
|
| 106 |
# attention, what we cannot get enough of
|
| 107 |
sim = sim.softmax(dim=-1)
|
| 108 |
|
| 109 |
+
out = torch.einsum("b i j, b j d -> b i d", sim, v)
|
| 110 |
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
|
| 111 |
return self.to_out(out)
|
| 112 |
|
|
|
|
| 250 |
context_dim = [context_dim]
|
| 251 |
self.in_channels = in_channels
|
| 252 |
inner_dim = n_heads * d_head
|
| 253 |
+
self.norm = nn.GroupNorm(
|
| 254 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 255 |
+
)
|
| 256 |
if not use_linear:
|
| 257 |
self.proj_in = nn.Conv2d(
|
| 258 |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
|
|
|
| 336 |
dropout=0.0,
|
| 337 |
context_dim=None,
|
| 338 |
disable_self_attn=False,
|
| 339 |
+
use_linear=True,
|
| 340 |
use_checkpoint=True,
|
| 341 |
):
|
| 342 |
super().__init__()
|
|
|
|
| 345 |
context_dim = [context_dim]
|
| 346 |
self.in_channels = in_channels
|
| 347 |
inner_dim = n_heads * d_head
|
| 348 |
+
self.norm = nn.GroupNorm(
|
| 349 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
| 350 |
+
)
|
| 351 |
if not use_linear:
|
| 352 |
self.proj_in = nn.Conv2d(
|
| 353 |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
mvdream/models.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
# obtained and modified from https://github.com/bytedance/MVDream
|
| 2 |
|
| 3 |
import math
|
| 4 |
-
import
|
| 5 |
-
import torch as th
|
| 6 |
import torch.nn as nn
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from diffusers.configuration_utils import ConfigMixin
|
|
@@ -223,7 +222,7 @@ class ResBlock(TimestepBlock):
|
|
| 223 |
emb_out = emb_out[..., None]
|
| 224 |
if self.use_scale_shift_norm:
|
| 225 |
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| 226 |
-
scale, shift =
|
| 227 |
h = out_norm(h) * (1 + scale) + shift
|
| 228 |
h = out_rest(h)
|
| 229 |
else:
|
|
@@ -232,112 +231,6 @@ class ResBlock(TimestepBlock):
|
|
| 232 |
return self.skip_connection(x) + h
|
| 233 |
|
| 234 |
|
| 235 |
-
class AttentionBlock(nn.Module):
|
| 236 |
-
"""
|
| 237 |
-
An attention block that allows spatial positions to attend to each other.
|
| 238 |
-
Originally ported from here, but adapted to the N-d case.
|
| 239 |
-
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
| 240 |
-
"""
|
| 241 |
-
|
| 242 |
-
def __init__(
|
| 243 |
-
self,
|
| 244 |
-
channels,
|
| 245 |
-
num_heads=1,
|
| 246 |
-
num_head_channels=-1,
|
| 247 |
-
use_checkpoint=False,
|
| 248 |
-
use_new_attention_order=False,
|
| 249 |
-
):
|
| 250 |
-
super().__init__()
|
| 251 |
-
self.channels = channels
|
| 252 |
-
if num_head_channels == -1:
|
| 253 |
-
self.num_heads = num_heads
|
| 254 |
-
else:
|
| 255 |
-
assert (
|
| 256 |
-
channels % num_head_channels == 0
|
| 257 |
-
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| 258 |
-
self.num_heads = channels // num_head_channels
|
| 259 |
-
self.use_checkpoint = use_checkpoint
|
| 260 |
-
self.norm = nn.GroupNorm(32, channels)
|
| 261 |
-
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
| 262 |
-
if use_new_attention_order:
|
| 263 |
-
# split qkv before split heads
|
| 264 |
-
self.attention = QKVAttention(self.num_heads)
|
| 265 |
-
else:
|
| 266 |
-
# split heads before split qkv
|
| 267 |
-
self.attention = QKVAttentionLegacy(self.num_heads)
|
| 268 |
-
|
| 269 |
-
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
| 270 |
-
|
| 271 |
-
def forward(self, x):
|
| 272 |
-
return checkpoint(self._forward, (x,), self.parameters(), True)
|
| 273 |
-
|
| 274 |
-
def _forward(self, x):
|
| 275 |
-
b, c, *spatial = x.shape
|
| 276 |
-
x = x.reshape(b, c, -1)
|
| 277 |
-
qkv = self.qkv(self.norm(x))
|
| 278 |
-
h = self.attention(qkv)
|
| 279 |
-
h = self.proj_out(h)
|
| 280 |
-
return (x + h).reshape(b, c, *spatial)
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
class QKVAttentionLegacy(nn.Module):
|
| 284 |
-
"""
|
| 285 |
-
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
| 286 |
-
"""
|
| 287 |
-
|
| 288 |
-
def __init__(self, n_heads):
|
| 289 |
-
super().__init__()
|
| 290 |
-
self.n_heads = n_heads
|
| 291 |
-
|
| 292 |
-
def forward(self, qkv):
|
| 293 |
-
"""
|
| 294 |
-
Apply QKV attention.
|
| 295 |
-
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
| 296 |
-
:return: an [N x (H * C) x T] tensor after attention.
|
| 297 |
-
"""
|
| 298 |
-
bs, width, length = qkv.shape
|
| 299 |
-
assert width % (3 * self.n_heads) == 0
|
| 300 |
-
ch = width // (3 * self.n_heads)
|
| 301 |
-
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
| 302 |
-
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 303 |
-
weight = th.einsum(
|
| 304 |
-
"bct,bcs->bts", q * scale, k * scale
|
| 305 |
-
) # More stable with f16 than dividing afterwards
|
| 306 |
-
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 307 |
-
a = th.einsum("bts,bcs->bct", weight, v)
|
| 308 |
-
return a.reshape(bs, -1, length)
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
class QKVAttention(nn.Module):
|
| 312 |
-
"""
|
| 313 |
-
A module which performs QKV attention and splits in a different order.
|
| 314 |
-
"""
|
| 315 |
-
|
| 316 |
-
def __init__(self, n_heads):
|
| 317 |
-
super().__init__()
|
| 318 |
-
self.n_heads = n_heads
|
| 319 |
-
|
| 320 |
-
def forward(self, qkv):
|
| 321 |
-
"""
|
| 322 |
-
Apply QKV attention.
|
| 323 |
-
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
| 324 |
-
:return: an [N x (H * C) x T] tensor after attention.
|
| 325 |
-
"""
|
| 326 |
-
bs, width, length = qkv.shape
|
| 327 |
-
assert width % (3 * self.n_heads) == 0
|
| 328 |
-
ch = width // (3 * self.n_heads)
|
| 329 |
-
q, k, v = qkv.chunk(3, dim=1)
|
| 330 |
-
scale = 1 / math.sqrt(math.sqrt(ch))
|
| 331 |
-
weight = th.einsum(
|
| 332 |
-
"bct,bcs->bts",
|
| 333 |
-
(q * scale).view(bs * self.n_heads, ch, length),
|
| 334 |
-
(k * scale).view(bs * self.n_heads, ch, length),
|
| 335 |
-
) # More stable with f16 than dividing afterwards
|
| 336 |
-
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 337 |
-
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
| 338 |
-
return a.reshape(bs, -1, length)
|
| 339 |
-
|
| 340 |
-
|
| 341 |
class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
| 342 |
"""
|
| 343 |
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
|
@@ -388,34 +281,18 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 388 |
num_heads_upsample=-1,
|
| 389 |
use_scale_shift_norm=False,
|
| 390 |
resblock_updown=False,
|
| 391 |
-
use_new_attention_order=False,
|
| 392 |
-
use_spatial_transformer=False, # custom transformer support
|
| 393 |
transformer_depth=1, # custom transformer support
|
| 394 |
context_dim=None, # custom transformer support
|
| 395 |
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 396 |
-
legacy=True,
|
| 397 |
disable_self_attentions=None,
|
| 398 |
num_attention_blocks=None,
|
| 399 |
disable_middle_self_attn=False,
|
| 400 |
-
use_linear_in_transformer=False,
|
| 401 |
adm_in_channels=None,
|
| 402 |
camera_dim=None,
|
| 403 |
):
|
| 404 |
super().__init__()
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
context_dim is not None
|
| 408 |
-
), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
|
| 409 |
-
|
| 410 |
-
if context_dim is not None:
|
| 411 |
-
assert (
|
| 412 |
-
use_spatial_transformer
|
| 413 |
-
), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
|
| 414 |
-
from omegaconf.listconfig import ListConfig
|
| 415 |
-
|
| 416 |
-
if type(context_dim) == ListConfig:
|
| 417 |
-
context_dim = list(context_dim)
|
| 418 |
-
|
| 419 |
if num_heads_upsample == -1:
|
| 420 |
num_heads_upsample = num_heads
|
| 421 |
|
|
@@ -535,13 +412,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 535 |
else:
|
| 536 |
num_heads = ch // num_head_channels
|
| 537 |
dim_head = num_head_channels
|
| 538 |
-
|
| 539 |
-
# num_heads = 1
|
| 540 |
-
dim_head = (
|
| 541 |
-
ch // num_heads
|
| 542 |
-
if use_spatial_transformer
|
| 543 |
-
else num_head_channels
|
| 544 |
-
)
|
| 545 |
if disable_self_attentions is not None:
|
| 546 |
disabled_sa = disable_self_attentions[level]
|
| 547 |
else:
|
|
@@ -549,22 +420,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 549 |
|
| 550 |
if num_attention_blocks is None or nr < num_attention_blocks[level]:
|
| 551 |
layers.append(
|
| 552 |
-
|
| 553 |
-
ch,
|
| 554 |
-
use_checkpoint=use_checkpoint,
|
| 555 |
-
num_heads=num_heads,
|
| 556 |
-
num_head_channels=dim_head,
|
| 557 |
-
use_new_attention_order=use_new_attention_order,
|
| 558 |
-
)
|
| 559 |
-
if not use_spatial_transformer
|
| 560 |
-
else SpatialTransformer3D(
|
| 561 |
ch,
|
| 562 |
num_heads,
|
| 563 |
dim_head,
|
| 564 |
depth=transformer_depth,
|
| 565 |
context_dim=context_dim,
|
| 566 |
disable_self_attn=disabled_sa,
|
| 567 |
-
use_linear=use_linear_in_transformer,
|
| 568 |
use_checkpoint=use_checkpoint,
|
| 569 |
)
|
| 570 |
)
|
|
@@ -601,9 +463,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 601 |
else:
|
| 602 |
num_heads = ch // num_head_channels
|
| 603 |
dim_head = num_head_channels
|
| 604 |
-
|
| 605 |
-
# num_heads = 1
|
| 606 |
-
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 607 |
self.middle_block = TimestepEmbedSequential(
|
| 608 |
ResBlock(
|
| 609 |
ch,
|
|
@@ -613,24 +473,15 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 613 |
use_checkpoint=use_checkpoint,
|
| 614 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 615 |
),
|
| 616 |
-
|
| 617 |
-
ch,
|
| 618 |
-
use_checkpoint=use_checkpoint,
|
| 619 |
-
num_heads=num_heads,
|
| 620 |
-
num_head_channels=dim_head,
|
| 621 |
-
use_new_attention_order=use_new_attention_order,
|
| 622 |
-
)
|
| 623 |
-
if not use_spatial_transformer
|
| 624 |
-
else SpatialTransformer3D(
|
| 625 |
ch,
|
| 626 |
num_heads,
|
| 627 |
dim_head,
|
| 628 |
depth=transformer_depth,
|
| 629 |
context_dim=context_dim,
|
| 630 |
disable_self_attn=disable_middle_self_attn,
|
| 631 |
-
use_linear=use_linear_in_transformer,
|
| 632 |
use_checkpoint=use_checkpoint,
|
| 633 |
-
),
|
| 634 |
ResBlock(
|
| 635 |
ch,
|
| 636 |
time_embed_dim,
|
|
@@ -664,13 +515,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 664 |
else:
|
| 665 |
num_heads = ch // num_head_channels
|
| 666 |
dim_head = num_head_channels
|
| 667 |
-
|
| 668 |
-
# num_heads = 1
|
| 669 |
-
dim_head = (
|
| 670 |
-
ch // num_heads
|
| 671 |
-
if use_spatial_transformer
|
| 672 |
-
else num_head_channels
|
| 673 |
-
)
|
| 674 |
if disable_self_attentions is not None:
|
| 675 |
disabled_sa = disable_self_attentions[level]
|
| 676 |
else:
|
|
@@ -678,22 +523,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 678 |
|
| 679 |
if num_attention_blocks is None or i < num_attention_blocks[level]:
|
| 680 |
layers.append(
|
| 681 |
-
|
| 682 |
-
ch,
|
| 683 |
-
use_checkpoint=use_checkpoint,
|
| 684 |
-
num_heads=num_heads_upsample,
|
| 685 |
-
num_head_channels=dim_head,
|
| 686 |
-
use_new_attention_order=use_new_attention_order,
|
| 687 |
-
)
|
| 688 |
-
if not use_spatial_transformer
|
| 689 |
-
else SpatialTransformer3D(
|
| 690 |
ch,
|
| 691 |
num_heads,
|
| 692 |
dim_head,
|
| 693 |
depth=transformer_depth,
|
| 694 |
context_dim=context_dim,
|
| 695 |
disable_self_attn=disabled_sa,
|
| 696 |
-
use_linear=use_linear_in_transformer,
|
| 697 |
use_checkpoint=use_checkpoint,
|
| 698 |
)
|
| 699 |
)
|
|
@@ -777,7 +613,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 777 |
hs.append(h)
|
| 778 |
h = self.middle_block(h, emb, context, num_frames=num_frames)
|
| 779 |
for module in self.output_blocks:
|
| 780 |
-
h =
|
| 781 |
h = module(h, emb, context, num_frames=num_frames)
|
| 782 |
h = h.type(x.dtype)
|
| 783 |
if self.predict_codebook_ids:
|
|
|
|
| 1 |
# obtained and modified from https://github.com/bytedance/MVDream
|
| 2 |
|
| 3 |
import math
|
| 4 |
+
import torch
|
|
|
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
from diffusers.configuration_utils import ConfigMixin
|
|
|
|
| 222 |
emb_out = emb_out[..., None]
|
| 223 |
if self.use_scale_shift_norm:
|
| 224 |
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| 225 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
| 226 |
h = out_norm(h) * (1 + scale) + shift
|
| 227 |
h = out_rest(h)
|
| 228 |
else:
|
|
|
|
| 231 |
return self.skip_connection(x) + h
|
| 232 |
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
| 235 |
"""
|
| 236 |
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
|
|
|
| 281 |
num_heads_upsample=-1,
|
| 282 |
use_scale_shift_norm=False,
|
| 283 |
resblock_updown=False,
|
|
|
|
|
|
|
| 284 |
transformer_depth=1, # custom transformer support
|
| 285 |
context_dim=None, # custom transformer support
|
| 286 |
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
|
|
|
| 287 |
disable_self_attentions=None,
|
| 288 |
num_attention_blocks=None,
|
| 289 |
disable_middle_self_attn=False,
|
|
|
|
| 290 |
adm_in_channels=None,
|
| 291 |
camera_dim=None,
|
| 292 |
):
|
| 293 |
super().__init__()
|
| 294 |
+
assert context_dim is not None
|
| 295 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
if num_heads_upsample == -1:
|
| 297 |
num_heads_upsample = num_heads
|
| 298 |
|
|
|
|
| 412 |
else:
|
| 413 |
num_heads = ch // num_head_channels
|
| 414 |
dim_head = num_head_channels
|
| 415 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
if disable_self_attentions is not None:
|
| 417 |
disabled_sa = disable_self_attentions[level]
|
| 418 |
else:
|
|
|
|
| 420 |
|
| 421 |
if num_attention_blocks is None or nr < num_attention_blocks[level]:
|
| 422 |
layers.append(
|
| 423 |
+
SpatialTransformer3D(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
ch,
|
| 425 |
num_heads,
|
| 426 |
dim_head,
|
| 427 |
depth=transformer_depth,
|
| 428 |
context_dim=context_dim,
|
| 429 |
disable_self_attn=disabled_sa,
|
|
|
|
| 430 |
use_checkpoint=use_checkpoint,
|
| 431 |
)
|
| 432 |
)
|
|
|
|
| 463 |
else:
|
| 464 |
num_heads = ch // num_head_channels
|
| 465 |
dim_head = num_head_channels
|
| 466 |
+
|
|
|
|
|
|
|
| 467 |
self.middle_block = TimestepEmbedSequential(
|
| 468 |
ResBlock(
|
| 469 |
ch,
|
|
|
|
| 473 |
use_checkpoint=use_checkpoint,
|
| 474 |
use_scale_shift_norm=use_scale_shift_norm,
|
| 475 |
),
|
| 476 |
+
SpatialTransformer3D(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
ch,
|
| 478 |
num_heads,
|
| 479 |
dim_head,
|
| 480 |
depth=transformer_depth,
|
| 481 |
context_dim=context_dim,
|
| 482 |
disable_self_attn=disable_middle_self_attn,
|
|
|
|
| 483 |
use_checkpoint=use_checkpoint,
|
| 484 |
+
),
|
| 485 |
ResBlock(
|
| 486 |
ch,
|
| 487 |
time_embed_dim,
|
|
|
|
| 515 |
else:
|
| 516 |
num_heads = ch // num_head_channels
|
| 517 |
dim_head = num_head_channels
|
| 518 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
if disable_self_attentions is not None:
|
| 520 |
disabled_sa = disable_self_attentions[level]
|
| 521 |
else:
|
|
|
|
| 523 |
|
| 524 |
if num_attention_blocks is None or i < num_attention_blocks[level]:
|
| 525 |
layers.append(
|
| 526 |
+
SpatialTransformer3D(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
ch,
|
| 528 |
num_heads,
|
| 529 |
dim_head,
|
| 530 |
depth=transformer_depth,
|
| 531 |
context_dim=context_dim,
|
| 532 |
disable_self_attn=disabled_sa,
|
|
|
|
| 533 |
use_checkpoint=use_checkpoint,
|
| 534 |
)
|
| 535 |
)
|
|
|
|
| 613 |
hs.append(h)
|
| 614 |
h = self.middle_block(h, emb, context, num_frames=num_frames)
|
| 615 |
for module in self.output_blocks:
|
| 616 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 617 |
h = module(h, emb, context, num_frames=num_frames)
|
| 618 |
h = h.type(x.dtype)
|
| 619 |
if self.predict_codebook_ids:
|