clean up
Browse files- README.md +8 -1
- main.py +13 -3
- mvdream/models.py +13 -25
- mvdream/util.py +0 -196
README.md
CHANGED
|
@@ -12,7 +12,14 @@ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd
|
|
| 12 |
python convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v2.1-base-4view.pt --dump_path ./weights --original_config_file ./sd-v2-base.yaml --half --to_safetensors --test
|
| 13 |
```
|
| 14 |
|
| 15 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
```python
|
| 17 |
import torch
|
| 18 |
import kiui
|
|
|
|
| 12 |
python convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v2.1-base-4view.pt --dump_path ./weights --original_config_file ./sd-v2-base.yaml --half --to_safetensors --test
|
| 13 |
```
|
| 14 |
|
| 15 |
+
### usage
|
| 16 |
+
|
| 17 |
+
example:
|
| 18 |
+
```bash
|
| 19 |
+
python main.py "a cute owl"
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
detailed usage:
|
| 23 |
```python
|
| 24 |
import torch
|
| 25 |
import kiui
|
main.py
CHANGED
|
@@ -1,11 +1,21 @@
|
|
| 1 |
import torch
|
| 2 |
import kiui
|
|
|
|
|
|
|
| 3 |
from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
|
| 4 |
|
| 5 |
pipe = MVDreamStableDiffusionPipeline.from_pretrained('./weights', torch_dtype=torch.float16)
|
| 6 |
pipe = pipe.to("cuda")
|
| 7 |
|
| 8 |
-
prompt = "a photo of an astronaut riding a horse on mars"
|
| 9 |
-
image = pipe(prompt)
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import kiui
|
| 3 |
+
import numpy as np
|
| 4 |
+
import argparse
|
| 5 |
from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
|
| 6 |
|
| 7 |
pipe = MVDreamStableDiffusionPipeline.from_pretrained('./weights', torch_dtype=torch.float16)
|
| 8 |
pipe = pipe.to("cuda")
|
| 9 |
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
parser = argparse.ArgumentParser(description='MVDream')
|
| 12 |
+
parser.add_argument('prompt', type=str, default="a cute owl 3d model")
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
|
| 15 |
+
while True:
|
| 16 |
+
image = pipe(args.prompt)
|
| 17 |
+
grid = np.concatenate([
|
| 18 |
+
np.concatenate([image[0], image[2]], axis=0),
|
| 19 |
+
np.concatenate([image[1], image[3]], axis=0),
|
| 20 |
+
], axis=1)
|
| 21 |
+
kiui.vis.plot_image(grid)
|
mvdream/models.py
CHANGED
|
@@ -10,10 +10,8 @@ from abc import abstractmethod
|
|
| 10 |
from .util import (
|
| 11 |
checkpoint,
|
| 12 |
conv_nd,
|
| 13 |
-
linear,
|
| 14 |
avg_pool_nd,
|
| 15 |
zero_module,
|
| 16 |
-
normalization,
|
| 17 |
timestep_embedding,
|
| 18 |
)
|
| 19 |
from .attention import SpatialTransformer, SpatialTransformer3D
|
|
@@ -56,7 +54,7 @@ class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
|
|
| 56 |
adm_in_channels=None,
|
| 57 |
camera_dim=None,):
|
| 58 |
super().__init__()
|
| 59 |
-
self.unet
|
| 60 |
image_size=image_size,
|
| 61 |
in_channels=in_channels,
|
| 62 |
model_channels=model_channels,
|
|
@@ -218,7 +216,7 @@ class ResBlock(TimestepBlock):
|
|
| 218 |
self.use_scale_shift_norm = use_scale_shift_norm
|
| 219 |
|
| 220 |
self.in_layers = nn.Sequential(
|
| 221 |
-
|
| 222 |
nn.SiLU(),
|
| 223 |
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
| 224 |
)
|
|
@@ -236,13 +234,13 @@ class ResBlock(TimestepBlock):
|
|
| 236 |
|
| 237 |
self.emb_layers = nn.Sequential(
|
| 238 |
nn.SiLU(),
|
| 239 |
-
|
| 240 |
emb_channels,
|
| 241 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 242 |
),
|
| 243 |
)
|
| 244 |
self.out_layers = nn.Sequential(
|
| 245 |
-
|
| 246 |
nn.SiLU(),
|
| 247 |
nn.Dropout(p=dropout),
|
| 248 |
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
|
@@ -310,7 +308,7 @@ class AttentionBlock(nn.Module):
|
|
| 310 |
assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| 311 |
self.num_heads = channels // num_head_channels
|
| 312 |
self.use_checkpoint = use_checkpoint
|
| 313 |
-
self.norm =
|
| 314 |
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
| 315 |
if use_new_attention_order:
|
| 316 |
# split qkv before split heads
|
|
@@ -418,16 +416,6 @@ class QKVAttention(nn.Module):
|
|
| 418 |
return count_flops_attn(model, _x, y)
|
| 419 |
|
| 420 |
|
| 421 |
-
class Timestep(nn.Module):
|
| 422 |
-
|
| 423 |
-
def __init__(self, dim):
|
| 424 |
-
super().__init__()
|
| 425 |
-
self.dim = dim
|
| 426 |
-
|
| 427 |
-
def forward(self, t):
|
| 428 |
-
return timestep_embedding(t, self.dim)
|
| 429 |
-
|
| 430 |
-
|
| 431 |
class MultiViewUNetModel(nn.Module):
|
| 432 |
"""
|
| 433 |
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
|
@@ -545,17 +533,17 @@ class MultiViewUNetModel(nn.Module):
|
|
| 545 |
|
| 546 |
time_embed_dim = model_channels * 4
|
| 547 |
self.time_embed = nn.Sequential(
|
| 548 |
-
|
| 549 |
nn.SiLU(),
|
| 550 |
-
|
| 551 |
)
|
| 552 |
|
| 553 |
if camera_dim is not None:
|
| 554 |
time_embed_dim = model_channels * 4
|
| 555 |
self.camera_embed = nn.Sequential(
|
| 556 |
-
|
| 557 |
nn.SiLU(),
|
| 558 |
-
|
| 559 |
)
|
| 560 |
|
| 561 |
if self.num_classes is not None:
|
|
@@ -567,9 +555,9 @@ class MultiViewUNetModel(nn.Module):
|
|
| 567 |
elif self.num_classes == "sequential":
|
| 568 |
assert adm_in_channels is not None
|
| 569 |
self.label_emb = nn.Sequential(nn.Sequential(
|
| 570 |
-
|
| 571 |
nn.SiLU(),
|
| 572 |
-
|
| 573 |
))
|
| 574 |
else:
|
| 575 |
raise ValueError()
|
|
@@ -722,13 +710,13 @@ class MultiViewUNetModel(nn.Module):
|
|
| 722 |
self._feature_size += ch
|
| 723 |
|
| 724 |
self.out = nn.Sequential(
|
| 725 |
-
|
| 726 |
nn.SiLU(),
|
| 727 |
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
| 728 |
)
|
| 729 |
if self.predict_codebook_ids:
|
| 730 |
self.id_predictor = nn.Sequential(
|
| 731 |
-
|
| 732 |
conv_nd(dims, model_channels, n_embed, 1),
|
| 733 |
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
| 734 |
)
|
|
|
|
| 10 |
from .util import (
|
| 11 |
checkpoint,
|
| 12 |
conv_nd,
|
|
|
|
| 13 |
avg_pool_nd,
|
| 14 |
zero_module,
|
|
|
|
| 15 |
timestep_embedding,
|
| 16 |
)
|
| 17 |
from .attention import SpatialTransformer, SpatialTransformer3D
|
|
|
|
| 54 |
adm_in_channels=None,
|
| 55 |
camera_dim=None,):
|
| 56 |
super().__init__()
|
| 57 |
+
self.unet = MultiViewUNetModel(
|
| 58 |
image_size=image_size,
|
| 59 |
in_channels=in_channels,
|
| 60 |
model_channels=model_channels,
|
|
|
|
| 216 |
self.use_scale_shift_norm = use_scale_shift_norm
|
| 217 |
|
| 218 |
self.in_layers = nn.Sequential(
|
| 219 |
+
nn.GroupNorm(32, channels),
|
| 220 |
nn.SiLU(),
|
| 221 |
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
| 222 |
)
|
|
|
|
| 234 |
|
| 235 |
self.emb_layers = nn.Sequential(
|
| 236 |
nn.SiLU(),
|
| 237 |
+
nn.Linear(
|
| 238 |
emb_channels,
|
| 239 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 240 |
),
|
| 241 |
)
|
| 242 |
self.out_layers = nn.Sequential(
|
| 243 |
+
nn.GroupNorm(32, self.out_channels),
|
| 244 |
nn.SiLU(),
|
| 245 |
nn.Dropout(p=dropout),
|
| 246 |
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
|
|
|
| 308 |
assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
| 309 |
self.num_heads = channels // num_head_channels
|
| 310 |
self.use_checkpoint = use_checkpoint
|
| 311 |
+
self.norm = nn.GroupNorm(32, channels)
|
| 312 |
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
| 313 |
if use_new_attention_order:
|
| 314 |
# split qkv before split heads
|
|
|
|
| 416 |
return count_flops_attn(model, _x, y)
|
| 417 |
|
| 418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
class MultiViewUNetModel(nn.Module):
|
| 420 |
"""
|
| 421 |
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
|
|
|
| 533 |
|
| 534 |
time_embed_dim = model_channels * 4
|
| 535 |
self.time_embed = nn.Sequential(
|
| 536 |
+
nn.Linear(model_channels, time_embed_dim),
|
| 537 |
nn.SiLU(),
|
| 538 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
| 539 |
)
|
| 540 |
|
| 541 |
if camera_dim is not None:
|
| 542 |
time_embed_dim = model_channels * 4
|
| 543 |
self.camera_embed = nn.Sequential(
|
| 544 |
+
nn.Linear(camera_dim, time_embed_dim),
|
| 545 |
nn.SiLU(),
|
| 546 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
| 547 |
)
|
| 548 |
|
| 549 |
if self.num_classes is not None:
|
|
|
|
| 555 |
elif self.num_classes == "sequential":
|
| 556 |
assert adm_in_channels is not None
|
| 557 |
self.label_emb = nn.Sequential(nn.Sequential(
|
| 558 |
+
nn.Linear(adm_in_channels, time_embed_dim),
|
| 559 |
nn.SiLU(),
|
| 560 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
| 561 |
))
|
| 562 |
else:
|
| 563 |
raise ValueError()
|
|
|
|
| 710 |
self._feature_size += ch
|
| 711 |
|
| 712 |
self.out = nn.Sequential(
|
| 713 |
+
nn.GroupNorm(32, ch),
|
| 714 |
nn.SiLU(),
|
| 715 |
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
| 716 |
)
|
| 717 |
if self.predict_codebook_ids:
|
| 718 |
self.id_predictor = nn.Sequential(
|
| 719 |
+
nn.GroupNorm(32, ch),
|
| 720 |
conv_nd(dims, model_channels, n_embed, 1),
|
| 721 |
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
| 722 |
)
|
mvdream/util.py
CHANGED
|
@@ -10,136 +10,7 @@
|
|
| 10 |
import math
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
| 13 |
-
import numpy as np
|
| 14 |
-
import importlib
|
| 15 |
from einops import repeat
|
| 16 |
-
from typing import Any
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def instantiate_from_config(config):
|
| 20 |
-
if not "target" in config:
|
| 21 |
-
if config == '__is_first_stage__':
|
| 22 |
-
return None
|
| 23 |
-
elif config == "__is_unconditional__":
|
| 24 |
-
return None
|
| 25 |
-
raise KeyError("Expected key `target` to instantiate.")
|
| 26 |
-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def get_obj_from_str(string, reload=False):
|
| 30 |
-
module, cls = string.rsplit(".", 1)
|
| 31 |
-
if reload:
|
| 32 |
-
module_imp = importlib.import_module(module)
|
| 33 |
-
importlib.reload(module_imp)
|
| 34 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def make_beta_schedule(schedule,
|
| 38 |
-
n_timestep,
|
| 39 |
-
linear_start=1e-4,
|
| 40 |
-
linear_end=2e-2,
|
| 41 |
-
cosine_s=8e-3):
|
| 42 |
-
if schedule == "linear":
|
| 43 |
-
betas = (torch.linspace(linear_start**0.5,
|
| 44 |
-
linear_end**0.5,
|
| 45 |
-
n_timestep,
|
| 46 |
-
dtype=torch.float64)**2)
|
| 47 |
-
|
| 48 |
-
elif schedule == "cosine":
|
| 49 |
-
timesteps = (
|
| 50 |
-
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
|
| 51 |
-
cosine_s)
|
| 52 |
-
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
| 53 |
-
alphas = torch.cos(alphas).pow(2)
|
| 54 |
-
alphas = alphas / alphas[0]
|
| 55 |
-
betas = 1 - alphas[1:] / alphas[:-1]
|
| 56 |
-
betas = np.clip(betas, a_min=0, a_max=0.999)
|
| 57 |
-
|
| 58 |
-
elif schedule == "sqrt_linear":
|
| 59 |
-
betas = torch.linspace(linear_start,
|
| 60 |
-
linear_end,
|
| 61 |
-
n_timestep,
|
| 62 |
-
dtype=torch.float64)
|
| 63 |
-
elif schedule == "sqrt":
|
| 64 |
-
betas = torch.linspace(linear_start,
|
| 65 |
-
linear_end,
|
| 66 |
-
n_timestep,
|
| 67 |
-
dtype=torch.float64)**0.5
|
| 68 |
-
else:
|
| 69 |
-
raise ValueError(f"schedule '{schedule}' unknown.")
|
| 70 |
-
return betas.numpy() # type: ignore
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def make_ddim_timesteps(ddim_discr_method,
|
| 74 |
-
num_ddim_timesteps,
|
| 75 |
-
num_ddpm_timesteps,
|
| 76 |
-
verbose=True):
|
| 77 |
-
if ddim_discr_method == 'uniform':
|
| 78 |
-
c = num_ddpm_timesteps // num_ddim_timesteps
|
| 79 |
-
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
| 80 |
-
elif ddim_discr_method == 'quad':
|
| 81 |
-
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
|
| 82 |
-
num_ddim_timesteps))**2).astype(int)
|
| 83 |
-
else:
|
| 84 |
-
raise NotImplementedError(
|
| 85 |
-
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
| 89 |
-
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
| 90 |
-
steps_out = ddim_timesteps + 1
|
| 91 |
-
if verbose:
|
| 92 |
-
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
| 93 |
-
return steps_out
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def make_ddim_sampling_parameters(alphacums,
|
| 97 |
-
ddim_timesteps,
|
| 98 |
-
eta,
|
| 99 |
-
verbose=True):
|
| 100 |
-
# select alphas for computing the variance schedule
|
| 101 |
-
alphas = alphacums[ddim_timesteps]
|
| 102 |
-
alphas_prev = np.asarray([alphacums[0]] +
|
| 103 |
-
alphacums[ddim_timesteps[:-1]].tolist())
|
| 104 |
-
|
| 105 |
-
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
| 106 |
-
sigmas = eta * np.sqrt(
|
| 107 |
-
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
| 108 |
-
if verbose:
|
| 109 |
-
print(
|
| 110 |
-
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
| 111 |
-
)
|
| 112 |
-
print(
|
| 113 |
-
f'For the chosen value of eta, which is {eta}, '
|
| 114 |
-
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
| 115 |
-
)
|
| 116 |
-
return sigmas, alphas, alphas_prev
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
| 120 |
-
"""
|
| 121 |
-
Create a beta schedule that discretizes the given alpha_t_bar function,
|
| 122 |
-
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
| 123 |
-
:param num_diffusion_timesteps: the number of betas to produce.
|
| 124 |
-
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
| 125 |
-
produces the cumulative product of (1-beta) up to that
|
| 126 |
-
part of the diffusion process.
|
| 127 |
-
:param max_beta: the maximum beta to use; use values lower than 1 to
|
| 128 |
-
prevent singularities.
|
| 129 |
-
"""
|
| 130 |
-
betas = []
|
| 131 |
-
for i in range(num_diffusion_timesteps):
|
| 132 |
-
t1 = i / num_diffusion_timesteps
|
| 133 |
-
t2 = (i + 1) / num_diffusion_timesteps
|
| 134 |
-
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
| 135 |
-
return np.array(betas)
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
def extract_into_tensor(a, t, x_shape):
|
| 139 |
-
b, *_ = t.shape
|
| 140 |
-
out = a.gather(-1, t)
|
| 141 |
-
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
|
| 142 |
-
|
| 143 |
|
| 144 |
def checkpoint(func, inputs, params, flag):
|
| 145 |
"""
|
|
@@ -227,45 +98,6 @@ def zero_module(module):
|
|
| 227 |
p.detach().zero_()
|
| 228 |
return module
|
| 229 |
|
| 230 |
-
|
| 231 |
-
def scale_module(module, scale):
|
| 232 |
-
"""
|
| 233 |
-
Scale the parameters of a module and return it.
|
| 234 |
-
"""
|
| 235 |
-
for p in module.parameters():
|
| 236 |
-
p.detach().mul_(scale)
|
| 237 |
-
return module
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def mean_flat(tensor):
|
| 241 |
-
"""
|
| 242 |
-
Take the mean over all non-batch dimensions.
|
| 243 |
-
"""
|
| 244 |
-
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
def normalization(channels):
|
| 248 |
-
"""
|
| 249 |
-
Make a standard normalization layer.
|
| 250 |
-
:param channels: number of input channels.
|
| 251 |
-
:return: an nn.Module for normalization.
|
| 252 |
-
"""
|
| 253 |
-
return GroupNorm32(32, channels)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
| 257 |
-
class SiLU(nn.Module):
|
| 258 |
-
|
| 259 |
-
def forward(self, x):
|
| 260 |
-
return x * torch.sigmoid(x)
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
class GroupNorm32(nn.GroupNorm):
|
| 264 |
-
|
| 265 |
-
def forward(self, x):
|
| 266 |
-
return super().forward(x)
|
| 267 |
-
|
| 268 |
-
|
| 269 |
def conv_nd(dims, *args, **kwargs):
|
| 270 |
"""
|
| 271 |
Create a 1D, 2D, or 3D convolution module.
|
|
@@ -279,13 +111,6 @@ def conv_nd(dims, *args, **kwargs):
|
|
| 279 |
raise ValueError(f"unsupported dimensions: {dims}")
|
| 280 |
|
| 281 |
|
| 282 |
-
def linear(*args, **kwargs):
|
| 283 |
-
"""
|
| 284 |
-
Create a linear module.
|
| 285 |
-
"""
|
| 286 |
-
return nn.Linear(*args, **kwargs)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
def avg_pool_nd(dims, *args, **kwargs):
|
| 290 |
"""
|
| 291 |
Create a 1D, 2D, or 3D average pooling module.
|
|
@@ -297,24 +122,3 @@ def avg_pool_nd(dims, *args, **kwargs):
|
|
| 297 |
elif dims == 3:
|
| 298 |
return nn.AvgPool3d(*args, **kwargs)
|
| 299 |
raise ValueError(f"unsupported dimensions: {dims}")
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
class HybridConditioner(nn.Module):
|
| 303 |
-
|
| 304 |
-
def __init__(self, c_concat_config, c_crossattn_config):
|
| 305 |
-
super().__init__()
|
| 306 |
-
self.concat_conditioner: Any = instantiate_from_config(c_concat_config)
|
| 307 |
-
self.crossattn_conditioner: Any = instantiate_from_config(
|
| 308 |
-
c_crossattn_config)
|
| 309 |
-
|
| 310 |
-
def forward(self, c_concat, c_crossattn):
|
| 311 |
-
c_concat = self.concat_conditioner(c_concat)
|
| 312 |
-
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
| 313 |
-
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
def noise_like(shape, device, repeat=False):
|
| 317 |
-
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
| 318 |
-
shape[0], *((1, ) * (len(shape) - 1)))
|
| 319 |
-
noise = lambda: torch.randn(shape, device=device)
|
| 320 |
-
return repeat_noise() if repeat else noise()
|
|
|
|
| 10 |
import math
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
|
|
|
|
|
|
| 13 |
from einops import repeat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def checkpoint(func, inputs, params, flag):
|
| 16 |
"""
|
|
|
|
| 98 |
p.detach().zero_()
|
| 99 |
return module
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
def conv_nd(dims, *args, **kwargs):
|
| 102 |
"""
|
| 103 |
Create a 1D, 2D, or 3D convolution module.
|
|
|
|
| 111 |
raise ValueError(f"unsupported dimensions: {dims}")
|
| 112 |
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
def avg_pool_nd(dims, *args, **kwargs):
|
| 115 |
"""
|
| 116 |
Create a 1D, 2D, or 3D average pooling module.
|
|
|
|
| 122 |
elif dims == 3:
|
| 123 |
return nn.AvgPool3d(*args, **kwargs)
|
| 124 |
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|