fix mvdream
Browse files- .gitignore +1 -0
- README.md +3 -0
- convert_mvdream_to_diffusers.py +6 -0
- mvdream/attention.py +1 -4
- mvdream/models.py +9 -23
- mvdream/util.py +0 -9
    	
        .gitignore
    CHANGED
    
    | @@ -4,4 +4,5 @@ | |
| 4 | 
             
            *.pyc
         | 
| 5 |  | 
| 6 | 
             
            weights
         | 
|  | |
| 7 | 
             
            sd-v2*
         | 
|  | |
| 4 | 
             
            *.pyc
         | 
| 5 |  | 
| 6 | 
             
            weights
         | 
| 7 | 
            +
            models
         | 
| 8 | 
             
            sd-v2*
         | 
    	
        README.md
    CHANGED
    
    | @@ -4,6 +4,9 @@ modified from https://github.com/KokeCacao/mvdream-hf. | |
| 4 |  | 
| 5 | 
             
            ### convert weights
         | 
| 6 | 
             
            ```bash
         | 
|  | |
|  | |
|  | |
| 7 | 
             
            # download original ckpt
         | 
| 8 | 
             
            wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
         | 
| 9 | 
             
            wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
         | 
|  | |
| 4 |  | 
| 5 | 
             
            ### convert weights
         | 
| 6 | 
             
            ```bash
         | 
| 7 | 
            +
            # dependency
         | 
| 8 | 
            +
            pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
         | 
| 9 | 
            +
             | 
| 10 | 
             
            # download original ckpt
         | 
| 11 | 
             
            wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
         | 
| 12 | 
             
            wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
         | 
    	
        convert_mvdream_to_diffusers.py
    CHANGED
    
    | @@ -405,6 +405,12 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de | |
| 405 | 
             
                # )
         | 
| 406 | 
             
                # print(f"Unet Config: {original_config.model.params.unet_config.params}")
         | 
| 407 | 
             
                unet_config = create_unet_config(original_config)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 408 | 
             
                unet = MultiViewUNetModel(**unet_config)
         | 
| 409 | 
             
                unet.register_to_config(**unet_config)
         | 
| 410 | 
             
                # print(f"Unet State Dict: {unet.state_dict().keys()}")
         | 
|  | |
| 405 | 
             
                # )
         | 
| 406 | 
             
                # print(f"Unet Config: {original_config.model.params.unet_config.params}")
         | 
| 407 | 
             
                unet_config = create_unet_config(original_config)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                # remove unused configs
         | 
| 410 | 
            +
                del unet_config['legacy']
         | 
| 411 | 
            +
                del unet_config['use_linear_in_transformer']
         | 
| 412 | 
            +
                del unet_config['use_spatial_transformer']
         | 
| 413 | 
            +
             | 
| 414 | 
             
                unet = MultiViewUNetModel(**unet_config)
         | 
| 415 | 
             
                unet.register_to_config(**unet_config)
         | 
| 416 | 
             
                # print(f"Unet State Dict: {unet.state_dict().keys()}")
         | 
    	
        mvdream/attention.py
    CHANGED
    
    | @@ -1,6 +1,3 @@ | |
| 1 | 
            -
            # obtained and modified from https://github.com/bytedance/MVDream
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            import math
         | 
| 4 | 
             
            import torch
         | 
| 5 | 
             
            import torch.nn as nn
         | 
| 6 | 
             
            import torch.nn.functional as F
         | 
| @@ -14,9 +11,9 @@ from .util import checkpoint, zero_module | |
| 14 | 
             
            try:
         | 
| 15 | 
             
                import xformers  # type: ignore
         | 
| 16 | 
             
                import xformers.ops  # type: ignore
         | 
| 17 | 
            -
             | 
| 18 | 
             
                XFORMERS_IS_AVAILBLE = True
         | 
| 19 | 
             
            except:
         | 
|  | |
| 20 | 
             
                XFORMERS_IS_AVAILBLE = False
         | 
| 21 |  | 
| 22 | 
             
            # CrossAttn precision handling
         | 
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
             
            import torch.nn as nn
         | 
| 3 | 
             
            import torch.nn.functional as F
         | 
|  | |
| 11 | 
             
            try:
         | 
| 12 | 
             
                import xformers  # type: ignore
         | 
| 13 | 
             
                import xformers.ops  # type: ignore
         | 
|  | |
| 14 | 
             
                XFORMERS_IS_AVAILBLE = True
         | 
| 15 | 
             
            except:
         | 
| 16 | 
            +
                print(f'[WARN] xformers is unavailable!')
         | 
| 17 | 
             
                XFORMERS_IS_AVAILBLE = False
         | 
| 18 |  | 
| 19 | 
             
            # CrossAttn precision handling
         | 
    	
        mvdream/models.py
    CHANGED
    
    | @@ -1,6 +1,3 @@ | |
| 1 | 
            -
            # obtained and modified from https://github.com/bytedance/MVDream
         | 
| 2 | 
            -
             | 
| 3 | 
            -
            import math
         | 
| 4 | 
             
            import torch
         | 
| 5 | 
             
            import torch.nn as nn
         | 
| 6 | 
             
            import torch.nn.functional as F
         | 
| @@ -9,7 +6,6 @@ from diffusers.models.modeling_utils import ModelMixin | |
| 9 | 
             
            from typing import Any, List, Optional
         | 
| 10 | 
             
            from torch import Tensor
         | 
| 11 |  | 
| 12 | 
            -
            from abc import abstractmethod
         | 
| 13 | 
             
            from .util import (
         | 
| 14 | 
             
                checkpoint,
         | 
| 15 | 
             
                conv_nd,
         | 
| @@ -19,19 +15,8 @@ from .util import ( | |
| 19 | 
             
            )
         | 
| 20 | 
             
            from .attention import SpatialTransformer, SpatialTransformer3D
         | 
| 21 |  | 
| 22 | 
            -
            class TimestepBlock(nn.Module):
         | 
| 23 | 
            -
                """
         | 
| 24 | 
            -
                Any module where forward() takes timestep embeddings as a second argument.
         | 
| 25 | 
            -
                """
         | 
| 26 |  | 
| 27 | 
            -
             | 
| 28 | 
            -
                def forward(self, x, emb):
         | 
| 29 | 
            -
                    """
         | 
| 30 | 
            -
                    Apply the module to `x` given `emb` timestep embeddings.
         | 
| 31 | 
            -
                    """
         | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
            class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
         | 
| 35 | 
             
                """
         | 
| 36 | 
             
                A sequential module that passes timestep embeddings to the children that
         | 
| 37 | 
             
                support it as an extra input.
         | 
| @@ -39,7 +24,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): | |
| 39 |  | 
| 40 | 
             
                def forward(self, x, emb, context=None, num_frames=1):
         | 
| 41 | 
             
                    for layer in self:
         | 
| 42 | 
            -
                        if isinstance(layer,  | 
| 43 | 
             
                            x = layer(x, emb)
         | 
| 44 | 
             
                        elif isinstance(layer, SpatialTransformer3D):
         | 
| 45 | 
             
                            x = layer(x, context, num_frames=num_frames)
         | 
| @@ -117,7 +102,7 @@ class Downsample(nn.Module): | |
| 117 | 
             
                    return self.op(x)
         | 
| 118 |  | 
| 119 |  | 
| 120 | 
            -
            class ResBlock( | 
| 121 | 
             
                """
         | 
| 122 | 
             
                A residual block that can optionally change the number of channels.
         | 
| 123 | 
             
                :param channels: the number of input channels.
         | 
| @@ -289,6 +274,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin): | |
| 289 | 
             
                    disable_middle_self_attn=False,
         | 
| 290 | 
             
                    adm_in_channels=None,
         | 
| 291 | 
             
                    camera_dim=None,
         | 
|  | |
| 292 | 
             
                ):
         | 
| 293 | 
             
                    super().__init__()
         | 
| 294 | 
             
                    assert context_dim is not None
         | 
| @@ -383,7 +369,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin): | |
| 383 |  | 
| 384 | 
             
                    self.input_blocks = nn.ModuleList(
         | 
| 385 | 
             
                        [
         | 
| 386 | 
            -
                             | 
| 387 | 
             
                                conv_nd(dims, in_channels, model_channels, 3, padding=1)
         | 
| 388 | 
             
                            )
         | 
| 389 | 
             
                        ]
         | 
| @@ -430,13 +416,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin): | |
| 430 | 
             
                                            use_checkpoint=use_checkpoint,
         | 
| 431 | 
             
                                        )
         | 
| 432 | 
             
                                    )
         | 
| 433 | 
            -
                            self.input_blocks.append( | 
| 434 | 
             
                            self._feature_size += ch
         | 
| 435 | 
             
                            input_block_chans.append(ch)
         | 
| 436 | 
             
                        if level != len(channel_mult) - 1:
         | 
| 437 | 
             
                            out_ch = ch
         | 
| 438 | 
             
                            self.input_blocks.append(
         | 
| 439 | 
            -
                                 | 
| 440 | 
             
                                    ResBlock(
         | 
| 441 | 
             
                                        ch,
         | 
| 442 | 
             
                                        time_embed_dim,
         | 
| @@ -464,7 +450,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin): | |
| 464 | 
             
                        num_heads = ch // num_head_channels
         | 
| 465 | 
             
                        dim_head = num_head_channels
         | 
| 466 |  | 
| 467 | 
            -
                    self.middle_block =  | 
| 468 | 
             
                        ResBlock(
         | 
| 469 | 
             
                            ch,
         | 
| 470 | 
             
                            time_embed_dim,
         | 
| @@ -550,7 +536,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin): | |
| 550 | 
             
                                    else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
         | 
| 551 | 
             
                                )
         | 
| 552 | 
             
                                ds //= 2
         | 
| 553 | 
            -
                            self.output_blocks.append( | 
| 554 | 
             
                            self._feature_size += ch
         | 
| 555 |  | 
| 556 | 
             
                    self.out = nn.Sequential(
         | 
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
             
            import torch.nn as nn
         | 
| 3 | 
             
            import torch.nn.functional as F
         | 
|  | |
| 6 | 
             
            from typing import Any, List, Optional
         | 
| 7 | 
             
            from torch import Tensor
         | 
| 8 |  | 
|  | |
| 9 | 
             
            from .util import (
         | 
| 10 | 
             
                checkpoint,
         | 
| 11 | 
             
                conv_nd,
         | 
|  | |
| 15 | 
             
            )
         | 
| 16 | 
             
            from .attention import SpatialTransformer, SpatialTransformer3D
         | 
| 17 |  | 
|  | |
|  | |
|  | |
|  | |
| 18 |  | 
| 19 | 
            +
            class CondSequential(nn.Sequential):
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 20 | 
             
                """
         | 
| 21 | 
             
                A sequential module that passes timestep embeddings to the children that
         | 
| 22 | 
             
                support it as an extra input.
         | 
|  | |
| 24 |  | 
| 25 | 
             
                def forward(self, x, emb, context=None, num_frames=1):
         | 
| 26 | 
             
                    for layer in self:
         | 
| 27 | 
            +
                        if isinstance(layer, ResBlock):
         | 
| 28 | 
             
                            x = layer(x, emb)
         | 
| 29 | 
             
                        elif isinstance(layer, SpatialTransformer3D):
         | 
| 30 | 
             
                            x = layer(x, context, num_frames=num_frames)
         | 
|  | |
| 102 | 
             
                    return self.op(x)
         | 
| 103 |  | 
| 104 |  | 
| 105 | 
            +
            class ResBlock(nn.Module):
         | 
| 106 | 
             
                """
         | 
| 107 | 
             
                A residual block that can optionally change the number of channels.
         | 
| 108 | 
             
                :param channels: the number of input channels.
         | 
|  | |
| 274 | 
             
                    disable_middle_self_attn=False,
         | 
| 275 | 
             
                    adm_in_channels=None,
         | 
| 276 | 
             
                    camera_dim=None,
         | 
| 277 | 
            +
                    **kwargs,
         | 
| 278 | 
             
                ):
         | 
| 279 | 
             
                    super().__init__()
         | 
| 280 | 
             
                    assert context_dim is not None
         | 
|  | |
| 369 |  | 
| 370 | 
             
                    self.input_blocks = nn.ModuleList(
         | 
| 371 | 
             
                        [
         | 
| 372 | 
            +
                            CondSequential(
         | 
| 373 | 
             
                                conv_nd(dims, in_channels, model_channels, 3, padding=1)
         | 
| 374 | 
             
                            )
         | 
| 375 | 
             
                        ]
         | 
|  | |
| 416 | 
             
                                            use_checkpoint=use_checkpoint,
         | 
| 417 | 
             
                                        )
         | 
| 418 | 
             
                                    )
         | 
| 419 | 
            +
                            self.input_blocks.append(CondSequential(*layers))
         | 
| 420 | 
             
                            self._feature_size += ch
         | 
| 421 | 
             
                            input_block_chans.append(ch)
         | 
| 422 | 
             
                        if level != len(channel_mult) - 1:
         | 
| 423 | 
             
                            out_ch = ch
         | 
| 424 | 
             
                            self.input_blocks.append(
         | 
| 425 | 
            +
                                CondSequential(
         | 
| 426 | 
             
                                    ResBlock(
         | 
| 427 | 
             
                                        ch,
         | 
| 428 | 
             
                                        time_embed_dim,
         | 
|  | |
| 450 | 
             
                        num_heads = ch // num_head_channels
         | 
| 451 | 
             
                        dim_head = num_head_channels
         | 
| 452 |  | 
| 453 | 
            +
                    self.middle_block = CondSequential(
         | 
| 454 | 
             
                        ResBlock(
         | 
| 455 | 
             
                            ch,
         | 
| 456 | 
             
                            time_embed_dim,
         | 
|  | |
| 536 | 
             
                                    else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
         | 
| 537 | 
             
                                )
         | 
| 538 | 
             
                                ds //= 2
         | 
| 539 | 
            +
                            self.output_blocks.append(CondSequential(*layers))
         | 
| 540 | 
             
                            self._feature_size += ch
         | 
| 541 |  | 
| 542 | 
             
                    self.out = nn.Sequential(
         | 
    	
        mvdream/util.py
    CHANGED
    
    | @@ -1,12 +1,3 @@ | |
| 1 | 
            -
            # adopted from
         | 
| 2 | 
            -
            # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
         | 
| 3 | 
            -
            # and
         | 
| 4 | 
            -
            # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
         | 
| 5 | 
            -
            # and
         | 
| 6 | 
            -
            # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
         | 
| 7 | 
            -
            #
         | 
| 8 | 
            -
            # thanks!
         | 
| 9 | 
            -
             | 
| 10 | 
             
            import math
         | 
| 11 | 
             
            import torch
         | 
| 12 | 
             
            import torch.nn as nn
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import math
         | 
| 2 | 
             
            import torch
         | 
| 3 | 
             
            import torch.nn as nn
         | 

