fix mvdream
Browse files- .gitignore +1 -0
- README.md +3 -0
- convert_mvdream_to_diffusers.py +6 -0
- mvdream/attention.py +1 -4
- mvdream/models.py +9 -23
- mvdream/util.py +0 -9
.gitignore
CHANGED
|
@@ -4,4 +4,5 @@
|
|
| 4 |
*.pyc
|
| 5 |
|
| 6 |
weights
|
|
|
|
| 7 |
sd-v2*
|
|
|
|
| 4 |
*.pyc
|
| 5 |
|
| 6 |
weights
|
| 7 |
+
models
|
| 8 |
sd-v2*
|
README.md
CHANGED
|
@@ -4,6 +4,9 @@ modified from https://github.com/KokeCacao/mvdream-hf.
|
|
| 4 |
|
| 5 |
### convert weights
|
| 6 |
```bash
|
|
|
|
|
|
|
|
|
|
| 7 |
# download original ckpt
|
| 8 |
wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
|
| 9 |
wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
|
|
|
|
| 4 |
|
| 5 |
### convert weights
|
| 6 |
```bash
|
| 7 |
+
# dependency
|
| 8 |
+
pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
|
| 9 |
+
|
| 10 |
# download original ckpt
|
| 11 |
wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
|
| 12 |
wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
|
convert_mvdream_to_diffusers.py
CHANGED
|
@@ -405,6 +405,12 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
|
|
| 405 |
# )
|
| 406 |
# print(f"Unet Config: {original_config.model.params.unet_config.params}")
|
| 407 |
unet_config = create_unet_config(original_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
unet = MultiViewUNetModel(**unet_config)
|
| 409 |
unet.register_to_config(**unet_config)
|
| 410 |
# print(f"Unet State Dict: {unet.state_dict().keys()}")
|
|
|
|
| 405 |
# )
|
| 406 |
# print(f"Unet Config: {original_config.model.params.unet_config.params}")
|
| 407 |
unet_config = create_unet_config(original_config)
|
| 408 |
+
|
| 409 |
+
# remove unused configs
|
| 410 |
+
del unet_config['legacy']
|
| 411 |
+
del unet_config['use_linear_in_transformer']
|
| 412 |
+
del unet_config['use_spatial_transformer']
|
| 413 |
+
|
| 414 |
unet = MultiViewUNetModel(**unet_config)
|
| 415 |
unet.register_to_config(**unet_config)
|
| 416 |
# print(f"Unet State Dict: {unet.state_dict().keys()}")
|
mvdream/attention.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# obtained and modified from https://github.com/bytedance/MVDream
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
|
@@ -14,9 +11,9 @@ from .util import checkpoint, zero_module
|
|
| 14 |
try:
|
| 15 |
import xformers # type: ignore
|
| 16 |
import xformers.ops # type: ignore
|
| 17 |
-
|
| 18 |
XFORMERS_IS_AVAILBLE = True
|
| 19 |
except:
|
|
|
|
| 20 |
XFORMERS_IS_AVAILBLE = False
|
| 21 |
|
| 22 |
# CrossAttn precision handling
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
|
|
|
| 11 |
try:
|
| 12 |
import xformers # type: ignore
|
| 13 |
import xformers.ops # type: ignore
|
|
|
|
| 14 |
XFORMERS_IS_AVAILBLE = True
|
| 15 |
except:
|
| 16 |
+
print(f'[WARN] xformers is unavailable!')
|
| 17 |
XFORMERS_IS_AVAILBLE = False
|
| 18 |
|
| 19 |
# CrossAttn precision handling
|
mvdream/models.py
CHANGED
|
@@ -1,6 +1,3 @@
|
|
| 1 |
-
# obtained and modified from https://github.com/bytedance/MVDream
|
| 2 |
-
|
| 3 |
-
import math
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
|
@@ -9,7 +6,6 @@ from diffusers.models.modeling_utils import ModelMixin
|
|
| 9 |
from typing import Any, List, Optional
|
| 10 |
from torch import Tensor
|
| 11 |
|
| 12 |
-
from abc import abstractmethod
|
| 13 |
from .util import (
|
| 14 |
checkpoint,
|
| 15 |
conv_nd,
|
|
@@ -19,19 +15,8 @@ from .util import (
|
|
| 19 |
)
|
| 20 |
from .attention import SpatialTransformer, SpatialTransformer3D
|
| 21 |
|
| 22 |
-
class TimestepBlock(nn.Module):
|
| 23 |
-
"""
|
| 24 |
-
Any module where forward() takes timestep embeddings as a second argument.
|
| 25 |
-
"""
|
| 26 |
|
| 27 |
-
|
| 28 |
-
def forward(self, x, emb):
|
| 29 |
-
"""
|
| 30 |
-
Apply the module to `x` given `emb` timestep embeddings.
|
| 31 |
-
"""
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 35 |
"""
|
| 36 |
A sequential module that passes timestep embeddings to the children that
|
| 37 |
support it as an extra input.
|
|
@@ -39,7 +24,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
|
| 39 |
|
| 40 |
def forward(self, x, emb, context=None, num_frames=1):
|
| 41 |
for layer in self:
|
| 42 |
-
if isinstance(layer,
|
| 43 |
x = layer(x, emb)
|
| 44 |
elif isinstance(layer, SpatialTransformer3D):
|
| 45 |
x = layer(x, context, num_frames=num_frames)
|
|
@@ -117,7 +102,7 @@ class Downsample(nn.Module):
|
|
| 117 |
return self.op(x)
|
| 118 |
|
| 119 |
|
| 120 |
-
class ResBlock(
|
| 121 |
"""
|
| 122 |
A residual block that can optionally change the number of channels.
|
| 123 |
:param channels: the number of input channels.
|
|
@@ -289,6 +274,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 289 |
disable_middle_self_attn=False,
|
| 290 |
adm_in_channels=None,
|
| 291 |
camera_dim=None,
|
|
|
|
| 292 |
):
|
| 293 |
super().__init__()
|
| 294 |
assert context_dim is not None
|
|
@@ -383,7 +369,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 383 |
|
| 384 |
self.input_blocks = nn.ModuleList(
|
| 385 |
[
|
| 386 |
-
|
| 387 |
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 388 |
)
|
| 389 |
]
|
|
@@ -430,13 +416,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 430 |
use_checkpoint=use_checkpoint,
|
| 431 |
)
|
| 432 |
)
|
| 433 |
-
self.input_blocks.append(
|
| 434 |
self._feature_size += ch
|
| 435 |
input_block_chans.append(ch)
|
| 436 |
if level != len(channel_mult) - 1:
|
| 437 |
out_ch = ch
|
| 438 |
self.input_blocks.append(
|
| 439 |
-
|
| 440 |
ResBlock(
|
| 441 |
ch,
|
| 442 |
time_embed_dim,
|
|
@@ -464,7 +450,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 464 |
num_heads = ch // num_head_channels
|
| 465 |
dim_head = num_head_channels
|
| 466 |
|
| 467 |
-
self.middle_block =
|
| 468 |
ResBlock(
|
| 469 |
ch,
|
| 470 |
time_embed_dim,
|
|
@@ -550,7 +536,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
|
|
| 550 |
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
| 551 |
)
|
| 552 |
ds //= 2
|
| 553 |
-
self.output_blocks.append(
|
| 554 |
self._feature_size += ch
|
| 555 |
|
| 556 |
self.out = nn.Sequential(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
|
|
|
| 6 |
from typing import Any, List, Optional
|
| 7 |
from torch import Tensor
|
| 8 |
|
|
|
|
| 9 |
from .util import (
|
| 10 |
checkpoint,
|
| 11 |
conv_nd,
|
|
|
|
| 15 |
)
|
| 16 |
from .attention import SpatialTransformer, SpatialTransformer3D
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
class CondSequential(nn.Sequential):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
"""
|
| 21 |
A sequential module that passes timestep embeddings to the children that
|
| 22 |
support it as an extra input.
|
|
|
|
| 24 |
|
| 25 |
def forward(self, x, emb, context=None, num_frames=1):
|
| 26 |
for layer in self:
|
| 27 |
+
if isinstance(layer, ResBlock):
|
| 28 |
x = layer(x, emb)
|
| 29 |
elif isinstance(layer, SpatialTransformer3D):
|
| 30 |
x = layer(x, context, num_frames=num_frames)
|
|
|
|
| 102 |
return self.op(x)
|
| 103 |
|
| 104 |
|
| 105 |
+
class ResBlock(nn.Module):
|
| 106 |
"""
|
| 107 |
A residual block that can optionally change the number of channels.
|
| 108 |
:param channels: the number of input channels.
|
|
|
|
| 274 |
disable_middle_self_attn=False,
|
| 275 |
adm_in_channels=None,
|
| 276 |
camera_dim=None,
|
| 277 |
+
**kwargs,
|
| 278 |
):
|
| 279 |
super().__init__()
|
| 280 |
assert context_dim is not None
|
|
|
|
| 369 |
|
| 370 |
self.input_blocks = nn.ModuleList(
|
| 371 |
[
|
| 372 |
+
CondSequential(
|
| 373 |
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 374 |
)
|
| 375 |
]
|
|
|
|
| 416 |
use_checkpoint=use_checkpoint,
|
| 417 |
)
|
| 418 |
)
|
| 419 |
+
self.input_blocks.append(CondSequential(*layers))
|
| 420 |
self._feature_size += ch
|
| 421 |
input_block_chans.append(ch)
|
| 422 |
if level != len(channel_mult) - 1:
|
| 423 |
out_ch = ch
|
| 424 |
self.input_blocks.append(
|
| 425 |
+
CondSequential(
|
| 426 |
ResBlock(
|
| 427 |
ch,
|
| 428 |
time_embed_dim,
|
|
|
|
| 450 |
num_heads = ch // num_head_channels
|
| 451 |
dim_head = num_head_channels
|
| 452 |
|
| 453 |
+
self.middle_block = CondSequential(
|
| 454 |
ResBlock(
|
| 455 |
ch,
|
| 456 |
time_embed_dim,
|
|
|
|
| 536 |
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
| 537 |
)
|
| 538 |
ds //= 2
|
| 539 |
+
self.output_blocks.append(CondSequential(*layers))
|
| 540 |
self._feature_size += ch
|
| 541 |
|
| 542 |
self.out = nn.Sequential(
|
mvdream/util.py
CHANGED
|
@@ -1,12 +1,3 @@
|
|
| 1 |
-
# adopted from
|
| 2 |
-
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
| 3 |
-
# and
|
| 4 |
-
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| 5 |
-
# and
|
| 6 |
-
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
| 7 |
-
#
|
| 8 |
-
# thanks!
|
| 9 |
-
|
| 10 |
import math
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|