not-lain Maikou commited on
Commit
522e1ad
·
verified ·
0 Parent(s):

Duplicate from Maikou/Michelangelo

Browse files

Co-authored-by: Zhao <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +3 -0
  3. README.md +13 -0
  4. checkpoints/aligned_shape_latents/shapevae-256.ckpt +3 -0
  5. checkpoints/clip/clip-vit-large-patch14 +1 -0
  6. checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt +3 -0
  7. checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt +3 -0
  8. configs/aligned_shape_latents/shapevae-256.yaml +46 -0
  9. configs/image_cond_diffuser_asl/image-ASLDM-256.yaml +97 -0
  10. configs/text_cond_diffuser_asl/text-ASLDM-256.yaml +98 -0
  11. example_data/image/car.jpg +0 -0
  12. example_data/surface/surface.npz +3 -0
  13. gradio_app.py +372 -0
  14. gradio_cached_dir/example/img_example/airplane.jpg +0 -0
  15. gradio_cached_dir/example/img_example/alita.jpg +0 -0
  16. gradio_cached_dir/example/img_example/bag.jpg +0 -0
  17. gradio_cached_dir/example/img_example/bench.jpg +0 -0
  18. gradio_cached_dir/example/img_example/building.jpg +0 -0
  19. gradio_cached_dir/example/img_example/burger.jpg +0 -0
  20. gradio_cached_dir/example/img_example/car.jpg +0 -0
  21. gradio_cached_dir/example/img_example/loopy.jpg +0 -0
  22. gradio_cached_dir/example/img_example/mario.jpg +0 -0
  23. gradio_cached_dir/example/img_example/ship.jpg +0 -0
  24. inference.py +181 -0
  25. michelangelo/__init__.py +1 -0
  26. michelangelo/data/__init__.py +1 -0
  27. michelangelo/data/templates.json +69 -0
  28. michelangelo/data/transforms.py +407 -0
  29. michelangelo/data/utils.py +59 -0
  30. michelangelo/graphics/__init__.py +1 -0
  31. michelangelo/graphics/primitives/__init__.py +9 -0
  32. michelangelo/graphics/primitives/mesh.py +114 -0
  33. michelangelo/graphics/primitives/volume.py +21 -0
  34. michelangelo/models/__init__.py +1 -0
  35. michelangelo/models/asl_diffusion/__init__.py +1 -0
  36. michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py +483 -0
  37. michelangelo/models/asl_diffusion/asl_udt.py +104 -0
  38. michelangelo/models/asl_diffusion/base.py +13 -0
  39. michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py +393 -0
  40. michelangelo/models/asl_diffusion/inference_utils.py +80 -0
  41. michelangelo/models/conditional_encoders/__init__.py +3 -0
  42. michelangelo/models/conditional_encoders/clip.py +89 -0
  43. michelangelo/models/conditional_encoders/encoder_factory.py +562 -0
  44. michelangelo/models/modules/__init__.py +3 -0
  45. michelangelo/models/modules/checkpoint.py +69 -0
  46. michelangelo/models/modules/diffusion_transformer.py +218 -0
  47. michelangelo/models/modules/distributions.py +100 -0
  48. michelangelo/models/modules/embedder.py +213 -0
  49. michelangelo/models/modules/transformer_blocks.py +286 -0
  50. michelangelo/models/modules/transformer_vit.py +308 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea
2
+ .vscode
3
+ __pycache__
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: lgpl-3.0
3
+ pipeline_tag: text-to-3d
4
+ tags:
5
+ - image-to-3d
6
+ ---
7
+
8
+ # Michelangelo
9
+
10
+ * [Project Page](https://neuralcarver.github.io/michelangelo/)
11
+ * [Paper](https://arxiv.org/abs/2306.17115)
12
+ * [Code](https://github.com/NeuralCarver/Michelangelo)
13
+ * [Demo](https://huggingface.co/spaces/Maikou/Michelangelo)
checkpoints/aligned_shape_latents/shapevae-256.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0391b81c36240e8f766fedf4265df599884193a5ef65354525074b9a00887454
3
+ size 3934164973
checkpoints/clip/clip-vit-large-patch14 ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8d052a0f05efbaefbc9e8786ba291cfdf93e5bff
checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83eda8e4f81034dee7674b3ce1ff03a4900181f0f0d7bc461e1a8692fb379b0f
3
+ size 1999253985
checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af546b1f877a41d71f63c3a11394779e77c954002c50dc8e75359338224f615b
3
+ size 4076140813
configs/aligned_shape_latents/shapevae-256.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
3
+ params:
4
+ shape_module_cfg:
5
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
6
+ params:
7
+ num_latents: 256
8
+ embed_dim: 64
9
+ point_feats: 3 # normal
10
+ num_freqs: 8
11
+ include_pi: false
12
+ heads: 12
13
+ width: 768
14
+ num_encoder_layers: 8
15
+ num_decoder_layers: 16
16
+ use_ln_post: true
17
+ init_scale: 0.25
18
+ qkv_bias: false
19
+ use_checkpoint: true
20
+ aligned_module_cfg:
21
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
22
+ params:
23
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
24
+
25
+ loss_cfg:
26
+ target: michelangelo.models.tsal.loss.ContrastKLNearFar
27
+ params:
28
+ contrast_weight: 0.1
29
+ near_weight: 0.1
30
+ kl_weight: 0.001
31
+
32
+ optimizer_cfg:
33
+ optimizer:
34
+ target: torch.optim.AdamW
35
+ params:
36
+ betas: [0.9, 0.99]
37
+ eps: 1.e-6
38
+ weight_decay: 1.e-2
39
+
40
+ scheduler:
41
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
42
+ params:
43
+ warm_up_steps: 5000
44
+ f_start: 1.e-6
45
+ f_min: 1.e-3
46
+ f_max: 1.0
configs/image_cond_diffuser_asl/image-ASLDM-256.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3
+ params:
4
+ first_stage_config:
5
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6
+ params:
7
+ shape_module_cfg:
8
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9
+ params:
10
+ num_latents: &num_latents 256
11
+ embed_dim: &embed_dim 64
12
+ point_feats: 3 # normal
13
+ num_freqs: 8
14
+ include_pi: false
15
+ heads: 12
16
+ width: 768
17
+ num_encoder_layers: 8
18
+ num_decoder_layers: 16
19
+ use_ln_post: true
20
+ init_scale: 0.25
21
+ qkv_bias: false
22
+ use_checkpoint: false
23
+ aligned_module_cfg:
24
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25
+ params:
26
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27
+
28
+ loss_cfg:
29
+ target: torch.nn.Identity
30
+
31
+ cond_stage_config:
32
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder
33
+ params:
34
+ version: "./checkpoints/clip/clip-vit-large-patch14"
35
+ zero_embedding_radio: 0.1
36
+
37
+ first_stage_key: "surface"
38
+ cond_stage_key: "image"
39
+ scale_by_std: false
40
+
41
+ denoiser_cfg:
42
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
43
+ params:
44
+ input_channels: *embed_dim
45
+ output_channels: *embed_dim
46
+ n_ctx: *num_latents
47
+ width: 768
48
+ layers: 6 # 2 * 6 + 1 = 13
49
+ heads: 12
50
+ context_dim: 1024
51
+ init_scale: 1.0
52
+ skip_ln: true
53
+ use_checkpoint: true
54
+
55
+ scheduler_cfg:
56
+ guidance_scale: 7.5
57
+ num_inference_steps: 50
58
+ eta: 0.0
59
+
60
+ noise:
61
+ target: diffusers.schedulers.DDPMScheduler
62
+ params:
63
+ num_train_timesteps: 1000
64
+ beta_start: 0.00085
65
+ beta_end: 0.012
66
+ beta_schedule: "scaled_linear"
67
+ variance_type: "fixed_small"
68
+ clip_sample: false
69
+ denoise:
70
+ target: diffusers.schedulers.DDIMScheduler
71
+ params:
72
+ num_train_timesteps: 1000
73
+ beta_start: 0.00085
74
+ beta_end: 0.012
75
+ beta_schedule: "scaled_linear"
76
+ clip_sample: false # clip sample to -1~1
77
+ set_alpha_to_one: false
78
+ steps_offset: 1
79
+
80
+ optimizer_cfg:
81
+ optimizer:
82
+ target: torch.optim.AdamW
83
+ params:
84
+ betas: [0.9, 0.99]
85
+ eps: 1.e-6
86
+ weight_decay: 1.e-2
87
+
88
+ scheduler:
89
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
90
+ params:
91
+ warm_up_steps: 5000
92
+ f_start: 1.e-6
93
+ f_min: 1.e-3
94
+ f_max: 1.0
95
+
96
+ loss_cfg:
97
+ loss_type: "mse"
configs/text_cond_diffuser_asl/text-ASLDM-256.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3
+ params:
4
+ first_stage_config:
5
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6
+ params:
7
+ shape_module_cfg:
8
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9
+ params:
10
+ num_latents: &num_latents 256
11
+ embed_dim: &embed_dim 64
12
+ point_feats: 3 # normal
13
+ num_freqs: 8
14
+ include_pi: false
15
+ heads: 12
16
+ width: 768
17
+ num_encoder_layers: 8
18
+ num_decoder_layers: 16
19
+ use_ln_post: true
20
+ init_scale: 0.25
21
+ qkv_bias: false
22
+ use_checkpoint: true
23
+ aligned_module_cfg:
24
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25
+ params:
26
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27
+
28
+ loss_cfg:
29
+ target: torch.nn.Identity
30
+
31
+ cond_stage_config:
32
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder
33
+ params:
34
+ version: "./checkpoints/clip/clip-vit-large-patch14"
35
+ zero_embedding_radio: 0.1
36
+ max_length: 77
37
+
38
+ first_stage_key: "surface"
39
+ cond_stage_key: "text"
40
+ scale_by_std: false
41
+
42
+ denoiser_cfg:
43
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
44
+ params:
45
+ input_channels: *embed_dim
46
+ output_channels: *embed_dim
47
+ n_ctx: *num_latents
48
+ width: 768
49
+ layers: 8 # 2 * 6 + 1 = 13
50
+ heads: 12
51
+ context_dim: 768
52
+ init_scale: 1.0
53
+ skip_ln: true
54
+ use_checkpoint: true
55
+
56
+ scheduler_cfg:
57
+ guidance_scale: 7.5
58
+ num_inference_steps: 50
59
+ eta: 0.0
60
+
61
+ noise:
62
+ target: diffusers.schedulers.DDPMScheduler
63
+ params:
64
+ num_train_timesteps: 1000
65
+ beta_start: 0.00085
66
+ beta_end: 0.012
67
+ beta_schedule: "scaled_linear"
68
+ variance_type: "fixed_small"
69
+ clip_sample: false
70
+ denoise:
71
+ target: diffusers.schedulers.DDIMScheduler
72
+ params:
73
+ num_train_timesteps: 1000
74
+ beta_start: 0.00085
75
+ beta_end: 0.012
76
+ beta_schedule: "scaled_linear"
77
+ clip_sample: false # clip sample to -1~1
78
+ set_alpha_to_one: false
79
+ steps_offset: 1
80
+
81
+ optimizer_cfg:
82
+ optimizer:
83
+ target: torch.optim.AdamW
84
+ params:
85
+ betas: [0.9, 0.99]
86
+ eps: 1.e-6
87
+ weight_decay: 1.e-2
88
+
89
+ scheduler:
90
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
91
+ params:
92
+ warm_up_steps: 5000
93
+ f_start: 1.e-6
94
+ f_min: 1.e-3
95
+ f_max: 1.0
96
+
97
+ loss_cfg:
98
+ loss_type: "mse"
example_data/image/car.jpg ADDED
example_data/surface/surface.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0893e44d82ada683baa656a718beaf6ec19fc28b6816b451f56645530d5bb962
3
+ size 1201024
gradio_app.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ from collections import OrderedDict
5
+ from PIL import Image
6
+ import torch
7
+ import trimesh
8
+ from typing import Optional, List
9
+ from einops import repeat, rearrange
10
+ import numpy as np
11
+ from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
12
+ from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
13
+ from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
14
+ from michelangelo.utils.visualizers import html_util
15
+
16
+ import gradio as gr
17
+
18
+
19
+ gradio_cached_dir = "./gradio_cached_dir"
20
+ os.makedirs(gradio_cached_dir, exist_ok=True)
21
+
22
+ save_mesh = False
23
+
24
+ state = ""
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+
27
+ box_v = 1.1
28
+ viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE")
29
+
30
+ image_model_config_dict = OrderedDict({
31
+ "ASLDM-256-obj": {
32
+ "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml",
33
+ "ckpt_path": "./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt",
34
+ },
35
+ })
36
+
37
+ text_model_config_dict = OrderedDict({
38
+ "ASLDM-256": {
39
+ "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml",
40
+ "ckpt_path": "./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt",
41
+ },
42
+ })
43
+
44
+
45
+ class InferenceModel(object):
46
+ model = None
47
+ name = ""
48
+
49
+
50
+ text2mesh_model = InferenceModel()
51
+ image2mesh_model = InferenceModel()
52
+
53
+
54
+ def set_state(s):
55
+ global state
56
+ state = s
57
+ print(s)
58
+
59
+
60
+ def output_to_html_frame(mesh_outputs: List[Latent2MeshOutput], bbox_size: float,
61
+ image: Optional[np.ndarray] = None,
62
+ html_frame: bool = False):
63
+ global viewer
64
+
65
+ for i in range(len(mesh_outputs)):
66
+ mesh = mesh_outputs[i]
67
+ if mesh is None:
68
+ continue
69
+
70
+ mesh_v = mesh.mesh_v.copy()
71
+ mesh_v[:, 0] += i * np.max(bbox_size)
72
+ mesh_v[:, 2] += np.max(bbox_size)
73
+ viewer.add_mesh(mesh_v, mesh.mesh_f)
74
+
75
+ mesh_tag = viewer.to_html(html_frame=False)
76
+
77
+ if image is not None:
78
+ image_tag = html_util.to_image_embed_tag(image)
79
+ frame = f"""
80
+ <table border = "1">
81
+ <tr>
82
+ <td>{image_tag}</td>
83
+ <td>{mesh_tag}</td>
84
+ </tr>
85
+ </table>
86
+ """
87
+ else:
88
+ frame = mesh_tag
89
+
90
+ if html_frame:
91
+ frame = html_util.to_html_frame(frame)
92
+
93
+ viewer.reset()
94
+
95
+ return frame
96
+
97
+
98
+ def load_model(model_name: str, model_config_dict: dict, inference_model: InferenceModel):
99
+ global device
100
+
101
+ if inference_model.name == model_name:
102
+ model = inference_model.model
103
+ else:
104
+ assert model_name in model_config_dict
105
+
106
+ if inference_model.model is not None:
107
+ del inference_model.model
108
+
109
+ config_ckpt_path = model_config_dict[model_name]
110
+
111
+ model_config = get_config_from_file(config_ckpt_path["config"])
112
+ if hasattr(model_config, "model"):
113
+ model_config = model_config.model
114
+
115
+ model = instantiate_from_config(model_config, ckpt_path=config_ckpt_path["ckpt_path"])
116
+ model = model.to(device)
117
+ model = model.eval()
118
+
119
+ inference_model.model = model
120
+ inference_model.name = model_name
121
+
122
+ return model
123
+
124
+
125
+ def prepare_img(image: np.ndarray):
126
+ image_pt = torch.tensor(image).float()
127
+ image_pt = image_pt / 255 * 2 - 1
128
+ image_pt = rearrange(image_pt, "h w c -> c h w")
129
+
130
+ return image_pt
131
+
132
+ def prepare_model_viewer(fp):
133
+ content = f"""
134
+ <head>
135
+ <script
136
+ type="module" src="https://ajax.googleapis.com/ajax/libs/model-viewer/3.1.1/model-viewer.min.js">
137
+ </script>
138
+ </head>
139
+ <body>
140
+ <model-viewer
141
+ style="height: 150px; width: 150px;"
142
+ rotation-per-second="10deg"
143
+ id="t1"
144
+ src="file/gradio_cached_dir/{fp}"
145
+ environment-image="neutral"
146
+ camera-target="0m 0m 0m"
147
+ orientation="0deg 90deg 170deg"
148
+ shadow-intensity="1"
149
+ ar:true
150
+ auto-rotate
151
+ camera-controls>
152
+ </model-viewer>
153
+ </body>
154
+ """
155
+ return content
156
+
157
+ def prepare_html_frame(content):
158
+ frame = f"""
159
+ <html>
160
+ <body>
161
+ {content}
162
+ </body>
163
+ </html>
164
+ """
165
+ return frame
166
+
167
+ def prepare_html_body(content):
168
+ frame = f"""
169
+ <body>
170
+ {content}
171
+ </body>
172
+ """
173
+ return frame
174
+
175
+ def post_process_mesh_outputs(mesh_outputs):
176
+ # html_frame = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=True)
177
+ html_content = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=False)
178
+ html_frame = prepare_html_frame(html_content)
179
+
180
+ # filename = f"{time.time()}.html"
181
+ filename = f"text-256-{time.time()}.html"
182
+ html_filepath = os.path.join(gradio_cached_dir, filename)
183
+ with open(html_filepath, "w") as writer:
184
+ writer.write(html_frame)
185
+
186
+ '''
187
+ Bug: The iframe tag does not work in Gradio.
188
+ The chrome returns "No resource with given URL found"
189
+ Solutions:
190
+ https://github.com/gradio-app/gradio/issues/884
191
+ Due to the security bitches, the server can only find files parallel to the gradio_app.py.
192
+ The path has format "file/TARGET_FILE_PATH"
193
+ '''
194
+
195
+ iframe_tag = f'<iframe src="file/gradio_cached_dir/{filename}" width="600%" height="400" frameborder="0"></iframe>'
196
+
197
+ filelist = []
198
+ filenames = []
199
+ for i, mesh in enumerate(mesh_outputs):
200
+ mesh.mesh_f = mesh.mesh_f[:, ::-1]
201
+ mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
202
+
203
+ name = str(i) + "_out_mesh.obj"
204
+ filepath = gradio_cached_dir + "/" + name
205
+ mesh_output.export(filepath, include_normals=True)
206
+ filelist.append(filepath)
207
+ filenames.append(name)
208
+
209
+ filelist.append(html_filepath)
210
+ return iframe_tag, filelist
211
+
212
+ def image2mesh(image: np.ndarray,
213
+ model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03",
214
+ num_samples: int = 4,
215
+ guidance_scale: int = 7.5,
216
+ octree_depth: int = 7):
217
+ global device, gradio_cached_dir, image_model_config_dict, box_v
218
+
219
+ # load model
220
+ model = load_model(model_name, image_model_config_dict, image2mesh_model)
221
+
222
+ # prepare image inputs
223
+ image_pt = prepare_img(image)
224
+ image_pt = repeat(image_pt, "c h w -> b c h w", b=num_samples)
225
+
226
+ sample_inputs = {
227
+ "image": image_pt
228
+ }
229
+ mesh_outputs = model.sample(
230
+ sample_inputs,
231
+ sample_times=1,
232
+ guidance_scale=guidance_scale,
233
+ return_intermediates=False,
234
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
235
+ octree_depth=octree_depth,
236
+ )[0]
237
+
238
+ iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs)
239
+
240
+ return iframe_tag, gr.update(value=filelist, visible=True)
241
+
242
+
243
+ def text2mesh(text: str,
244
+ model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03",
245
+ num_samples: int = 4,
246
+ guidance_scale: int = 7.5,
247
+ octree_depth: int = 7):
248
+ global device, gradio_cached_dir, text_model_config_dict, text2mesh_model, box_v
249
+
250
+ # load model
251
+ model = load_model(model_name, text_model_config_dict, text2mesh_model)
252
+
253
+ # prepare text inputs
254
+ sample_inputs = {
255
+ "text": [text] * num_samples
256
+ }
257
+ mesh_outputs = model.sample(
258
+ sample_inputs,
259
+ sample_times=1,
260
+ guidance_scale=guidance_scale,
261
+ return_intermediates=False,
262
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
263
+ octree_depth=octree_depth,
264
+ )[0]
265
+
266
+ iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs)
267
+
268
+ return iframe_tag, gr.update(value=filelist, visible=True)
269
+
270
+ example_dir = './gradio_cached_dir/example/img_example'
271
+
272
+ first_page_items = [
273
+ 'alita.jpg',
274
+ 'burger.jpg'
275
+ 'loopy.jpg'
276
+ 'building.jpg',
277
+ 'mario.jpg',
278
+ 'car.jpg',
279
+ 'airplane.jpg',
280
+ 'bag.jpg',
281
+ 'bench.jpg',
282
+ 'ship.jpg'
283
+ ]
284
+ raw_example_items = [
285
+ # (os.path.join(example_dir, x), x)
286
+ os.path.join(example_dir, x)
287
+ for x in os.listdir(example_dir)
288
+ if x.endswith(('.jpg', '.png'))
289
+ ]
290
+ example_items = [x for x in raw_example_items if os.path.basename(x) in first_page_items] + [x for x in raw_example_items if os.path.basename(x) not in first_page_items]
291
+
292
+ example_text = [
293
+ ["A 3D model of a car; Audi A6."],
294
+ ["A 3D model of police car; Highway Patrol Charger"]
295
+ ],
296
+
297
+ def set_cache(data: gr.SelectData):
298
+ img_name = os.path.basename(example_items[data.index])
299
+ return os.path.join(example_dir, img_name), os.path.join(img_name)
300
+
301
+ def disable_cache():
302
+ return ""
303
+
304
+ with gr.Blocks() as app:
305
+ gr.Markdown("# Michelangelo")
306
+ gr.Markdown("## [Github](https://github.com/NeuralCarver/Michelangelo) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Project Page](https://neuralcarver.github.io/michelangelo/)")
307
+ gr.Markdown("Michelangelo is a conditional 3D shape generation system that trains based on the shape-image-text aligned latent representation.")
308
+ gr.Markdown("### Hint:")
309
+ gr.Markdown("1. We provide two APIs: Image-conditioned generation and Text-conditioned generation")
310
+ gr.Markdown("2. Note that the Image-conditioned model is trained on multiple 3D datasets like ShapeNet and Objaverse")
311
+ gr.Markdown("3. We provide some examples for you to try. You can also upload images or text as input.")
312
+ gr.Markdown("4. Welcome to share your amazing results with us, and thanks for your interest in our work!")
313
+
314
+ with gr.Row():
315
+ with gr.Column():
316
+
317
+ with gr.Tab("Image to 3D"):
318
+ img = gr.Image(label="Image")
319
+ gr.Markdown("For the best results, we suggest that the images uploaded meet the following three criteria: 1. The object is positioned at the center of the image, 2. The image size is square, and 3. The background is relatively clean.")
320
+ btn_generate_img2obj = gr.Button(value="Generate")
321
+
322
+ with gr.Accordion("Advanced settings", open=False):
323
+ image_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256-obj",choices=list(image_model_config_dict.keys()))
324
+ num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1)
325
+ guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1)
326
+ octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1)
327
+
328
+
329
+ cache_dir = gr.Textbox(value="", visible=False)
330
+ examples = gr.Gallery(label='Examples', value=example_items, elem_id="gallery", allow_preview=False, columns=[4], object_fit="contain")
331
+
332
+ with gr.Tab("Text to 3D"):
333
+ prompt = gr.Textbox(label="Prompt", placeholder="A 3D model of motorcar; Porche Cayenne Turbo.")
334
+ gr.Markdown("For the best results, we suggest that the prompt follows 'A 3D model of CATEGORY; DESCRIPTION'. For example, A 3D model of motorcar; Porche Cayenne Turbo.")
335
+ btn_generate_txt2obj = gr.Button(value="Generate")
336
+
337
+ with gr.Accordion("Advanced settings", open=False):
338
+ text_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256",choices=list(text_model_config_dict.keys()))
339
+ num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1)
340
+ guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1)
341
+ octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1)
342
+
343
+ gr.Markdown("#### Examples:")
344
+ gr.Markdown("1. A 3D model of a coupe; Audi A6.")
345
+ gr.Markdown("2. A 3D model of a motorcar; Hummer H2 SUT.")
346
+ gr.Markdown("3. A 3D model of an airplane; Airbus.")
347
+ gr.Markdown("4. A 3D model of a fighter aircraft; Attack Fighter.")
348
+ gr.Markdown("5. A 3D model of a chair; Simple Wooden Chair.")
349
+ gr.Markdown("6. A 3D model of a laptop computer; Dell Laptop.")
350
+ gr.Markdown("7. A 3D model of a lamp; ceiling light.")
351
+ gr.Markdown("8. A 3D model of a rifle; AK47.")
352
+ gr.Markdown("9. A 3D model of a knife; Sword.")
353
+ gr.Markdown("10. A 3D model of a vase; Plant in pot.")
354
+
355
+ with gr.Column():
356
+ model_3d = gr.HTML()
357
+ file_out = gr.File(label="Files", visible=False)
358
+
359
+ outputs = [model_3d, file_out]
360
+
361
+ img.upload(disable_cache, outputs=cache_dir)
362
+ examples.select(set_cache, outputs=[img, cache_dir])
363
+ print(f'line:404: {cache_dir}')
364
+ btn_generate_img2obj.click(image2mesh, inputs=[img, image_dropdown_models, num_samples,
365
+ guidance_scale, octree_depth],
366
+ outputs=outputs, api_name="generate_img2obj")
367
+
368
+ btn_generate_txt2obj.click(text2mesh, inputs=[prompt, text_dropdown_models, num_samples,
369
+ guidance_scale, octree_depth],
370
+ outputs=outputs, api_name="generate_txt2obj")
371
+
372
+ app.launch(server_name="0.0.0.0", server_port=8008, share=False)
gradio_cached_dir/example/img_example/airplane.jpg ADDED
gradio_cached_dir/example/img_example/alita.jpg ADDED
gradio_cached_dir/example/img_example/bag.jpg ADDED
gradio_cached_dir/example/img_example/bench.jpg ADDED
gradio_cached_dir/example/img_example/building.jpg ADDED
gradio_cached_dir/example/img_example/burger.jpg ADDED
gradio_cached_dir/example/img_example/car.jpg ADDED
gradio_cached_dir/example/img_example/loopy.jpg ADDED
gradio_cached_dir/example/img_example/mario.jpg ADDED
gradio_cached_dir/example/img_example/ship.jpg ADDED
inference.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ from collections import OrderedDict
5
+ from typing import Optional, List
6
+ import argparse
7
+ from functools import partial
8
+
9
+ from einops import repeat, rearrange
10
+ import numpy as np
11
+ from PIL import Image
12
+ import trimesh
13
+ import cv2
14
+
15
+ import torch
16
+ import pytorch_lightning as pl
17
+
18
+ from michelangelo.models.tsal.tsal_base import Latent2MeshOutput
19
+ from michelangelo.models.tsal.inference_utils import extract_geometry
20
+ from michelangelo.utils.misc import get_config_from_file, instantiate_from_config
21
+ from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer
22
+ from michelangelo.utils.visualizers import html_util
23
+
24
+ def load_model(args):
25
+
26
+ model_config = get_config_from_file(args.config_path)
27
+ if hasattr(model_config, "model"):
28
+ model_config = model_config.model
29
+
30
+ model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path)
31
+ model = model.cuda()
32
+ model = model.eval()
33
+
34
+ return model
35
+
36
+ def load_surface(fp):
37
+
38
+ with np.load(args.pointcloud_path) as input_pc:
39
+ surface = input_pc['points']
40
+ normal = input_pc['normals']
41
+
42
+ rng = np.random.default_rng()
43
+ ind = rng.choice(surface.shape[0], 4096, replace=False)
44
+ surface = torch.FloatTensor(surface[ind])
45
+ normal = torch.FloatTensor(normal[ind])
46
+
47
+ surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda()
48
+
49
+ return surface
50
+
51
+ def prepare_image(args, number_samples=2):
52
+
53
+ image = cv2.imread(f"{args.image_path}")
54
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55
+
56
+ image_pt = torch.tensor(image).float()
57
+ image_pt = image_pt / 255 * 2 - 1
58
+ image_pt = rearrange(image_pt, "h w c -> c h w")
59
+
60
+ image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples)
61
+
62
+ return image_pt
63
+
64
+ def save_output(args, mesh_outputs):
65
+
66
+ os.makedirs(args.output_dir, exist_ok=True)
67
+ for i, mesh in enumerate(mesh_outputs):
68
+ mesh.mesh_f = mesh.mesh_f[:, ::-1]
69
+ mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f)
70
+
71
+ name = str(i) + "_out_mesh.obj"
72
+ mesh_output.export(os.path.join(args.output_dir, name), include_normals=True)
73
+
74
+ print(f'-----------------------------------------------------------------------------')
75
+ print(f'>>> Finished and mesh saved in {args.output_dir}')
76
+ print(f'-----------------------------------------------------------------------------')
77
+
78
+ return 0
79
+
80
+ def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000):
81
+
82
+ surface = load_surface(args.pointcloud_path)
83
+
84
+ # encoding
85
+ shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True)
86
+ shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents)
87
+
88
+ # decoding
89
+ latents = model.model.shape_model.decode(shape_zq)
90
+ geometric_func = partial(model.model.shape_model.query_geometry, latents=latents)
91
+
92
+ # reconstruction
93
+ mesh_v_f, has_surface = extract_geometry(
94
+ geometric_func=geometric_func,
95
+ device=surface.device,
96
+ batch_size=surface.shape[0],
97
+ bounds=bounds,
98
+ octree_depth=octree_depth,
99
+ num_chunks=num_chunks,
100
+ )
101
+ recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1])
102
+
103
+ # save
104
+ os.makedirs(args.output_dir, exist_ok=True)
105
+ recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj'))
106
+
107
+ print(f'-----------------------------------------------------------------------------')
108
+ print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}')
109
+ print(f'-----------------------------------------------------------------------------')
110
+
111
+ return 0
112
+
113
+ def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7):
114
+
115
+ sample_inputs = {
116
+ "image": prepare_image(args)
117
+ }
118
+
119
+ mesh_outputs = model.sample(
120
+ sample_inputs,
121
+ sample_times=1,
122
+ guidance_scale=guidance_scale,
123
+ return_intermediates=False,
124
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
125
+ octree_depth=octree_depth,
126
+ )[0]
127
+
128
+ save_output(args, mesh_outputs)
129
+
130
+ return 0
131
+
132
+ def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7):
133
+
134
+ sample_inputs = {
135
+ "text": [args.text] * num_samples
136
+ }
137
+ mesh_outputs = model.sample(
138
+ sample_inputs,
139
+ sample_times=1,
140
+ guidance_scale=guidance_scale,
141
+ return_intermediates=False,
142
+ bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v],
143
+ octree_depth=octree_depth,
144
+ )[0]
145
+
146
+ save_output(args, mesh_outputs)
147
+
148
+ return 0
149
+
150
+ task_dick = {
151
+ 'reconstruction': reconstruction,
152
+ 'image2mesh': image2mesh,
153
+ 'text2mesh': text2mesh,
154
+ }
155
+
156
+ if __name__ == "__main__":
157
+ '''
158
+ 1. Reconstruct point cloud
159
+ 2. Image-conditioned generation
160
+ 3. Text-conditioned generation
161
+ '''
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True)
164
+ parser.add_argument("--config_path", type=str, required=True)
165
+ parser.add_argument("--ckpt_path", type=str, required=True)
166
+ parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud')
167
+ parser.add_argument("--image_path", type=str, help='Path to the input image')
168
+ parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.')
169
+ parser.add_argument("--output_dir", type=str, default='./output')
170
+ parser.add_argument("-s", "--seed", type=int, default=0)
171
+ args = parser.parse_args()
172
+
173
+ pl.seed_everything(args.seed)
174
+
175
+ print(f'-----------------------------------------------------------------------------')
176
+ print(f'>>> Running {args.task}')
177
+ args.output_dir = os.path.join(args.output_dir, args.task)
178
+ print(f'>>> Output directory: {args.output_dir}')
179
+ print(f'-----------------------------------------------------------------------------')
180
+
181
+ task_dick[args.task](args, load_model(args))
michelangelo/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/data/templates.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "shape": [
3
+ "a point cloud model of {}.",
4
+ "There is a {} in the scene.",
5
+ "There is the {} in the scene.",
6
+ "a photo of a {} in the scene.",
7
+ "a photo of the {} in the scene.",
8
+ "a photo of one {} in the scene.",
9
+ "itap of a {}.",
10
+ "itap of my {}.",
11
+ "itap of the {}.",
12
+ "a photo of a {}.",
13
+ "a photo of my {}.",
14
+ "a photo of the {}.",
15
+ "a photo of one {}.",
16
+ "a photo of many {}.",
17
+ "a good photo of a {}.",
18
+ "a good photo of the {}.",
19
+ "a bad photo of a {}.",
20
+ "a bad photo of the {}.",
21
+ "a photo of a nice {}.",
22
+ "a photo of the nice {}.",
23
+ "a photo of a cool {}.",
24
+ "a photo of the cool {}.",
25
+ "a photo of a weird {}.",
26
+ "a photo of the weird {}.",
27
+ "a photo of a small {}.",
28
+ "a photo of the small {}.",
29
+ "a photo of a large {}.",
30
+ "a photo of the large {}.",
31
+ "a photo of a clean {}.",
32
+ "a photo of the clean {}.",
33
+ "a photo of a dirty {}.",
34
+ "a photo of the dirty {}.",
35
+ "a bright photo of a {}.",
36
+ "a bright photo of the {}.",
37
+ "a dark photo of a {}.",
38
+ "a dark photo of the {}.",
39
+ "a photo of a hard to see {}.",
40
+ "a photo of the hard to see {}.",
41
+ "a low resolution photo of a {}.",
42
+ "a low resolution photo of the {}.",
43
+ "a cropped photo of a {}.",
44
+ "a cropped photo of the {}.",
45
+ "a close-up photo of a {}.",
46
+ "a close-up photo of the {}.",
47
+ "a jpeg corrupted photo of a {}.",
48
+ "a jpeg corrupted photo of the {}.",
49
+ "a blurry photo of a {}.",
50
+ "a blurry photo of the {}.",
51
+ "a pixelated photo of a {}.",
52
+ "a pixelated photo of the {}.",
53
+ "a black and white photo of the {}.",
54
+ "a black and white photo of a {}",
55
+ "a plastic {}.",
56
+ "the plastic {}.",
57
+ "a toy {}.",
58
+ "the toy {}.",
59
+ "a plushie {}.",
60
+ "the plushie {}.",
61
+ "a cartoon {}.",
62
+ "the cartoon {}.",
63
+ "an embroidered {}.",
64
+ "the embroidered {}.",
65
+ "a painting of the {}.",
66
+ "a painting of a {}."
67
+ ]
68
+
69
+ }
michelangelo/data/transforms.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ import numpy as np
5
+ import warnings
6
+ import random
7
+ from omegaconf.listconfig import ListConfig
8
+ from webdataset import pipelinefilter
9
+ import torch
10
+ import torchvision.transforms.functional as TVF
11
+ from torchvision.transforms import InterpolationMode
12
+ from torchvision.transforms.transforms import _interpolation_modes_from_int
13
+ from typing import Sequence
14
+
15
+ from michelangelo.utils import instantiate_from_config
16
+
17
+
18
+ def _uid_buffer_pick(buf_dict, rng):
19
+ uid_keys = list(buf_dict.keys())
20
+ selected_uid = rng.choice(uid_keys)
21
+ buf = buf_dict[selected_uid]
22
+
23
+ k = rng.randint(0, len(buf) - 1)
24
+ sample = buf[k]
25
+ buf[k] = buf[-1]
26
+ buf.pop()
27
+
28
+ if len(buf) == 0:
29
+ del buf_dict[selected_uid]
30
+
31
+ return sample
32
+
33
+
34
+ def _add_to_buf_dict(buf_dict, sample):
35
+ key = sample["__key__"]
36
+ uid, uid_sample_id = key.split("_")
37
+ if uid not in buf_dict:
38
+ buf_dict[uid] = []
39
+ buf_dict[uid].append(sample)
40
+
41
+ return buf_dict
42
+
43
+
44
+ def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
45
+ """Shuffle the data in the stream.
46
+
47
+ This uses a buffer of size `bufsize`. Shuffling at
48
+ startup is less random; this is traded off against
49
+ yielding samples quickly.
50
+
51
+ data: iterator
52
+ bufsize: buffer size for shuffling
53
+ returns: iterator
54
+ rng: either random module or random.Random instance
55
+
56
+ """
57
+ if rng is None:
58
+ rng = random.Random(int((os.getpid() + time.time()) * 1e9))
59
+ initial = min(initial, bufsize)
60
+ buf_dict = dict()
61
+ current_samples = 0
62
+ for sample in data:
63
+ _add_to_buf_dict(buf_dict, sample)
64
+ current_samples += 1
65
+
66
+ if current_samples < bufsize:
67
+ try:
68
+ _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708
69
+ current_samples += 1
70
+ except StopIteration:
71
+ pass
72
+
73
+ if current_samples >= initial:
74
+ current_samples -= 1
75
+ yield _uid_buffer_pick(buf_dict, rng)
76
+
77
+ while current_samples > 0:
78
+ current_samples -= 1
79
+ yield _uid_buffer_pick(buf_dict, rng)
80
+
81
+
82
+ uid_shuffle = pipelinefilter(_uid_shuffle)
83
+
84
+
85
+ class RandomSample(object):
86
+ def __init__(self,
87
+ num_volume_samples: int = 1024,
88
+ num_near_samples: int = 1024):
89
+
90
+ super().__init__()
91
+
92
+ self.num_volume_samples = num_volume_samples
93
+ self.num_near_samples = num_near_samples
94
+
95
+ def __call__(self, sample):
96
+ rng = np.random.default_rng()
97
+
98
+ # 1. sample surface input
99
+ total_surface = sample["surface"]
100
+ ind = rng.choice(total_surface.shape[0], replace=False)
101
+ surface = total_surface[ind]
102
+
103
+ # 2. sample volume/near geometric points
104
+ vol_points = sample["vol_points"]
105
+ vol_label = sample["vol_label"]
106
+ near_points = sample["near_points"]
107
+ near_label = sample["near_label"]
108
+
109
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
110
+ vol_points = vol_points[ind]
111
+ vol_label = vol_label[ind]
112
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
113
+
114
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
115
+ near_points = near_points[ind]
116
+ near_label = near_label[ind]
117
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
118
+
119
+ # concat sampled volume and near points
120
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
121
+
122
+ sample = {
123
+ "surface": surface,
124
+ "geo_points": geo_points
125
+ }
126
+
127
+ return sample
128
+
129
+
130
+ class SplitRandomSample(object):
131
+ def __init__(self,
132
+ use_surface_sample: bool = False,
133
+ num_surface_samples: int = 4096,
134
+ num_volume_samples: int = 1024,
135
+ num_near_samples: int = 1024):
136
+
137
+ super().__init__()
138
+
139
+ self.use_surface_sample = use_surface_sample
140
+ self.num_surface_samples = num_surface_samples
141
+ self.num_volume_samples = num_volume_samples
142
+ self.num_near_samples = num_near_samples
143
+
144
+ def __call__(self, sample):
145
+
146
+ rng = np.random.default_rng()
147
+
148
+ # 1. sample surface input
149
+ surface = sample["surface"]
150
+
151
+ if self.use_surface_sample:
152
+ replace = surface.shape[0] < self.num_surface_samples
153
+ ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace)
154
+ surface = surface[ind]
155
+
156
+ # 2. sample volume/near geometric points
157
+ vol_points = sample["vol_points"]
158
+ vol_label = sample["vol_label"]
159
+ near_points = sample["near_points"]
160
+ near_label = sample["near_label"]
161
+
162
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
163
+ vol_points = vol_points[ind]
164
+ vol_label = vol_label[ind]
165
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
166
+
167
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
168
+ near_points = near_points[ind]
169
+ near_label = near_label[ind]
170
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
171
+
172
+ # concat sampled volume and near points
173
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
174
+
175
+ sample = {
176
+ "surface": surface,
177
+ "geo_points": geo_points
178
+ }
179
+
180
+ return sample
181
+
182
+
183
+ class FeatureSelection(object):
184
+
185
+ VALID_SURFACE_FEATURE_DIMS = {
186
+ "none": [0, 1, 2], # xyz
187
+ "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal
188
+ "normal": [0, 1, 2, 6, 7, 8]
189
+ }
190
+
191
+ def __init__(self, surface_feature_type: str):
192
+
193
+ self.surface_feature_type = surface_feature_type
194
+ self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type]
195
+
196
+ def __call__(self, sample):
197
+ sample["surface"] = sample["surface"][:, self.surface_dims]
198
+ return sample
199
+
200
+
201
+ class AxisScaleTransform(object):
202
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
203
+ assert isinstance(interval, (tuple, list, ListConfig))
204
+ self.interval = interval
205
+ self.min_val = interval[0]
206
+ self.max_val = interval[1]
207
+ self.inter_size = interval[1] - interval[0]
208
+ self.jitter = jitter
209
+ self.jitter_scale = jitter_scale
210
+
211
+ def __call__(self, sample):
212
+
213
+ surface = sample["surface"][..., 0:3]
214
+ geo_points = sample["geo_points"][..., 0:3]
215
+
216
+ scaling = torch.rand(1, 3) * self.inter_size + self.min_val
217
+ # print(scaling)
218
+ surface = surface * scaling
219
+ geo_points = geo_points * scaling
220
+
221
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
222
+ surface *= scale
223
+ geo_points *= scale
224
+
225
+ if self.jitter:
226
+ surface += self.jitter_scale * torch.randn_like(surface)
227
+ surface.clamp_(min=-1.015, max=1.015)
228
+
229
+ sample["surface"][..., 0:3] = surface
230
+ sample["geo_points"][..., 0:3] = geo_points
231
+
232
+ return sample
233
+
234
+
235
+ class ToTensor(object):
236
+
237
+ def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")):
238
+ self.tensor_keys = tensor_keys
239
+
240
+ def __call__(self, sample):
241
+ for key in self.tensor_keys:
242
+ if key not in sample:
243
+ continue
244
+
245
+ sample[key] = torch.tensor(sample[key], dtype=torch.float32)
246
+
247
+ return sample
248
+
249
+
250
+ class AxisScale(object):
251
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
252
+ assert isinstance(interval, (tuple, list, ListConfig))
253
+ self.interval = interval
254
+ self.jitter = jitter
255
+ self.jitter_scale = jitter_scale
256
+
257
+ def __call__(self, surface, *args):
258
+ scaling = torch.rand(1, 3) * 0.5 + 0.75
259
+ # print(scaling)
260
+ surface = surface * scaling
261
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
262
+ surface *= scale
263
+
264
+ args_outputs = []
265
+ for _arg in args:
266
+ _arg = _arg * scaling * scale
267
+ args_outputs.append(_arg)
268
+
269
+ if self.jitter:
270
+ surface += self.jitter_scale * torch.randn_like(surface)
271
+ surface.clamp_(min=-1, max=1)
272
+
273
+ if len(args) == 0:
274
+ return surface
275
+ else:
276
+ return surface, *args_outputs
277
+
278
+
279
+ class RandomResize(torch.nn.Module):
280
+ """Apply randomly Resize with a given probability."""
281
+
282
+ def __init__(
283
+ self,
284
+ size,
285
+ resize_radio=(0.5, 1),
286
+ allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR),
287
+ interpolation=InterpolationMode.BICUBIC,
288
+ max_size=None,
289
+ antialias=None,
290
+ ):
291
+ super().__init__()
292
+ if not isinstance(size, (int, Sequence)):
293
+ raise TypeError(f"Size should be int or sequence. Got {type(size)}")
294
+ if isinstance(size, Sequence) and len(size) not in (1, 2):
295
+ raise ValueError("If size is a sequence, it should have 1 or 2 values")
296
+
297
+ self.size = size
298
+ self.max_size = max_size
299
+ # Backward compatibility with integer value
300
+ if isinstance(interpolation, int):
301
+ warnings.warn(
302
+ "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
303
+ "Please use InterpolationMode enum."
304
+ )
305
+ interpolation = _interpolation_modes_from_int(interpolation)
306
+
307
+ self.interpolation = interpolation
308
+ self.antialias = antialias
309
+
310
+ self.resize_radio = resize_radio
311
+ self.allow_resize_interpolations = allow_resize_interpolations
312
+
313
+ def random_resize_params(self):
314
+ radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0]
315
+
316
+ if isinstance(self.size, int):
317
+ size = int(self.size * radio)
318
+ elif isinstance(self.size, Sequence):
319
+ size = list(self.size)
320
+ size = (int(size[0] * radio), int(size[1] * radio))
321
+ else:
322
+ raise RuntimeError()
323
+
324
+ interpolation = self.allow_resize_interpolations[
325
+ torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,))
326
+ ]
327
+ return size, interpolation
328
+
329
+ def forward(self, img):
330
+ size, interpolation = self.random_resize_params()
331
+ img = TVF.resize(img, size, interpolation, self.max_size, self.antialias)
332
+ img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
333
+ return img
334
+
335
+ def __repr__(self) -> str:
336
+ detail = f"(size={self.size}, interpolation={self.interpolation.value},"
337
+ detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}"
338
+ return f"{self.__class__.__name__}{detail}"
339
+
340
+
341
+ class Compose(object):
342
+ """Composes several transforms together. This transform does not support torchscript.
343
+ Please, see the note below.
344
+
345
+ Args:
346
+ transforms (list of ``Transform`` objects): list of transforms to compose.
347
+
348
+ Example:
349
+ >>> transforms.Compose([
350
+ >>> transforms.CenterCrop(10),
351
+ >>> transforms.ToTensor(),
352
+ >>> ])
353
+
354
+ .. note::
355
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
356
+
357
+ >>> transforms = torch.nn.Sequential(
358
+ >>> transforms.CenterCrop(10),
359
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
360
+ >>> )
361
+ >>> scripted_transforms = torch.jit.script(transforms)
362
+
363
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
364
+ `lambda` functions or ``PIL.Image``.
365
+
366
+ """
367
+
368
+ def __init__(self, transforms):
369
+ self.transforms = transforms
370
+
371
+ def __call__(self, *args):
372
+ for t in self.transforms:
373
+ args = t(*args)
374
+ return args
375
+
376
+ def __repr__(self):
377
+ format_string = self.__class__.__name__ + '('
378
+ for t in self.transforms:
379
+ format_string += '\n'
380
+ format_string += ' {0}'.format(t)
381
+ format_string += '\n)'
382
+ return format_string
383
+
384
+
385
+ def identity(*args, **kwargs):
386
+ if len(args) == 1:
387
+ return args[0]
388
+ else:
389
+ return args
390
+
391
+
392
+ def build_transforms(cfg):
393
+
394
+ if cfg is None:
395
+ return identity
396
+
397
+ transforms = []
398
+
399
+ for transform_name, cfg_instance in cfg.items():
400
+ transform_instance = instantiate_from_config(cfg_instance)
401
+ transforms.append(transform_instance)
402
+ print(f"Build transform: {transform_instance}")
403
+
404
+ transforms = Compose(transforms)
405
+
406
+ return transforms
407
+
michelangelo/data/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ def worker_init_fn(_):
8
+ worker_info = torch.utils.data.get_worker_info()
9
+ worker_id = worker_info.id
10
+
11
+ # dataset = worker_info.dataset
12
+ # split_size = dataset.num_records // worker_info.num_workers
13
+ # # reset num_records to the true number to retain reliable length information
14
+ # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
15
+ # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
16
+ # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
17
+
18
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
19
+
20
+
21
+ def collation_fn(samples, combine_tensors=True, combine_scalars=True):
22
+ """
23
+
24
+ Args:
25
+ samples (list[dict]):
26
+ combine_tensors:
27
+ combine_scalars:
28
+
29
+ Returns:
30
+
31
+ """
32
+
33
+ result = {}
34
+
35
+ keys = samples[0].keys()
36
+
37
+ for key in keys:
38
+ result[key] = []
39
+
40
+ for sample in samples:
41
+ for key in keys:
42
+ val = sample[key]
43
+ result[key].append(val)
44
+
45
+ for key in keys:
46
+ val_list = result[key]
47
+ if isinstance(val_list[0], (int, float)):
48
+ if combine_scalars:
49
+ result[key] = np.array(result[key])
50
+
51
+ elif isinstance(val_list[0], torch.Tensor):
52
+ if combine_tensors:
53
+ result[key] = torch.stack(val_list)
54
+
55
+ elif isinstance(val_list[0], np.ndarray):
56
+ if combine_tensors:
57
+ result[key] = np.stack(val_list)
58
+
59
+ return result
michelangelo/graphics/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/graphics/primitives/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .volume import generate_dense_grid_points
4
+
5
+ from .mesh import (
6
+ MeshOutput,
7
+ save_obj,
8
+ savemeshtes2
9
+ )
michelangelo/graphics/primitives/mesh.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import PIL.Image
7
+ from typing import Optional
8
+
9
+ import trimesh
10
+
11
+
12
+ def save_obj(pointnp_px3, facenp_fx3, fname):
13
+ fid = open(fname, "w")
14
+ write_str = ""
15
+ for pidx, p in enumerate(pointnp_px3):
16
+ pp = p
17
+ write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
18
+
19
+ for i, f in enumerate(facenp_fx3):
20
+ f1 = f + 1
21
+ write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
22
+ fid.write(write_str)
23
+ fid.close()
24
+ return
25
+
26
+
27
+ def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
28
+ fol, na = os.path.split(fname)
29
+ na, _ = os.path.splitext(na)
30
+
31
+ matname = "%s/%s.mtl" % (fol, na)
32
+ fid = open(matname, "w")
33
+ fid.write("newmtl material_0\n")
34
+ fid.write("Kd 1 1 1\n")
35
+ fid.write("Ka 0 0 0\n")
36
+ fid.write("Ks 0.4 0.4 0.4\n")
37
+ fid.write("Ns 10\n")
38
+ fid.write("illum 2\n")
39
+ fid.write("map_Kd %s.png\n" % na)
40
+ fid.close()
41
+ ####
42
+
43
+ fid = open(fname, "w")
44
+ fid.write("mtllib %s.mtl\n" % na)
45
+
46
+ for pidx, p in enumerate(pointnp_px3):
47
+ pp = p
48
+ fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
49
+
50
+ for pidx, p in enumerate(tcoords_px2):
51
+ pp = p
52
+ fid.write("vt %f %f\n" % (pp[0], pp[1]))
53
+
54
+ fid.write("usemtl material_0\n")
55
+ for i, f in enumerate(facenp_fx3):
56
+ f1 = f + 1
57
+ f2 = facetex_fx3[i] + 1
58
+ fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
59
+ fid.close()
60
+
61
+ PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
62
+ os.path.join(fol, "%s.png" % na))
63
+
64
+ return
65
+
66
+
67
+ class MeshOutput(object):
68
+
69
+ def __init__(self,
70
+ mesh_v: np.ndarray,
71
+ mesh_f: np.ndarray,
72
+ vertex_colors: Optional[np.ndarray] = None,
73
+ uvs: Optional[np.ndarray] = None,
74
+ mesh_tex_idx: Optional[np.ndarray] = None,
75
+ tex_map: Optional[np.ndarray] = None):
76
+
77
+ self.mesh_v = mesh_v
78
+ self.mesh_f = mesh_f
79
+ self.vertex_colors = vertex_colors
80
+ self.uvs = uvs
81
+ self.mesh_tex_idx = mesh_tex_idx
82
+ self.tex_map = tex_map
83
+
84
+ def contain_uv_texture(self):
85
+ return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
86
+
87
+ def contain_vertex_colors(self):
88
+ return self.vertex_colors is not None
89
+
90
+ def export(self, fname):
91
+
92
+ if self.contain_uv_texture():
93
+ savemeshtes2(
94
+ self.mesh_v,
95
+ self.uvs,
96
+ self.mesh_f,
97
+ self.mesh_tex_idx,
98
+ self.tex_map,
99
+ fname
100
+ )
101
+
102
+ elif self.contain_vertex_colors():
103
+ mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
104
+ mesh_obj.export(fname)
105
+
106
+ else:
107
+ save_obj(
108
+ self.mesh_v,
109
+ self.mesh_f,
110
+ fname
111
+ )
112
+
113
+
114
+
michelangelo/graphics/primitives/volume.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+
6
+ def generate_dense_grid_points(bbox_min: np.ndarray,
7
+ bbox_max: np.ndarray,
8
+ octree_depth: int,
9
+ indexing: str = "ij"):
10
+ length = bbox_max - bbox_min
11
+ num_cells = np.exp2(octree_depth)
12
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
13
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
14
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
15
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
16
+ xyz = np.stack((xs, ys, zs), axis=-1)
17
+ xyz = xyz.reshape(-1, 3)
18
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
19
+
20
+ return xyz, grid_size, length
21
+
michelangelo/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/models/asl_diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from omegaconf import DictConfig
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.optim import lr_scheduler
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+
13
+ from einops import rearrange
14
+
15
+ from diffusers.schedulers import (
16
+ DDPMScheduler,
17
+ DDIMScheduler,
18
+ KarrasVeScheduler,
19
+ DPMSolverMultistepScheduler
20
+ )
21
+
22
+ from michelangelo.utils import instantiate_from_config
23
+ # from michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule
24
+ from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule
25
+ from michelangelo.models.asl_diffusion.inference_utils import ddim_sample
26
+
27
+ SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
28
+
29
+
30
+ def disabled_train(self, mode=True):
31
+ """Overwrite model.train with this function to make sure train/eval mode
32
+ does not change anymore."""
33
+ return self
34
+
35
+
36
+ class ASLDiffuser(pl.LightningModule):
37
+ first_stage_model: Optional[AlignedShapeAsLatentPLModule]
38
+ # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
39
+ model: nn.Module
40
+
41
+ def __init__(self, *,
42
+ first_stage_config,
43
+ denoiser_cfg,
44
+ scheduler_cfg,
45
+ optimizer_cfg,
46
+ loss_cfg,
47
+ first_stage_key: str = "surface",
48
+ cond_stage_key: str = "image",
49
+ cond_stage_trainable: bool = True,
50
+ scale_by_std: bool = False,
51
+ z_scale_factor: float = 1.0,
52
+ ckpt_path: Optional[str] = None,
53
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
54
+
55
+ super().__init__()
56
+
57
+ self.first_stage_key = first_stage_key
58
+ self.cond_stage_key = cond_stage_key
59
+ self.cond_stage_trainable = cond_stage_trainable
60
+
61
+ # 1. initialize first stage.
62
+ # Note: the condition model contained in the first stage model.
63
+ self.first_stage_config = first_stage_config
64
+ self.first_stage_model = None
65
+ # self.instantiate_first_stage(first_stage_config)
66
+
67
+ # 2. initialize conditional stage
68
+ # self.instantiate_cond_stage(cond_stage_config)
69
+ self.cond_stage_model = {
70
+ "image": self.encode_image,
71
+ "image_unconditional_embedding": self.empty_img_cond,
72
+ "text": self.encode_text,
73
+ "text_unconditional_embedding": self.empty_text_cond,
74
+ "surface": self.encode_surface,
75
+ "surface_unconditional_embedding": self.empty_surface_cond,
76
+ }
77
+
78
+ # 3. diffusion model
79
+ self.model = instantiate_from_config(
80
+ denoiser_cfg, device=None, dtype=None
81
+ )
82
+
83
+ self.optimizer_cfg = optimizer_cfg
84
+
85
+ # 4. scheduling strategy
86
+ self.scheduler_cfg = scheduler_cfg
87
+
88
+ self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
89
+ self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
90
+
91
+ # 5. loss configures
92
+ self.loss_cfg = loss_cfg
93
+
94
+ self.scale_by_std = scale_by_std
95
+ if scale_by_std:
96
+ self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
97
+ else:
98
+ self.z_scale_factor = z_scale_factor
99
+
100
+ self.ckpt_path = ckpt_path
101
+ if ckpt_path is not None:
102
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
103
+
104
+ def instantiate_first_stage(self, config):
105
+ model = instantiate_from_config(config)
106
+ self.first_stage_model = model.eval()
107
+ self.first_stage_model.train = disabled_train
108
+ for param in self.first_stage_model.parameters():
109
+ param.requires_grad = False
110
+
111
+ self.first_stage_model = self.first_stage_model.to(self.device)
112
+
113
+ # def instantiate_cond_stage(self, config):
114
+ # if not self.cond_stage_trainable:
115
+ # if config == "__is_first_stage__":
116
+ # print("Using first stage also as cond stage.")
117
+ # self.cond_stage_model = self.first_stage_model
118
+ # elif config == "__is_unconditional__":
119
+ # print(f"Training {self.__class__.__name__} as an unconditional model.")
120
+ # self.cond_stage_model = None
121
+ # # self.be_unconditional = True
122
+ # else:
123
+ # model = instantiate_from_config(config)
124
+ # self.cond_stage_model = model.eval()
125
+ # self.cond_stage_model.train = disabled_train
126
+ # for param in self.cond_stage_model.parameters():
127
+ # param.requires_grad = False
128
+ # else:
129
+ # assert config != "__is_first_stage__"
130
+ # assert config != "__is_unconditional__"
131
+ # model = instantiate_from_config(config)
132
+ # self.cond_stage_model = model
133
+
134
+ def init_from_ckpt(self, path, ignore_keys=()):
135
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
136
+
137
+ keys = list(state_dict.keys())
138
+ for k in keys:
139
+ for ik in ignore_keys:
140
+ if k.startswith(ik):
141
+ print("Deleting key {} from state_dict.".format(k))
142
+ del state_dict[k]
143
+
144
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
145
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
146
+ if len(missing) > 0:
147
+ print(f"Missing Keys: {missing}")
148
+ print(f"Unexpected Keys: {unexpected}")
149
+
150
+ @property
151
+ def zero_rank(self):
152
+ if self._trainer:
153
+ zero_rank = self.trainer.local_rank == 0
154
+ else:
155
+ zero_rank = True
156
+
157
+ return zero_rank
158
+
159
+ def configure_optimizers(self) -> Tuple[List, List]:
160
+
161
+ lr = self.learning_rate
162
+
163
+ trainable_parameters = list(self.model.parameters())
164
+ # if the conditional encoder is trainable
165
+
166
+ # if self.cond_stage_trainable:
167
+ # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
168
+ # trainable_parameters += conditioner_params
169
+ # print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
170
+
171
+ if self.optimizer_cfg is None:
172
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
173
+ schedulers = []
174
+ else:
175
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
176
+ scheduler_func = instantiate_from_config(
177
+ self.optimizer_cfg.scheduler,
178
+ max_decay_steps=self.trainer.max_steps,
179
+ lr_max=lr
180
+ )
181
+ scheduler = {
182
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
183
+ "interval": "step",
184
+ "frequency": 1
185
+ }
186
+ optimizers = [optimizer]
187
+ schedulers = [scheduler]
188
+
189
+ return optimizers, schedulers
190
+
191
+ @torch.no_grad()
192
+ def encode_text(self, text):
193
+
194
+ b = text.shape[0]
195
+ text_tokens = rearrange(text, "b t l -> (b t) l")
196
+ text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
197
+ text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
198
+ text_embed = text_embed.mean(dim=1)
199
+ text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
200
+
201
+ return text_embed
202
+
203
+ @torch.no_grad()
204
+ def encode_image(self, img):
205
+
206
+ return self.first_stage_model.model.encode_image_embed(img)
207
+
208
+ @torch.no_grad()
209
+ def encode_surface(self, surface):
210
+
211
+ return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
212
+
213
+ @torch.no_grad()
214
+ def empty_text_cond(self, cond):
215
+
216
+ return torch.zeros_like(cond, device=cond.device)
217
+
218
+ @torch.no_grad()
219
+ def empty_img_cond(self, cond):
220
+
221
+ return torch.zeros_like(cond, device=cond.device)
222
+
223
+ @torch.no_grad()
224
+ def empty_surface_cond(self, cond):
225
+
226
+ return torch.zeros_like(cond, device=cond.device)
227
+
228
+ @torch.no_grad()
229
+ def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
230
+
231
+ z_q = self.first_stage_model.encode(surface, sample_posterior)
232
+ z_q = self.z_scale_factor * z_q
233
+
234
+ return z_q
235
+
236
+ @torch.no_grad()
237
+ def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
238
+
239
+ z_q = 1. / self.z_scale_factor * z_q
240
+ latents = self.first_stage_model.decode(z_q, **kwargs)
241
+ return latents
242
+
243
+ @rank_zero_only
244
+ @torch.no_grad()
245
+ def on_train_batch_start(self, batch, batch_idx):
246
+ # only for very first batch
247
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
248
+ and batch_idx == 0 and self.ckpt_path is None:
249
+ # set rescale weight to 1./std of encodings
250
+ print("### USING STD-RESCALING ###")
251
+
252
+ z_q = self.encode_first_stage(batch[self.first_stage_key])
253
+ z = z_q.detach()
254
+
255
+ del self.z_scale_factor
256
+ self.register_buffer("z_scale_factor", 1. / z.flatten().std())
257
+ print(f"setting self.z_scale_factor to {self.z_scale_factor}")
258
+
259
+ print("### USING STD-RESCALING ###")
260
+
261
+ def compute_loss(self, model_outputs, split):
262
+ """
263
+
264
+ Args:
265
+ model_outputs (dict):
266
+ - x_0:
267
+ - noise:
268
+ - noise_prior:
269
+ - noise_pred:
270
+ - noise_pred_prior:
271
+
272
+ split (str):
273
+
274
+ Returns:
275
+
276
+ """
277
+
278
+ pred = model_outputs["pred"]
279
+
280
+ if self.noise_scheduler.prediction_type == "epsilon":
281
+ target = model_outputs["noise"]
282
+ elif self.noise_scheduler.prediction_type == "sample":
283
+ target = model_outputs["x_0"]
284
+ else:
285
+ raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
286
+
287
+ if self.loss_cfg.loss_type == "l1":
288
+ simple = F.l1_loss(pred, target, reduction="mean")
289
+ elif self.loss_cfg.loss_type in ["mse", "l2"]:
290
+ simple = F.mse_loss(pred, target, reduction="mean")
291
+ else:
292
+ raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
293
+
294
+ total_loss = simple
295
+
296
+ loss_dict = {
297
+ f"{split}/total_loss": total_loss.clone().detach(),
298
+ f"{split}/simple": simple.detach(),
299
+ }
300
+
301
+ return total_loss, loss_dict
302
+
303
+ def forward(self, batch):
304
+ """
305
+
306
+ Args:
307
+ batch:
308
+
309
+ Returns:
310
+
311
+ """
312
+
313
+ if self.first_stage_model is None:
314
+ self.instantiate_first_stage(self.first_stage_config)
315
+
316
+ latents = self.encode_first_stage(batch[self.first_stage_key])
317
+
318
+ # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
319
+
320
+ conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
321
+
322
+ mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
323
+ conditions = conditions * mask.to(conditions)
324
+
325
+ # Sample noise that we"ll add to the latents
326
+ # [batch_size, n_token, latent_dim]
327
+ noise = torch.randn_like(latents)
328
+ bs = latents.shape[0]
329
+ # Sample a random timestep for each motion
330
+ timesteps = torch.randint(
331
+ 0,
332
+ self.noise_scheduler.config.num_train_timesteps,
333
+ (bs,),
334
+ device=latents.device,
335
+ )
336
+ timesteps = timesteps.long()
337
+ # Add noise to the latents according to the noise magnitude at each timestep
338
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
339
+
340
+ # diffusion model forward
341
+ noise_pred = self.model(noisy_z, timesteps, conditions)
342
+
343
+ diffusion_outputs = {
344
+ "x_0": noisy_z,
345
+ "noise": noise,
346
+ "pred": noise_pred
347
+ }
348
+
349
+ return diffusion_outputs
350
+
351
+ def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
352
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
353
+ """
354
+
355
+ Args:
356
+ batch (dict): the batch sample, and it contains:
357
+ - surface (torch.FloatTensor):
358
+ - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
359
+ - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
360
+ - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
361
+ - text (list of str):
362
+
363
+ batch_idx (int):
364
+
365
+ optimizer_idx (int):
366
+
367
+ Returns:
368
+ loss (torch.FloatTensor):
369
+
370
+ """
371
+
372
+ diffusion_outputs = self(batch)
373
+
374
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
375
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
376
+
377
+ return loss
378
+
379
+ def validation_step(self, batch: Dict[str, torch.FloatTensor],
380
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
381
+ """
382
+
383
+ Args:
384
+ batch (dict): the batch sample, and it contains:
385
+ - surface_pc (torch.FloatTensor): [n_pts, 4]
386
+ - surface_feats (torch.FloatTensor): [n_pts, c]
387
+ - text (list of str):
388
+
389
+ batch_idx (int):
390
+
391
+ optimizer_idx (int):
392
+
393
+ Returns:
394
+ loss (torch.FloatTensor):
395
+
396
+ """
397
+
398
+ diffusion_outputs = self(batch)
399
+
400
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
401
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
402
+
403
+ return loss
404
+
405
+ @torch.no_grad()
406
+ def sample(self,
407
+ batch: Dict[str, Union[torch.FloatTensor, List[str]]],
408
+ sample_times: int = 1,
409
+ steps: Optional[int] = None,
410
+ guidance_scale: Optional[float] = None,
411
+ eta: float = 0.0,
412
+ return_intermediates: bool = False, **kwargs):
413
+
414
+ if self.first_stage_model is None:
415
+ self.instantiate_first_stage(self.first_stage_config)
416
+
417
+ if steps is None:
418
+ steps = self.scheduler_cfg.num_inference_steps
419
+
420
+ if guidance_scale is None:
421
+ guidance_scale = self.scheduler_cfg.guidance_scale
422
+ do_classifier_free_guidance = guidance_scale > 0
423
+
424
+ # conditional encode
425
+ xc = batch[self.cond_stage_key]
426
+ # cond = self.cond_stage_model[self.cond_stage_key](xc)
427
+ cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
428
+
429
+ if do_classifier_free_guidance:
430
+ """
431
+ Note: There are two kinds of uncond for text.
432
+ 1: using "" as uncond text; (in SAL diffusion)
433
+ 2: zeros_like(cond) as uncond text; (in MDM)
434
+ """
435
+ # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
436
+ un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
437
+ # un_cond = torch.zeros_like(cond, device=cond.device)
438
+ cond = torch.cat([un_cond, cond], dim=0)
439
+
440
+ outputs = []
441
+ latents = None
442
+
443
+ if not return_intermediates:
444
+ for _ in range(sample_times):
445
+ sample_loop = ddim_sample(
446
+ self.denoise_scheduler,
447
+ self.model,
448
+ shape=self.first_stage_model.latent_shape,
449
+ cond=cond,
450
+ steps=steps,
451
+ guidance_scale=guidance_scale,
452
+ do_classifier_free_guidance=do_classifier_free_guidance,
453
+ device=self.device,
454
+ eta=eta,
455
+ disable_prog=not self.zero_rank
456
+ )
457
+ for sample, t in sample_loop:
458
+ latents = sample
459
+ outputs.append(self.decode_first_stage(latents, **kwargs))
460
+ else:
461
+
462
+ sample_loop = ddim_sample(
463
+ self.denoise_scheduler,
464
+ self.model,
465
+ shape=self.first_stage_model.latent_shape,
466
+ cond=cond,
467
+ steps=steps,
468
+ guidance_scale=guidance_scale,
469
+ do_classifier_free_guidance=do_classifier_free_guidance,
470
+ device=self.device,
471
+ eta=eta,
472
+ disable_prog=not self.zero_rank
473
+ )
474
+
475
+ iter_size = steps // sample_times
476
+ i = 0
477
+ for sample, t in sample_loop:
478
+ latents = sample
479
+ if i % iter_size == 0 or i == steps - 1:
480
+ outputs.append(self.decode_first_stage(latents, **kwargs))
481
+ i += 1
482
+
483
+ return outputs
michelangelo/models/asl_diffusion/asl_udt.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional
6
+ from diffusers.models.embeddings import Timesteps
7
+ import math
8
+
9
+ from michelangelo.models.modules.transformer_blocks import MLP
10
+ from michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer
11
+
12
+
13
+ class ConditionalASLUDTDenoiser(nn.Module):
14
+
15
+ def __init__(self, *,
16
+ device: Optional[torch.device],
17
+ dtype: Optional[torch.dtype],
18
+ input_channels: int,
19
+ output_channels: int,
20
+ n_ctx: int,
21
+ width: int,
22
+ layers: int,
23
+ heads: int,
24
+ context_dim: int,
25
+ context_ln: bool = True,
26
+ skip_ln: bool = False,
27
+ init_scale: float = 0.25,
28
+ flip_sin_to_cos: bool = False,
29
+ use_checkpoint: bool = False):
30
+ super().__init__()
31
+
32
+ self.use_checkpoint = use_checkpoint
33
+
34
+ init_scale = init_scale * math.sqrt(1.0 / width)
35
+
36
+ self.backbone = UNetDiffusionTransformer(
37
+ device=device,
38
+ dtype=dtype,
39
+ n_ctx=n_ctx,
40
+ width=width,
41
+ layers=layers,
42
+ heads=heads,
43
+ skip_ln=skip_ln,
44
+ init_scale=init_scale,
45
+ use_checkpoint=use_checkpoint
46
+ )
47
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
48
+ self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
49
+ self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
50
+
51
+ # timestep embedding
52
+ self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
53
+ self.time_proj = MLP(
54
+ device=device, dtype=dtype, width=width, init_scale=init_scale
55
+ )
56
+
57
+ self.context_embed = nn.Sequential(
58
+ nn.LayerNorm(context_dim, device=device, dtype=dtype),
59
+ nn.Linear(context_dim, width, device=device, dtype=dtype),
60
+ )
61
+
62
+ if context_ln:
63
+ self.context_embed = nn.Sequential(
64
+ nn.LayerNorm(context_dim, device=device, dtype=dtype),
65
+ nn.Linear(context_dim, width, device=device, dtype=dtype),
66
+ )
67
+ else:
68
+ self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
69
+
70
+ def forward(self,
71
+ model_input: torch.FloatTensor,
72
+ timestep: torch.LongTensor,
73
+ context: torch.FloatTensor):
74
+
75
+ r"""
76
+ Args:
77
+ model_input (torch.FloatTensor): [bs, n_data, c]
78
+ timestep (torch.LongTensor): [bs,]
79
+ context (torch.FloatTensor): [bs, context_tokens, c]
80
+
81
+ Returns:
82
+ sample (torch.FloatTensor): [bs, n_data, c]
83
+
84
+ """
85
+
86
+ _, n_data, _ = model_input.shape
87
+
88
+ # 1. time
89
+ t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
90
+
91
+ # 2. conditions projector
92
+ context = self.context_embed(context)
93
+
94
+ # 3. denoiser
95
+ x = self.input_proj(model_input)
96
+ x = torch.cat([t_emb, context, x], dim=1)
97
+ x = self.backbone(x)
98
+ x = self.ln_post(x)
99
+ x = x[:, -n_data:]
100
+ sample = self.output_proj(x)
101
+
102
+ return sample
103
+
104
+
michelangelo/models/asl_diffusion/base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BaseDenoiser(nn.Module):
8
+
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, t, context):
13
+ raise NotImplementedError
michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from omegaconf import DictConfig
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.optim import lr_scheduler
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+
13
+ from diffusers.schedulers import (
14
+ DDPMScheduler,
15
+ DDIMScheduler,
16
+ KarrasVeScheduler,
17
+ DPMSolverMultistepScheduler
18
+ )
19
+
20
+ from michelangelo.utils import instantiate_from_config
21
+ from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule
22
+ from michelangelo.models.asl_diffusion.inference_utils import ddim_sample
23
+
24
+ SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
25
+
26
+
27
+ def disabled_train(self, mode=True):
28
+ """Overwrite model.train with this function to make sure train/eval mode
29
+ does not change anymore."""
30
+ return self
31
+
32
+
33
+ class ClipASLDiffuser(pl.LightningModule):
34
+ first_stage_model: Optional[AlignedShapeAsLatentPLModule]
35
+ cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
36
+ model: nn.Module
37
+
38
+ def __init__(self, *,
39
+ first_stage_config,
40
+ cond_stage_config,
41
+ denoiser_cfg,
42
+ scheduler_cfg,
43
+ optimizer_cfg,
44
+ loss_cfg,
45
+ first_stage_key: str = "surface",
46
+ cond_stage_key: str = "image",
47
+ scale_by_std: bool = False,
48
+ z_scale_factor: float = 1.0,
49
+ ckpt_path: Optional[str] = None,
50
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
51
+
52
+ super().__init__()
53
+
54
+ self.first_stage_key = first_stage_key
55
+ self.cond_stage_key = cond_stage_key
56
+
57
+ # 1. lazy initialize first stage
58
+ self.instantiate_first_stage(first_stage_config)
59
+
60
+ # 2. initialize conditional stage
61
+ self.instantiate_cond_stage(cond_stage_config)
62
+
63
+ # 3. diffusion model
64
+ self.model = instantiate_from_config(
65
+ denoiser_cfg, device=None, dtype=None
66
+ )
67
+
68
+ self.optimizer_cfg = optimizer_cfg
69
+
70
+ # 4. scheduling strategy
71
+ self.scheduler_cfg = scheduler_cfg
72
+
73
+ self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
74
+ self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
75
+
76
+ # 5. loss configures
77
+ self.loss_cfg = loss_cfg
78
+
79
+ self.scale_by_std = scale_by_std
80
+ if scale_by_std:
81
+ self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
82
+ else:
83
+ self.z_scale_factor = z_scale_factor
84
+
85
+ self.ckpt_path = ckpt_path
86
+ if ckpt_path is not None:
87
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
88
+
89
+ def instantiate_non_trainable_model(self, config):
90
+ model = instantiate_from_config(config)
91
+ model = model.eval()
92
+ model.train = disabled_train
93
+ for param in model.parameters():
94
+ param.requires_grad = False
95
+
96
+ return model
97
+
98
+ def instantiate_first_stage(self, first_stage_config):
99
+ self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
100
+ self.first_stage_model.set_shape_model_only()
101
+
102
+ def instantiate_cond_stage(self, cond_stage_config):
103
+ self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
104
+
105
+ def init_from_ckpt(self, path, ignore_keys=()):
106
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
107
+
108
+ keys = list(state_dict.keys())
109
+ for k in keys:
110
+ for ik in ignore_keys:
111
+ if k.startswith(ik):
112
+ print("Deleting key {} from state_dict.".format(k))
113
+ del state_dict[k]
114
+
115
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
116
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
117
+ if len(missing) > 0:
118
+ print(f"Missing Keys: {missing}")
119
+ print(f"Unexpected Keys: {unexpected}")
120
+
121
+ @property
122
+ def zero_rank(self):
123
+ if self._trainer:
124
+ zero_rank = self.trainer.local_rank == 0
125
+ else:
126
+ zero_rank = True
127
+
128
+ return zero_rank
129
+
130
+ def configure_optimizers(self) -> Tuple[List, List]:
131
+
132
+ lr = self.learning_rate
133
+
134
+ trainable_parameters = list(self.model.parameters())
135
+ if self.optimizer_cfg is None:
136
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
137
+ schedulers = []
138
+ else:
139
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
140
+ scheduler_func = instantiate_from_config(
141
+ self.optimizer_cfg.scheduler,
142
+ max_decay_steps=self.trainer.max_steps,
143
+ lr_max=lr
144
+ )
145
+ scheduler = {
146
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
147
+ "interval": "step",
148
+ "frequency": 1
149
+ }
150
+ optimizers = [optimizer]
151
+ schedulers = [scheduler]
152
+
153
+ return optimizers, schedulers
154
+
155
+ @torch.no_grad()
156
+ def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
157
+
158
+ z_q = self.first_stage_model.encode(surface, sample_posterior)
159
+ z_q = self.z_scale_factor * z_q
160
+
161
+ return z_q
162
+
163
+ @torch.no_grad()
164
+ def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
165
+
166
+ z_q = 1. / self.z_scale_factor * z_q
167
+ latents = self.first_stage_model.decode(z_q, **kwargs)
168
+ return latents
169
+
170
+ @rank_zero_only
171
+ @torch.no_grad()
172
+ def on_train_batch_start(self, batch, batch_idx):
173
+ # only for very first batch
174
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
175
+ and batch_idx == 0 and self.ckpt_path is None:
176
+ # set rescale weight to 1./std of encodings
177
+ print("### USING STD-RESCALING ###")
178
+
179
+ z_q = self.encode_first_stage(batch[self.first_stage_key])
180
+ z = z_q.detach()
181
+
182
+ del self.z_scale_factor
183
+ self.register_buffer("z_scale_factor", 1. / z.flatten().std())
184
+ print(f"setting self.z_scale_factor to {self.z_scale_factor}")
185
+
186
+ print("### USING STD-RESCALING ###")
187
+
188
+ def compute_loss(self, model_outputs, split):
189
+ """
190
+
191
+ Args:
192
+ model_outputs (dict):
193
+ - x_0:
194
+ - noise:
195
+ - noise_prior:
196
+ - noise_pred:
197
+ - noise_pred_prior:
198
+
199
+ split (str):
200
+
201
+ Returns:
202
+
203
+ """
204
+
205
+ pred = model_outputs["pred"]
206
+
207
+ if self.noise_scheduler.prediction_type == "epsilon":
208
+ target = model_outputs["noise"]
209
+ elif self.noise_scheduler.prediction_type == "sample":
210
+ target = model_outputs["x_0"]
211
+ else:
212
+ raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
213
+
214
+ if self.loss_cfg.loss_type == "l1":
215
+ simple = F.l1_loss(pred, target, reduction="mean")
216
+ elif self.loss_cfg.loss_type in ["mse", "l2"]:
217
+ simple = F.mse_loss(pred, target, reduction="mean")
218
+ else:
219
+ raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
220
+
221
+ total_loss = simple
222
+
223
+ loss_dict = {
224
+ f"{split}/total_loss": total_loss.clone().detach(),
225
+ f"{split}/simple": simple.detach(),
226
+ }
227
+
228
+ return total_loss, loss_dict
229
+
230
+ def forward(self, batch):
231
+ """
232
+
233
+ Args:
234
+ batch:
235
+
236
+ Returns:
237
+
238
+ """
239
+
240
+ latents = self.encode_first_stage(batch[self.first_stage_key])
241
+ conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
242
+
243
+ # Sample noise that we"ll add to the latents
244
+ # [batch_size, n_token, latent_dim]
245
+ noise = torch.randn_like(latents)
246
+ bs = latents.shape[0]
247
+ # Sample a random timestep for each motion
248
+ timesteps = torch.randint(
249
+ 0,
250
+ self.noise_scheduler.config.num_train_timesteps,
251
+ (bs,),
252
+ device=latents.device,
253
+ )
254
+ timesteps = timesteps.long()
255
+ # Add noise to the latents according to the noise magnitude at each timestep
256
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
257
+
258
+ # diffusion model forward
259
+ noise_pred = self.model(noisy_z, timesteps, conditions)
260
+
261
+ diffusion_outputs = {
262
+ "x_0": noisy_z,
263
+ "noise": noise,
264
+ "pred": noise_pred
265
+ }
266
+
267
+ return diffusion_outputs
268
+
269
+ def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
270
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
271
+ """
272
+
273
+ Args:
274
+ batch (dict): the batch sample, and it contains:
275
+ - surface (torch.FloatTensor):
276
+ - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
277
+ - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
278
+ - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
279
+ - text (list of str):
280
+
281
+ batch_idx (int):
282
+
283
+ optimizer_idx (int):
284
+
285
+ Returns:
286
+ loss (torch.FloatTensor):
287
+
288
+ """
289
+
290
+ diffusion_outputs = self(batch)
291
+
292
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
293
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
294
+
295
+ return loss
296
+
297
+ def validation_step(self, batch: Dict[str, torch.FloatTensor],
298
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
299
+ """
300
+
301
+ Args:
302
+ batch (dict): the batch sample, and it contains:
303
+ - surface_pc (torch.FloatTensor): [n_pts, 4]
304
+ - surface_feats (torch.FloatTensor): [n_pts, c]
305
+ - text (list of str):
306
+
307
+ batch_idx (int):
308
+
309
+ optimizer_idx (int):
310
+
311
+ Returns:
312
+ loss (torch.FloatTensor):
313
+
314
+ """
315
+
316
+ diffusion_outputs = self(batch)
317
+
318
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
319
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
320
+
321
+ return loss
322
+
323
+ @torch.no_grad()
324
+ def sample(self,
325
+ batch: Dict[str, Union[torch.FloatTensor, List[str]]],
326
+ sample_times: int = 1,
327
+ steps: Optional[int] = None,
328
+ guidance_scale: Optional[float] = None,
329
+ eta: float = 0.0,
330
+ return_intermediates: bool = False, **kwargs):
331
+
332
+ if steps is None:
333
+ steps = self.scheduler_cfg.num_inference_steps
334
+
335
+ if guidance_scale is None:
336
+ guidance_scale = self.scheduler_cfg.guidance_scale
337
+ do_classifier_free_guidance = guidance_scale > 0
338
+
339
+ # conditional encode
340
+ xc = batch[self.cond_stage_key]
341
+
342
+ # print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
343
+
344
+ cond = self.cond_stage_model(xc)
345
+
346
+ if do_classifier_free_guidance:
347
+ un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
348
+ cond = torch.cat([un_cond, cond], dim=0)
349
+
350
+ outputs = []
351
+ latents = None
352
+
353
+ if not return_intermediates:
354
+ for _ in range(sample_times):
355
+ sample_loop = ddim_sample(
356
+ self.denoise_scheduler,
357
+ self.model,
358
+ shape=self.first_stage_model.latent_shape,
359
+ cond=cond,
360
+ steps=steps,
361
+ guidance_scale=guidance_scale,
362
+ do_classifier_free_guidance=do_classifier_free_guidance,
363
+ device=self.device,
364
+ eta=eta,
365
+ disable_prog=not self.zero_rank
366
+ )
367
+ for sample, t in sample_loop:
368
+ latents = sample
369
+ outputs.append(self.decode_first_stage(latents, **kwargs))
370
+ else:
371
+
372
+ sample_loop = ddim_sample(
373
+ self.denoise_scheduler,
374
+ self.model,
375
+ shape=self.first_stage_model.latent_shape,
376
+ cond=cond,
377
+ steps=steps,
378
+ guidance_scale=guidance_scale,
379
+ do_classifier_free_guidance=do_classifier_free_guidance,
380
+ device=self.device,
381
+ eta=eta,
382
+ disable_prog=not self.zero_rank
383
+ )
384
+
385
+ iter_size = steps // sample_times
386
+ i = 0
387
+ for sample, t in sample_loop:
388
+ latents = sample
389
+ if i % iter_size == 0 or i == steps - 1:
390
+ outputs.append(self.decode_first_stage(latents, **kwargs))
391
+ i += 1
392
+
393
+ return outputs
michelangelo/models/asl_diffusion/inference_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from typing import Tuple, List, Union, Optional
6
+ from diffusers.schedulers import DDIMScheduler
7
+
8
+
9
+ __all__ = ["ddim_sample"]
10
+
11
+
12
+ def ddim_sample(ddim_scheduler: DDIMScheduler,
13
+ diffusion_model: torch.nn.Module,
14
+ shape: Union[List[int], Tuple[int]],
15
+ cond: torch.FloatTensor,
16
+ steps: int,
17
+ eta: float = 0.0,
18
+ guidance_scale: float = 3.0,
19
+ do_classifier_free_guidance: bool = True,
20
+ generator: Optional[torch.Generator] = None,
21
+ device: torch.device = "cuda:0",
22
+ disable_prog: bool = True):
23
+
24
+ assert steps > 0, f"{steps} must > 0."
25
+
26
+ # init latents
27
+ bsz = cond.shape[0]
28
+ if do_classifier_free_guidance:
29
+ bsz = bsz // 2
30
+
31
+ latents = torch.randn(
32
+ (bsz, *shape),
33
+ generator=generator,
34
+ device=cond.device,
35
+ dtype=cond.dtype,
36
+ )
37
+ # scale the initial noise by the standard deviation required by the scheduler
38
+ latents = latents * ddim_scheduler.init_noise_sigma
39
+ # set timesteps
40
+ ddim_scheduler.set_timesteps(steps)
41
+ timesteps = ddim_scheduler.timesteps.to(device)
42
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
43
+ # eta (η) is only used with the DDIMScheduler, and between [0, 1]
44
+ extra_step_kwargs = {
45
+ "eta": eta,
46
+ "generator": generator
47
+ }
48
+
49
+ # reverse
50
+ for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
51
+ # expand the latents if we are doing classifier free guidance
52
+ latent_model_input = (
53
+ torch.cat([latents] * 2)
54
+ if do_classifier_free_guidance
55
+ else latents
56
+ )
57
+ # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
58
+ # predict the noise residual
59
+ timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
60
+ timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
61
+ noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
62
+
63
+ # perform guidance
64
+ if do_classifier_free_guidance:
65
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
66
+ noise_pred = noise_pred_uncond + guidance_scale * (
67
+ noise_pred_text - noise_pred_uncond
68
+ )
69
+ # text_embeddings_for_guidance = encoder_hidden_states.chunk(
70
+ # 2)[1] if do_classifier_free_guidance else encoder_hidden_states
71
+ # compute the previous noisy sample x_t -> x_t-1
72
+ latents = ddim_scheduler.step(
73
+ noise_pred, t, latents, **extra_step_kwargs
74
+ ).prev_sample
75
+
76
+ yield latents, t
77
+
78
+
79
+ def karra_sample():
80
+ pass
michelangelo/models/conditional_encoders/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .clip import CLIPEncoder
michelangelo/models/conditional_encoders/clip.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from dataclasses import dataclass
7
+ from torchvision.transforms import Normalize
8
+ from transformers import CLIPModel, CLIPTokenizer
9
+ from transformers.utils import ModelOutput
10
+ from typing import Iterable, Optional, Union, List
11
+
12
+
13
+ ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
14
+
15
+
16
+ @dataclass
17
+ class CLIPEmbedOutput(ModelOutput):
18
+ last_hidden_state: torch.FloatTensor = None
19
+ pooler_output: torch.FloatTensor = None
20
+ embeds: torch.FloatTensor = None
21
+
22
+
23
+ class CLIPEncoder(torch.nn.Module):
24
+
25
+ def __init__(self, model_path="openai/clip-vit-base-patch32"):
26
+
27
+ super().__init__()
28
+
29
+ # Load the CLIP model and processor
30
+ self.model: CLIPModel = CLIPModel.from_pretrained(model_path)
31
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_path)
32
+ self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
33
+
34
+ self.model.training = False
35
+ for p in self.model.parameters():
36
+ p.requires_grad = False
37
+
38
+ @torch.no_grad()
39
+ def encode_image(self, images: Iterable[Optional[ImageType]]):
40
+ pixel_values = self.image_preprocess(images)
41
+
42
+ vision_outputs = self.model.vision_model(pixel_values=pixel_values)
43
+
44
+ pooler_output = vision_outputs[1] # pooled_output
45
+ image_features = self.model.visual_projection(pooler_output)
46
+
47
+ visual_embeds = CLIPEmbedOutput(
48
+ last_hidden_state=vision_outputs.last_hidden_state,
49
+ pooler_output=pooler_output,
50
+ embeds=image_features
51
+ )
52
+
53
+ return visual_embeds
54
+
55
+ @torch.no_grad()
56
+ def encode_text(self, texts: List[str]):
57
+ text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt")
58
+
59
+ text_outputs = self.model.text_model(input_ids=text_inputs)
60
+
61
+ pooler_output = text_outputs[1] # pooled_output
62
+ text_features = self.model.text_projection(pooler_output)
63
+
64
+ text_embeds = CLIPEmbedOutput(
65
+ last_hidden_state=text_outputs.last_hidden_state,
66
+ pooler_output=pooler_output,
67
+ embeds=text_features
68
+ )
69
+
70
+ return text_embeds
71
+
72
+ def forward(self,
73
+ images: Iterable[Optional[ImageType]],
74
+ texts: List[str]):
75
+
76
+ visual_embeds = self.encode_image(images)
77
+ text_embeds = self.encode_text(texts)
78
+
79
+ return visual_embeds, text_embeds
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
michelangelo/models/conditional_encoders/encoder_factory.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms
7
+ from transformers import CLIPModel, CLIPTokenizer
8
+ from collections import OrderedDict
9
+
10
+ from michelangelo.data.transforms import RandomResize
11
+
12
+
13
+ class AbstractEncoder(nn.Module):
14
+ embedding_dim: int
15
+
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def encode(self, *args, **kwargs):
20
+ raise NotImplementedError
21
+
22
+
23
+ class ClassEmbedder(nn.Module):
24
+ def __init__(self, embed_dim, n_classes=1000, key="class"):
25
+ super().__init__()
26
+ self.key = key
27
+ self.embedding = nn.Embedding(n_classes, embed_dim)
28
+
29
+ def forward(self, batch, key=None):
30
+ if key is None:
31
+ key = self.key
32
+ # this is for use in crossattn
33
+ c = batch[key][:, None]
34
+ c = self.embedding(c)
35
+ return c
36
+
37
+
38
+ class FrozenCLIPTextEmbedder(AbstractEncoder):
39
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
40
+
41
+ def __init__(
42
+ self,
43
+ version="openai/clip-vit-large-patch14",
44
+ tokenizer_version=None,
45
+ device="cuda",
46
+ max_length=77,
47
+ zero_embedding_radio: float = 0.1,
48
+ ):
49
+ super().__init__()
50
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
51
+
52
+ self.device = device
53
+ self.max_length = max_length
54
+ self.zero_embedding_radio = zero_embedding_radio
55
+
56
+ self.clip_dict = OrderedDict()
57
+ self.clip_name = os.path.split(version)[-1]
58
+
59
+ transformer = CLIPModel.from_pretrained(version).text_model
60
+
61
+ for param in transformer.parameters():
62
+ param.requires_grad = False
63
+ self.clip_dict[self.clip_name] = transformer
64
+
65
+ self._move_flag = False
66
+
67
+ @property
68
+ def clip(self):
69
+ return self.clip_dict[self.clip_name]
70
+
71
+ def move(self):
72
+ if self._move_flag:
73
+ return
74
+
75
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
76
+ self._move_flag = True
77
+
78
+ def unconditional_embedding(self, batch_size):
79
+ empty_text = [""] * batch_size
80
+ empty_z = self.forward(empty_text)
81
+ return empty_z
82
+
83
+ def forward(self, text):
84
+ self.move()
85
+
86
+ batch_encoding = self.tokenizer(
87
+ text,
88
+ truncation=True,
89
+ max_length=self.max_length,
90
+ return_length=True,
91
+ return_overflowing_tokens=False,
92
+ padding="max_length",
93
+ return_tensors="pt",
94
+ )
95
+
96
+ tokens = batch_encoding["input_ids"].to(self.device)
97
+ outputs = self.clip(input_ids=tokens)
98
+
99
+ z = outputs.last_hidden_state
100
+ return z
101
+
102
+ def encode(self, text):
103
+ batch_size = len(text)
104
+ batch_mask = torch.rand((batch_size,))
105
+ for i in range(batch_size):
106
+ if batch_mask[i] < self.zero_embedding_radio:
107
+ text[i] = ""
108
+
109
+ return self(text)
110
+
111
+ class FrozenAlignedCLIPTextEmbedder(AbstractEncoder):
112
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
113
+
114
+ def __init__(
115
+ self,
116
+ version="openai/clip-vit-large-patch14",
117
+ tokenizer_version=None,
118
+ device="cuda",
119
+ max_length=77,
120
+ zero_embedding_radio: float = 0.1,
121
+ ):
122
+ super().__init__()
123
+ self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version)
124
+
125
+ self.device = device
126
+ self.max_length = max_length
127
+ self.zero_embedding_radio = zero_embedding_radio
128
+
129
+ self.clip_dict = OrderedDict()
130
+ self.clip_name = os.path.split(version)[-1]
131
+
132
+ transformer = CLIPModel.from_pretrained(version).text_model
133
+
134
+ for param in transformer.parameters():
135
+ param.requires_grad = False
136
+ self.clip_dict[self.clip_name] = transformer
137
+
138
+ self._move_flag = False
139
+
140
+ @property
141
+ def clip(self):
142
+ return self.clip_dict[self.clip_name]
143
+
144
+ def move(self):
145
+ if self._move_flag:
146
+ return
147
+
148
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
149
+ self._move_flag = True
150
+
151
+ def unconditional_embedding(self, batch_size):
152
+ empty_text = [""] * batch_size
153
+ empty_z = self.forward(empty_text)
154
+ return empty_z
155
+
156
+ def forward(self, text):
157
+ self.move()
158
+
159
+ batch_encoding = self.tokenizer(
160
+ text,
161
+ truncation=True,
162
+ max_length=self.max_length,
163
+ return_length=True,
164
+ return_overflowing_tokens=False,
165
+ padding="max_length",
166
+ return_tensors="pt",
167
+ )
168
+
169
+ tokens = batch_encoding["input_ids"].to(self.device)
170
+ outputs = self.clip(input_ids=tokens)
171
+
172
+ z = outputs.last_hidden_state
173
+ return z
174
+
175
+ def encode(self, text):
176
+ batch_size = len(text)
177
+ batch_mask = torch.rand((batch_size,))
178
+ for i in range(batch_size):
179
+ if batch_mask[i] < self.zero_embedding_radio:
180
+ text[i] = ""
181
+
182
+ return self(text)
183
+
184
+
185
+ class FrozenCLIPImageEmbedder(AbstractEncoder):
186
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
187
+
188
+ def __init__(
189
+ self,
190
+ version="openai/clip-vit-large-patch14",
191
+ device="cuda",
192
+ zero_embedding_radio=0.1,
193
+ normalize_embedding=True,
194
+ num_projection_vector=0,
195
+ linear_mapping_bias=True,
196
+ reverse_visual_projection=False,
197
+ ):
198
+ super().__init__()
199
+
200
+ self.device = device
201
+
202
+ self.clip_dict = OrderedDict()
203
+ self.clip_name = os.path.split(version)[-1]
204
+
205
+ clip_model = CLIPModel.from_pretrained(version)
206
+ clip_model.text_model = None
207
+ clip_model.text_projection = None
208
+ clip_model = clip_model.eval()
209
+ for param in self.parameters():
210
+ param.requires_grad = False
211
+ self.clip_dict[self.clip_name] = clip_model
212
+
213
+ self.transform = transforms.Compose(
214
+ [
215
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
216
+ transforms.CenterCrop(224), # crop a (224, 224) square
217
+ transforms.Normalize(
218
+ mean=[0.48145466, 0.4578275, 0.40821073],
219
+ std=[0.26862954, 0.26130258, 0.27577711],
220
+ ),
221
+ ]
222
+ )
223
+ self.zero_embedding_radio = zero_embedding_radio
224
+
225
+ self.num_projection_vector = num_projection_vector
226
+ self.reverse_visual_projection = reverse_visual_projection
227
+ self.normalize_embedding = normalize_embedding
228
+
229
+ embedding_dim = (
230
+ clip_model.visual_projection.in_features
231
+ if reverse_visual_projection
232
+ else clip_model.visual_projection.out_features
233
+ )
234
+ self.embedding_dim = embedding_dim
235
+ if self.num_projection_vector > 0:
236
+ self.projection = nn.Linear(
237
+ embedding_dim,
238
+ clip_model.visual_projection.out_features * num_projection_vector,
239
+ bias=linear_mapping_bias,
240
+ )
241
+ nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5)
242
+
243
+ self._move_flag = False
244
+
245
+ @property
246
+ def clip(self):
247
+ return self.clip_dict[self.clip_name]
248
+
249
+ def unconditional_embedding(self, batch_size):
250
+ zero = torch.zeros(
251
+ batch_size,
252
+ 1,
253
+ self.embedding_dim,
254
+ device=self.device,
255
+ dtype=self.clip.visual_projection.weight.dtype,
256
+ )
257
+ if self.num_projection_vector > 0:
258
+ zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
259
+ return zero
260
+
261
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
262
+ if value_range is not None:
263
+ low, high = value_range
264
+ image = (image - low) / (high - low)
265
+
266
+ image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
267
+
268
+ if self.reverse_visual_projection:
269
+ z = self.clip.vision_model(self.transform(image))[1]
270
+ else:
271
+ z = self.clip.get_image_features(self.transform(image))
272
+
273
+ if self.normalize_embedding:
274
+ z = z / z.norm(dim=-1, keepdim=True)
275
+ if z.ndim == 2:
276
+ z = z.unsqueeze(dim=-2)
277
+
278
+ if zero_embedding_radio > 0:
279
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio
280
+ z = z * mask.to(z)
281
+
282
+ if self.num_projection_vector > 0:
283
+ z = self.projection(z).view(len(image), self.num_projection_vector, -1)
284
+
285
+ return z
286
+
287
+ def move(self):
288
+ if self._move_flag:
289
+ return
290
+
291
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
292
+ self._move_flag = True
293
+
294
+ def encode(self, image):
295
+ self.move()
296
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
297
+
298
+
299
+ class FrozenCLIPImageGridEmbedder(AbstractEncoder):
300
+
301
+ def __init__(
302
+ self,
303
+ version="openai/clip-vit-large-patch14",
304
+ device="cuda",
305
+ zero_embedding_radio=0.1,
306
+ ):
307
+ super().__init__()
308
+
309
+ self.device = device
310
+
311
+ self.clip_dict = OrderedDict()
312
+ self.clip_name = os.path.split(version)[-1]
313
+
314
+ clip_model: CLIPModel = CLIPModel.from_pretrained(version)
315
+ clip_model.text_model = None
316
+ clip_model.text_projection = None
317
+ clip_model = clip_model.eval()
318
+ for param in self.parameters():
319
+ param.requires_grad = False
320
+ self.clip_dict[self.clip_name] = clip_model
321
+
322
+ self.transform = transforms.Compose(
323
+ [
324
+ transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True),
325
+ transforms.CenterCrop(224), # crop a (224, 224) square
326
+ transforms.Normalize(
327
+ mean=[0.48145466, 0.4578275, 0.40821073],
328
+ std=[0.26862954, 0.26130258, 0.27577711],
329
+ ),
330
+ ]
331
+ )
332
+ self.zero_embedding_radio = zero_embedding_radio
333
+ self.embedding_dim = clip_model.vision_embed_dim
334
+
335
+ self._move_flag = False
336
+
337
+ @property
338
+ def clip(self):
339
+ return self.clip_dict[self.clip_name]
340
+
341
+ def move(self):
342
+ if self._move_flag:
343
+ return
344
+
345
+ self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device)
346
+ self._move_flag = True
347
+
348
+ def unconditional_embedding(self, batch_size):
349
+ zero = torch.zeros(
350
+ batch_size,
351
+ self.clip.vision_model.embeddings.num_positions,
352
+ self.embedding_dim,
353
+ device=self.device,
354
+ dtype=self.clip.visual_projection.weight.dtype,
355
+ )
356
+ return zero
357
+
358
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
359
+ self.move()
360
+
361
+ if value_range is not None:
362
+ low, high = value_range
363
+ image = (image - low) / (high - low)
364
+
365
+ image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype)
366
+
367
+ z = self.clip.vision_model(self.transform(image)).last_hidden_state
368
+
369
+ if zero_embedding_radio > 0:
370
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
371
+ z = z * mask.to(z)
372
+
373
+ return z
374
+
375
+ def encode(self, image):
376
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
377
+
378
+
379
+ class MoECLIPImageEncoder(nn.Module):
380
+ def __init__(
381
+ self,
382
+ versions,
383
+ hidden_state_dim,
384
+ num_projection_vector=8,
385
+ zero_embedding_radio=0.1,
386
+ device="cuda",
387
+ precision="fp16",
388
+ normalize=False,
389
+ clip_max=0,
390
+ transform_type="base",
391
+ argument_p=0.2,
392
+ ):
393
+ super().__init__()
394
+
395
+ self.device = torch.device(device)
396
+ self.hidden_state_dim = hidden_state_dim
397
+ self.zero_embedding_radio = zero_embedding_radio
398
+ self.num_projection_vector = num_projection_vector
399
+ self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision]
400
+ self.normalize = normalize
401
+ self.clip_max = clip_max
402
+
403
+ if transform_type == "base":
404
+ self.transform = transforms.Compose(
405
+ [
406
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
407
+ transforms.CenterCrop(224), # crop a (224, 224) square
408
+ transforms.Normalize(
409
+ mean=[0.48145466, 0.4578275, 0.40821073],
410
+ std=[0.26862954, 0.26130258, 0.27577711],
411
+ ),
412
+ ]
413
+ )
414
+ elif transform_type == "crop_blur_resize":
415
+ self.transform = transforms.Compose(
416
+ [
417
+ transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True),
418
+ transforms.CenterCrop(224), # crop a (224, 224) square
419
+ transforms.RandomApply(
420
+ transforms=[
421
+ transforms.RandomResizedCrop(
422
+ size=224,
423
+ scale=(0.8, 1.0),
424
+ ratio=(0.99, 1.01),
425
+ interpolation=transforms.InterpolationMode.BICUBIC,
426
+ ),
427
+ ],
428
+ p=argument_p,
429
+ ),
430
+ transforms.RandomApply(
431
+ transforms=[
432
+ transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)),
433
+ ],
434
+ p=argument_p,
435
+ ),
436
+ transforms.RandomApply(
437
+ transforms=[
438
+ RandomResize(size=224, resize_radio=(0.2, 1)),
439
+ ],
440
+ p=argument_p,
441
+ ),
442
+ transforms.Normalize(
443
+ mean=[0.48145466, 0.4578275, 0.40821073],
444
+ std=[0.26862954, 0.26130258, 0.27577711],
445
+ ),
446
+ ]
447
+ )
448
+ else:
449
+ raise ValueError(f"invalid {transform_type=}")
450
+
451
+ if isinstance(versions, str):
452
+ versions = (versions,)
453
+
454
+ # 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16
455
+ clips = OrderedDict()
456
+
457
+ for v in versions:
458
+ # 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。
459
+ clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None)
460
+ delattr(clips[v], "transformer")
461
+ clips[v].eval()
462
+ clips[v].requires_grad_(False)
463
+
464
+ self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips)
465
+
466
+ if self.num_projection_vector == 0:
467
+ self.projection = nn.Identity()
468
+ else:
469
+ self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True)
470
+ self.projection.to(dtype=self.dtype)
471
+ nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5)
472
+
473
+ self.clips = clips
474
+
475
+ self._move_flag = False
476
+
477
+ def move(self):
478
+ if self._move_flag:
479
+ return
480
+
481
+ def convert_weights(model: nn.Module):
482
+ """Convert applicable model parameters to fp16"""
483
+
484
+ def _convert_weights_to_fp16(l):
485
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
486
+ l.weight.data = l.weight.data.type(self.dtype)
487
+ if l.bias is not None:
488
+ l.bias.data = l.bias.data.type(self.dtype)
489
+
490
+ if isinstance(l, nn.MultiheadAttention):
491
+ for attr in [
492
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
493
+ "in_proj_bias",
494
+ "bias_k",
495
+ "bias_v",
496
+ ]:
497
+ tensor = getattr(l, attr)
498
+ if tensor is not None:
499
+ tensor.data = tensor.data.type(self.dtype)
500
+
501
+ for name in ["text_projection", "proj"]:
502
+ if hasattr(l, name):
503
+ attr = getattr(l, name)
504
+ if attr is not None:
505
+ attr.data = attr.data.type(self.dtype)
506
+
507
+ model.apply(_convert_weights_to_fp16)
508
+
509
+ for k in self.clips:
510
+ self.clips[k].to(self.device)
511
+ convert_weights(self.clips[k]) # fp32 -> self.dtype
512
+ self._move_flag = True
513
+
514
+ def unconditional_embedding(self, batch_size=None):
515
+ zero = torch.zeros(
516
+ batch_size,
517
+ self.clips_hidden_dim,
518
+ device=self.device,
519
+ dtype=self.dtype,
520
+ )
521
+ if self.num_projection_vector > 0:
522
+ zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1)
523
+ return zero
524
+
525
+ def convert_embedding(self, z):
526
+ if self.num_projection_vector > 0:
527
+ z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1)
528
+ return z
529
+
530
+ def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0):
531
+ if value_range is not None:
532
+ low, high = value_range
533
+ image = (image - low) / (high - low)
534
+
535
+ image = self.transform(image)
536
+
537
+ with torch.no_grad():
538
+ embs = []
539
+ for v in self.clips:
540
+ x = self.clips[v].encode_image(image)
541
+ if self.normalize:
542
+ x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5)
543
+ # clip_max only works with normalization
544
+ if self.clip_max > 0:
545
+ x = x.clamp(-self.clip_max, self.clip_max)
546
+ embs.append(x)
547
+
548
+ z = torch.cat(embs, dim=-1)
549
+ if self.normalize:
550
+ z /= z.size(-1) ** 0.5
551
+
552
+ if zero_embedding_radio > 0:
553
+ mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio
554
+ z = z + mask.to(z)
555
+
556
+ if self.num_projection_vector > 0:
557
+ z = self.projection(z).view(len(image), self.num_projection_vector, -1)
558
+ return z
559
+
560
+ def encode(self, image):
561
+ self.move()
562
+ return self(image, zero_embedding_radio=self.zero_embedding_radio)
michelangelo/models/modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .checkpoint import checkpoint
michelangelo/models/modules/checkpoint.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124
4
+ """
5
+
6
+ import torch
7
+ from typing import Callable, Iterable, Sequence, Union
8
+
9
+
10
+ def checkpoint(
11
+ func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]],
12
+ inputs: Sequence[torch.Tensor],
13
+ params: Iterable[torch.Tensor],
14
+ flag: bool,
15
+ use_deepspeed: bool = False
16
+ ):
17
+ """
18
+ Evaluate a function without caching intermediate activations, allowing for
19
+ reduced memory at the expense of extra compute in the backward pass.
20
+ :param func: the function to evaluate.
21
+ :param inputs: the argument sequence to pass to `func`.
22
+ :param params: a sequence of parameters `func` depends on but does not
23
+ explicitly take as arguments.
24
+ :param flag: if False, disable gradient checkpointing.
25
+ :param use_deepspeed: if True, use deepspeed
26
+ """
27
+ if flag:
28
+ if use_deepspeed:
29
+ import deepspeed
30
+ return deepspeed.checkpointing.checkpoint(func, *inputs)
31
+
32
+ args = tuple(inputs) + tuple(params)
33
+ return CheckpointFunction.apply(func, len(inputs), *args)
34
+ else:
35
+ return func(*inputs)
36
+
37
+
38
+ class CheckpointFunction(torch.autograd.Function):
39
+ @staticmethod
40
+ @torch.cuda.amp.custom_fwd
41
+ def forward(ctx, run_function, length, *args):
42
+ ctx.run_function = run_function
43
+ ctx.input_tensors = list(args[:length])
44
+ ctx.input_params = list(args[length:])
45
+
46
+ with torch.no_grad():
47
+ output_tensors = ctx.run_function(*ctx.input_tensors)
48
+ return output_tensors
49
+
50
+ @staticmethod
51
+ @torch.cuda.amp.custom_bwd
52
+ def backward(ctx, *output_grads):
53
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
54
+ with torch.enable_grad():
55
+ # Fixes a bug where the first op in run_function modifies the
56
+ # Tensor storage in place, which is not allowed for detach()'d
57
+ # Tensors.
58
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
59
+ output_tensors = ctx.run_function(*shallow_copies)
60
+ input_grads = torch.autograd.grad(
61
+ output_tensors,
62
+ ctx.input_tensors + ctx.input_params,
63
+ output_grads,
64
+ allow_unused=True,
65
+ )
66
+ del ctx.input_tensors
67
+ del ctx.input_params
68
+ del output_tensors
69
+ return (None, None) + input_grads
michelangelo/models/modules/diffusion_transformer.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+
8
+ from michelangelo.models.modules.checkpoint import checkpoint
9
+ from michelangelo.models.modules.transformer_blocks import (
10
+ init_linear,
11
+ MLP,
12
+ MultiheadCrossAttention,
13
+ MultiheadAttention,
14
+ ResidualAttentionBlock
15
+ )
16
+
17
+
18
+ class AdaLayerNorm(nn.Module):
19
+ def __init__(self,
20
+ device: torch.device,
21
+ dtype: torch.dtype,
22
+ width: int):
23
+
24
+ super().__init__()
25
+
26
+ self.silu = nn.SiLU(inplace=True)
27
+ self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype)
28
+ self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype)
29
+
30
+ def forward(self, x, timestep):
31
+ emb = self.linear(timestep)
32
+ scale, shift = torch.chunk(emb, 2, dim=2)
33
+ x = self.layernorm(x) * (1 + scale) + shift
34
+ return x
35
+
36
+
37
+ class DitBlock(nn.Module):
38
+ def __init__(
39
+ self,
40
+ *,
41
+ device: torch.device,
42
+ dtype: torch.dtype,
43
+ n_ctx: int,
44
+ width: int,
45
+ heads: int,
46
+ context_dim: int,
47
+ qkv_bias: bool = False,
48
+ init_scale: float = 1.0,
49
+ use_checkpoint: bool = False
50
+ ):
51
+ super().__init__()
52
+
53
+ self.use_checkpoint = use_checkpoint
54
+
55
+ self.attn = MultiheadAttention(
56
+ device=device,
57
+ dtype=dtype,
58
+ n_ctx=n_ctx,
59
+ width=width,
60
+ heads=heads,
61
+ init_scale=init_scale,
62
+ qkv_bias=qkv_bias
63
+ )
64
+ self.ln_1 = AdaLayerNorm(device, dtype, width)
65
+
66
+ if context_dim is not None:
67
+ self.ln_2 = AdaLayerNorm(device, dtype, width)
68
+ self.cross_attn = MultiheadCrossAttention(
69
+ device=device,
70
+ dtype=dtype,
71
+ width=width,
72
+ heads=heads,
73
+ data_width=context_dim,
74
+ init_scale=init_scale,
75
+ qkv_bias=qkv_bias
76
+ )
77
+
78
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
79
+ self.ln_3 = AdaLayerNorm(device, dtype, width)
80
+
81
+ def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
82
+ return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint)
83
+
84
+ def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
85
+ x = x + self.attn(self.ln_1(x, t))
86
+ if context is not None:
87
+ x = x + self.cross_attn(self.ln_2(x, t), context)
88
+ x = x + self.mlp(self.ln_3(x, t))
89
+ return x
90
+
91
+
92
+ class DiT(nn.Module):
93
+ def __init__(
94
+ self,
95
+ *,
96
+ device: Optional[torch.device],
97
+ dtype: Optional[torch.dtype],
98
+ n_ctx: int,
99
+ width: int,
100
+ layers: int,
101
+ heads: int,
102
+ context_dim: int,
103
+ init_scale: float = 0.25,
104
+ qkv_bias: bool = False,
105
+ use_checkpoint: bool = False
106
+ ):
107
+ super().__init__()
108
+ self.n_ctx = n_ctx
109
+ self.width = width
110
+ self.layers = layers
111
+
112
+ self.resblocks = nn.ModuleList(
113
+ [
114
+ DitBlock(
115
+ device=device,
116
+ dtype=dtype,
117
+ n_ctx=n_ctx,
118
+ width=width,
119
+ heads=heads,
120
+ context_dim=context_dim,
121
+ qkv_bias=qkv_bias,
122
+ init_scale=init_scale,
123
+ use_checkpoint=use_checkpoint
124
+ )
125
+ for _ in range(layers)
126
+ ]
127
+ )
128
+
129
+ def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None):
130
+ for block in self.resblocks:
131
+ x = block(x, t, context)
132
+ return x
133
+
134
+
135
+ class UNetDiffusionTransformer(nn.Module):
136
+ def __init__(
137
+ self,
138
+ *,
139
+ device: Optional[torch.device],
140
+ dtype: Optional[torch.dtype],
141
+ n_ctx: int,
142
+ width: int,
143
+ layers: int,
144
+ heads: int,
145
+ init_scale: float = 0.25,
146
+ qkv_bias: bool = False,
147
+ skip_ln: bool = False,
148
+ use_checkpoint: bool = False
149
+ ):
150
+ super().__init__()
151
+
152
+ self.n_ctx = n_ctx
153
+ self.width = width
154
+ self.layers = layers
155
+
156
+ self.encoder = nn.ModuleList()
157
+ for _ in range(layers):
158
+ resblock = ResidualAttentionBlock(
159
+ device=device,
160
+ dtype=dtype,
161
+ n_ctx=n_ctx,
162
+ width=width,
163
+ heads=heads,
164
+ init_scale=init_scale,
165
+ qkv_bias=qkv_bias,
166
+ use_checkpoint=use_checkpoint
167
+ )
168
+ self.encoder.append(resblock)
169
+
170
+ self.middle_block = ResidualAttentionBlock(
171
+ device=device,
172
+ dtype=dtype,
173
+ n_ctx=n_ctx,
174
+ width=width,
175
+ heads=heads,
176
+ init_scale=init_scale,
177
+ qkv_bias=qkv_bias,
178
+ use_checkpoint=use_checkpoint
179
+ )
180
+
181
+ self.decoder = nn.ModuleList()
182
+ for _ in range(layers):
183
+ resblock = ResidualAttentionBlock(
184
+ device=device,
185
+ dtype=dtype,
186
+ n_ctx=n_ctx,
187
+ width=width,
188
+ heads=heads,
189
+ init_scale=init_scale,
190
+ qkv_bias=qkv_bias,
191
+ use_checkpoint=use_checkpoint
192
+ )
193
+ linear = nn.Linear(width * 2, width, device=device, dtype=dtype)
194
+ init_linear(linear, init_scale)
195
+
196
+ layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None
197
+
198
+ self.decoder.append(nn.ModuleList([resblock, linear, layer_norm]))
199
+
200
+ def forward(self, x: torch.Tensor):
201
+
202
+ enc_outputs = []
203
+ for block in self.encoder:
204
+ x = block(x)
205
+ enc_outputs.append(x)
206
+
207
+ x = self.middle_block(x)
208
+
209
+ for i, (resblock, linear, layer_norm) in enumerate(self.decoder):
210
+ x = torch.cat([enc_outputs.pop(), x], dim=-1)
211
+ x = linear(x)
212
+
213
+ if layer_norm is not None:
214
+ x = layer_norm(x)
215
+
216
+ x = resblock(x)
217
+
218
+ return x
michelangelo/models/modules/distributions.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Union, List
4
+
5
+
6
+ class AbstractDistribution(object):
7
+ def sample(self):
8
+ raise NotImplementedError()
9
+
10
+ def mode(self):
11
+ raise NotImplementedError()
12
+
13
+
14
+ class DiracDistribution(AbstractDistribution):
15
+ def __init__(self, value):
16
+ self.value = value
17
+
18
+ def sample(self):
19
+ return self.value
20
+
21
+ def mode(self):
22
+ return self.value
23
+
24
+
25
+ class DiagonalGaussianDistribution(object):
26
+ def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1):
27
+ self.feat_dim = feat_dim
28
+ self.parameters = parameters
29
+
30
+ if isinstance(parameters, list):
31
+ self.mean = parameters[0]
32
+ self.logvar = parameters[1]
33
+ else:
34
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim)
35
+
36
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
37
+ self.deterministic = deterministic
38
+ self.std = torch.exp(0.5 * self.logvar)
39
+ self.var = torch.exp(self.logvar)
40
+ if self.deterministic:
41
+ self.var = self.std = torch.zeros_like(self.mean)
42
+
43
+ def sample(self):
44
+ x = self.mean + self.std * torch.randn_like(self.mean)
45
+ return x
46
+
47
+ def kl(self, other=None, dims=(1, 2, 3)):
48
+ if self.deterministic:
49
+ return torch.Tensor([0.])
50
+ else:
51
+ if other is None:
52
+ return 0.5 * torch.mean(torch.pow(self.mean, 2)
53
+ + self.var - 1.0 - self.logvar,
54
+ dim=dims)
55
+ else:
56
+ return 0.5 * torch.mean(
57
+ torch.pow(self.mean - other.mean, 2) / other.var
58
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
59
+ dim=dims)
60
+
61
+ def nll(self, sample, dims=(1, 2, 3)):
62
+ if self.deterministic:
63
+ return torch.Tensor([0.])
64
+ logtwopi = np.log(2.0 * np.pi)
65
+ return 0.5 * torch.sum(
66
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
67
+ dim=dims)
68
+
69
+ def mode(self):
70
+ return self.mean
71
+
72
+
73
+ def normal_kl(mean1, logvar1, mean2, logvar2):
74
+ """
75
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
76
+ Compute the KL divergence between two gaussians.
77
+ Shapes are automatically broadcasted, so batches can be compared to
78
+ scalars, among other use cases.
79
+ """
80
+ tensor = None
81
+ for obj in (mean1, logvar1, mean2, logvar2):
82
+ if isinstance(obj, torch.Tensor):
83
+ tensor = obj
84
+ break
85
+ assert tensor is not None, "at least one argument must be a Tensor"
86
+
87
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
88
+ # Tensors, but it does not work for torch.exp().
89
+ logvar1, logvar2 = [
90
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
91
+ for x in (logvar1, logvar2)
92
+ ]
93
+
94
+ return 0.5 * (
95
+ -1.0
96
+ + logvar2
97
+ - logvar1
98
+ + torch.exp(logvar1 - logvar2)
99
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
100
+ )
michelangelo/models/modules/embedder.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+
8
+ VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"]
9
+
10
+
11
+ class FourierEmbedder(nn.Module):
12
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
13
+ each feature dimension of `x[..., i]` into:
14
+ [
15
+ sin(x[..., i]),
16
+ sin(f_1*x[..., i]),
17
+ sin(f_2*x[..., i]),
18
+ ...
19
+ sin(f_N * x[..., i]),
20
+ cos(x[..., i]),
21
+ cos(f_1*x[..., i]),
22
+ cos(f_2*x[..., i]),
23
+ ...
24
+ cos(f_N * x[..., i]),
25
+ x[..., i] # only present if include_input is True.
26
+ ], here f_i is the frequency.
27
+
28
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
29
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
30
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
31
+
32
+ Args:
33
+ num_freqs (int): the number of frequencies, default is 6;
34
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
35
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
36
+ input_dim (int): the input dimension, default is 3;
37
+ include_input (bool): include the input tensor or not, default is True.
38
+
39
+ Attributes:
40
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
41
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
42
+
43
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
44
+ otherwise, it is input_dim * num_freqs * 2.
45
+
46
+ """
47
+
48
+ def __init__(self,
49
+ num_freqs: int = 6,
50
+ logspace: bool = True,
51
+ input_dim: int = 3,
52
+ include_input: bool = True,
53
+ include_pi: bool = True) -> None:
54
+
55
+ """The initialization"""
56
+
57
+ super().__init__()
58
+
59
+ if logspace:
60
+ frequencies = 2.0 ** torch.arange(
61
+ num_freqs,
62
+ dtype=torch.float32
63
+ )
64
+ else:
65
+ frequencies = torch.linspace(
66
+ 1.0,
67
+ 2.0 ** (num_freqs - 1),
68
+ num_freqs,
69
+ dtype=torch.float32
70
+ )
71
+
72
+ if include_pi:
73
+ frequencies *= torch.pi
74
+
75
+ self.register_buffer("frequencies", frequencies, persistent=False)
76
+ self.include_input = include_input
77
+ self.num_freqs = num_freqs
78
+
79
+ self.out_dim = self.get_dims(input_dim)
80
+
81
+ def get_dims(self, input_dim):
82
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
83
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
84
+
85
+ return out_dim
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ """ Forward process.
89
+
90
+ Args:
91
+ x: tensor of shape [..., dim]
92
+
93
+ Returns:
94
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
95
+ where temp is 1 if include_input is True and 0 otherwise.
96
+ """
97
+
98
+ if self.num_freqs > 0:
99
+ embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
100
+ if self.include_input:
101
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
102
+ else:
103
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
104
+ else:
105
+ return x
106
+
107
+
108
+ class LearnedFourierEmbedder(nn.Module):
109
+ """ following @crowsonkb "s lead with learned sinusoidal pos emb """
110
+ """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
111
+
112
+ def __init__(self, in_channels, dim):
113
+ super().__init__()
114
+ assert (dim % 2) == 0
115
+ half_dim = dim // 2
116
+ per_channel_dim = half_dim // in_channels
117
+ self.weights = nn.Parameter(torch.randn(per_channel_dim))
118
+
119
+ def forward(self, x):
120
+ """
121
+
122
+ Args:
123
+ x (torch.FloatTensor): [..., c]
124
+
125
+ Returns:
126
+ x (torch.FloatTensor): [..., d]
127
+ """
128
+
129
+ # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d]
130
+ freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1)
131
+ fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1)
132
+ return fouriered
133
+
134
+
135
+ class TriplaneLearnedFourierEmbedder(nn.Module):
136
+ def __init__(self, in_channels, dim):
137
+ super().__init__()
138
+
139
+ self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
140
+ self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
141
+ self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim)
142
+
143
+ self.out_dim = in_channels + dim
144
+
145
+ def forward(self, x):
146
+
147
+ yz_embed = self.yz_plane_embedder(x)
148
+ xz_embed = self.xz_plane_embedder(x)
149
+ xy_embed = self.xy_plane_embedder(x)
150
+
151
+ embed = yz_embed + xz_embed + xy_embed
152
+
153
+ return embed
154
+
155
+
156
+ def sequential_pos_embed(num_len, embed_dim):
157
+ assert embed_dim % 2 == 0
158
+
159
+ pos = torch.arange(num_len, dtype=torch.float32)
160
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32)
161
+ omega /= embed_dim / 2.
162
+ omega = 1. / 10000 ** omega # (D/2,)
163
+
164
+ pos = pos.reshape(-1) # (M,)
165
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
166
+
167
+ emb_sin = torch.sin(out) # (M, D/2)
168
+ emb_cos = torch.cos(out) # (M, D/2)
169
+
170
+ embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
171
+
172
+ return embeddings
173
+
174
+
175
+ def timestep_embedding(timesteps, dim, max_period=10000):
176
+ """
177
+ Create sinusoidal timestep embeddings.
178
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
179
+ These may be fractional.
180
+ :param dim: the dimension of the output.
181
+ :param max_period: controls the minimum frequency of the embeddings.
182
+ :return: an [N x dim] Tensor of positional embeddings.
183
+ """
184
+ half = dim // 2
185
+ freqs = torch.exp(
186
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
187
+ ).to(device=timesteps.device)
188
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
189
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190
+ if dim % 2:
191
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
192
+ return embedding
193
+
194
+
195
+ def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4,
196
+ num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16,
197
+ log2_hashmap_size=19, desired_resolution=None):
198
+ if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1):
199
+ return nn.Identity(), input_dim
200
+
201
+ elif embed_type == "fourier":
202
+ embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim,
203
+ logspace=True, include_input=True)
204
+ return embedder_obj, embedder_obj.out_dim
205
+
206
+ elif embed_type == "hashgrid":
207
+ raise NotImplementedError
208
+
209
+ elif embed_type == "sphere_harmonic":
210
+ raise NotImplementedError
211
+
212
+ else:
213
+ raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}")
michelangelo/models/modules/transformer_blocks.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from typing import Optional
8
+
9
+ from michelangelo.models.modules.checkpoint import checkpoint
10
+
11
+
12
+ def init_linear(l, stddev):
13
+ nn.init.normal_(l.weight, std=stddev)
14
+ if l.bias is not None:
15
+ nn.init.constant_(l.bias, 0.0)
16
+
17
+
18
+ class MultiheadAttention(nn.Module):
19
+ def __init__(
20
+ self,
21
+ *,
22
+ device: torch.device,
23
+ dtype: torch.dtype,
24
+ n_ctx: int,
25
+ width: int,
26
+ heads: int,
27
+ init_scale: float,
28
+ qkv_bias: bool,
29
+ flash: bool = False
30
+ ):
31
+ super().__init__()
32
+ self.n_ctx = n_ctx
33
+ self.width = width
34
+ self.heads = heads
35
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
36
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
37
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash)
38
+ init_linear(self.c_qkv, init_scale)
39
+ init_linear(self.c_proj, init_scale)
40
+
41
+ def forward(self, x):
42
+ x = self.c_qkv(x)
43
+ x = checkpoint(self.attention, (x,), (), True)
44
+ x = self.c_proj(x)
45
+ return x
46
+
47
+
48
+ class QKVMultiheadAttention(nn.Module):
49
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False):
50
+ super().__init__()
51
+ self.device = device
52
+ self.dtype = dtype
53
+ self.heads = heads
54
+ self.n_ctx = n_ctx
55
+ self.flash = flash
56
+
57
+ def forward(self, qkv):
58
+ bs, n_ctx, width = qkv.shape
59
+ attn_ch = width // self.heads // 3
60
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
61
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
62
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
63
+
64
+ if self.flash:
65
+ out = F.scaled_dot_product_attention(q, k, v)
66
+ else:
67
+ weight = torch.einsum(
68
+ "bthc,bshc->bhts", q * scale, k * scale
69
+ ) # More stable with f16 than dividing afterwards
70
+ wdtype = weight.dtype
71
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
72
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
73
+
74
+ return out
75
+
76
+
77
+ class ResidualAttentionBlock(nn.Module):
78
+ def __init__(
79
+ self,
80
+ *,
81
+ device: torch.device,
82
+ dtype: torch.dtype,
83
+ n_ctx: int,
84
+ width: int,
85
+ heads: int,
86
+ init_scale: float = 1.0,
87
+ qkv_bias: bool = True,
88
+ flash: bool = False,
89
+ use_checkpoint: bool = False
90
+ ):
91
+ super().__init__()
92
+
93
+ self.use_checkpoint = use_checkpoint
94
+
95
+ self.attn = MultiheadAttention(
96
+ device=device,
97
+ dtype=dtype,
98
+ n_ctx=n_ctx,
99
+ width=width,
100
+ heads=heads,
101
+ init_scale=init_scale,
102
+ qkv_bias=qkv_bias,
103
+ flash=flash
104
+ )
105
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
106
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
107
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
108
+
109
+ def _forward(self, x: torch.Tensor):
110
+ x = x + self.attn(self.ln_1(x))
111
+ x = x + self.mlp(self.ln_2(x))
112
+ return x
113
+
114
+ def forward(self, x: torch.Tensor):
115
+ return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
116
+
117
+
118
+ class MultiheadCrossAttention(nn.Module):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ device: torch.device,
123
+ dtype: torch.dtype,
124
+ width: int,
125
+ heads: int,
126
+ init_scale: float,
127
+ qkv_bias: bool = True,
128
+ flash: bool = False,
129
+ n_data: Optional[int] = None,
130
+ data_width: Optional[int] = None,
131
+ ):
132
+ super().__init__()
133
+ self.n_data = n_data
134
+ self.width = width
135
+ self.heads = heads
136
+ self.data_width = width if data_width is None else data_width
137
+ self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
138
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
139
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
140
+ self.attention = QKVMultiheadCrossAttention(
141
+ device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash
142
+ )
143
+ init_linear(self.c_q, init_scale)
144
+ init_linear(self.c_kv, init_scale)
145
+ init_linear(self.c_proj, init_scale)
146
+
147
+ def forward(self, x, data):
148
+ x = self.c_q(x)
149
+ data = self.c_kv(data)
150
+ x = checkpoint(self.attention, (x, data), (), True)
151
+ x = self.c_proj(x)
152
+ return x
153
+
154
+
155
+ class QKVMultiheadCrossAttention(nn.Module):
156
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int,
157
+ flash: bool = False, n_data: Optional[int] = None):
158
+
159
+ super().__init__()
160
+ self.device = device
161
+ self.dtype = dtype
162
+ self.heads = heads
163
+ self.n_data = n_data
164
+ self.flash = flash
165
+
166
+ def forward(self, q, kv):
167
+ _, n_ctx, _ = q.shape
168
+ bs, n_data, width = kv.shape
169
+ attn_ch = width // self.heads // 2
170
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
171
+ q = q.view(bs, n_ctx, self.heads, -1)
172
+ kv = kv.view(bs, n_data, self.heads, -1)
173
+ k, v = torch.split(kv, attn_ch, dim=-1)
174
+
175
+ if self.flash:
176
+ out = F.scaled_dot_product_attention(q, k, v)
177
+ else:
178
+ weight = torch.einsum(
179
+ "bthc,bshc->bhts", q * scale, k * scale
180
+ ) # More stable with f16 than dividing afterwards
181
+ wdtype = weight.dtype
182
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
183
+ out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
184
+
185
+ return out
186
+
187
+
188
+ class ResidualCrossAttentionBlock(nn.Module):
189
+ def __init__(
190
+ self,
191
+ *,
192
+ device: Optional[torch.device],
193
+ dtype: Optional[torch.dtype],
194
+ n_data: Optional[int] = None,
195
+ width: int,
196
+ heads: int,
197
+ data_width: Optional[int] = None,
198
+ init_scale: float = 0.25,
199
+ qkv_bias: bool = True,
200
+ flash: bool = False
201
+ ):
202
+ super().__init__()
203
+
204
+ if data_width is None:
205
+ data_width = width
206
+
207
+ self.attn = MultiheadCrossAttention(
208
+ device=device,
209
+ dtype=dtype,
210
+ n_data=n_data,
211
+ width=width,
212
+ heads=heads,
213
+ data_width=data_width,
214
+ init_scale=init_scale,
215
+ qkv_bias=qkv_bias,
216
+ flash=flash,
217
+ )
218
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
219
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
220
+ self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
221
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
222
+
223
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
224
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
225
+ x = x + self.mlp(self.ln_3(x))
226
+ return x
227
+
228
+
229
+ class MLP(nn.Module):
230
+ def __init__(self, *,
231
+ device: Optional[torch.device],
232
+ dtype: Optional[torch.dtype],
233
+ width: int,
234
+ init_scale: float):
235
+ super().__init__()
236
+ self.width = width
237
+ self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
238
+ self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
239
+ self.gelu = nn.GELU()
240
+ init_linear(self.c_fc, init_scale)
241
+ init_linear(self.c_proj, init_scale)
242
+
243
+ def forward(self, x):
244
+ return self.c_proj(self.gelu(self.c_fc(x)))
245
+
246
+
247
+ class Transformer(nn.Module):
248
+ def __init__(
249
+ self,
250
+ *,
251
+ device: Optional[torch.device],
252
+ dtype: Optional[torch.dtype],
253
+ n_ctx: int,
254
+ width: int,
255
+ layers: int,
256
+ heads: int,
257
+ init_scale: float = 0.25,
258
+ qkv_bias: bool = True,
259
+ flash: bool = False,
260
+ use_checkpoint: bool = False
261
+ ):
262
+ super().__init__()
263
+ self.n_ctx = n_ctx
264
+ self.width = width
265
+ self.layers = layers
266
+ self.resblocks = nn.ModuleList(
267
+ [
268
+ ResidualAttentionBlock(
269
+ device=device,
270
+ dtype=dtype,
271
+ n_ctx=n_ctx,
272
+ width=width,
273
+ heads=heads,
274
+ init_scale=init_scale,
275
+ qkv_bias=qkv_bias,
276
+ flash=flash,
277
+ use_checkpoint=use_checkpoint
278
+ )
279
+ for _ in range(layers)
280
+ ]
281
+ )
282
+
283
+ def forward(self, x: torch.Tensor):
284
+ for block in self.resblocks:
285
+ x = block(x)
286
+ return x
michelangelo/models/modules/transformer_vit.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+ import warnings
8
+
9
+ from michelangelo.models.modules.checkpoint import checkpoint
10
+
11
+
12
+ def _trunc_normal_(tensor, mean, std, a, b):
13
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
14
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
15
+ def norm_cdf(x):
16
+ # Computes standard normal cumulative distribution function
17
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
18
+
19
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
20
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
21
+ "The distribution of values may be incorrect.",
22
+ stacklevel=2)
23
+
24
+ # Values are generated by using a truncated uniform distribution and
25
+ # then using the inverse CDF for the normal distribution.
26
+ # Get upper and lower cdf values
27
+ l = norm_cdf((a - mean) / std)
28
+ u = norm_cdf((b - mean) / std)
29
+
30
+ # Uniformly fill tensor with values from [l, u], then translate to
31
+ # [2l-1, 2u-1].
32
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
33
+
34
+ # Use inverse cdf transform for normal distribution to get truncated
35
+ # standard normal
36
+ tensor.erfinv_()
37
+
38
+ # Transform to proper mean, std
39
+ tensor.mul_(std * math.sqrt(2.))
40
+ tensor.add_(mean)
41
+
42
+ # Clamp to ensure it's in the proper range
43
+ tensor.clamp_(min=a, max=b)
44
+ return tensor
45
+
46
+
47
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
48
+ # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor
49
+ r"""Fills the input Tensor with values drawn from a truncated
50
+ normal distribution. The values are effectively drawn from the
51
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
52
+ with values outside :math:`[a, b]` redrawn until they are within
53
+ the bounds. The method used for generating the random values works
54
+ best when :math:`a \leq \text{mean} \leq b`.
55
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
56
+ applied while sampling the normal with mean/std applied, therefore a, b args
57
+ should be adjusted to match the range of mean, std args.
58
+ Args:
59
+ tensor: an n-dimensional `torch.Tensor`
60
+ mean: the mean of the normal distribution
61
+ std: the standard deviation of the normal distribution
62
+ a: the minimum cutoff value
63
+ b: the maximum cutoff value
64
+ Examples:
65
+ >>> w = torch.empty(3, 5)
66
+ >>> nn.init.trunc_normal_(w)
67
+ """
68
+ with torch.no_grad():
69
+ return _trunc_normal_(tensor, mean, std, a, b)
70
+
71
+
72
+ def init_weights(m):
73
+ if isinstance(m, nn.Linear):
74
+ trunc_normal_(m.weight, std=.02)
75
+ if isinstance(m, nn.Linear) and m.bias is not None:
76
+ nn.init.constant_(m.bias, 0)
77
+ elif isinstance(m, nn.LayerNorm):
78
+ nn.init.constant_(m.bias, 0)
79
+ nn.init.constant_(m.weight, 1.0)
80
+
81
+
82
+ class MultiheadAttention(nn.Module):
83
+ def __init__(
84
+ self,
85
+ *,
86
+ device: torch.device,
87
+ dtype: torch.dtype,
88
+ n_ctx: int,
89
+ width: int,
90
+ heads: int,
91
+ qkv_bias: bool
92
+ ):
93
+ super().__init__()
94
+ self.n_ctx = n_ctx
95
+ self.width = width
96
+ self.heads = heads
97
+ self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype)
98
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
99
+ self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx)
100
+
101
+ def forward(self, x):
102
+ x = self.c_qkv(x)
103
+ x = checkpoint(self.attention, (x,), (), True)
104
+ x = self.c_proj(x)
105
+ return x
106
+
107
+
108
+ class QKVMultiheadAttention(nn.Module):
109
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int):
110
+ super().__init__()
111
+ self.device = device
112
+ self.dtype = dtype
113
+ self.heads = heads
114
+ self.n_ctx = n_ctx
115
+
116
+ def forward(self, qkv):
117
+ bs, n_ctx, width = qkv.shape
118
+ attn_ch = width // self.heads // 3
119
+ scale = 1 / math.sqrt(attn_ch)
120
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
121
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
122
+ weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
123
+ wdtype = weight.dtype
124
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
125
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
126
+
127
+
128
+ class ResidualAttentionBlock(nn.Module):
129
+ def __init__(
130
+ self,
131
+ *,
132
+ device: torch.device,
133
+ dtype: torch.dtype,
134
+ n_ctx: int,
135
+ width: int,
136
+ heads: int,
137
+ qkv_bias: bool = True,
138
+ use_checkpoint: bool = False
139
+ ):
140
+ super().__init__()
141
+
142
+ self.use_checkpoint = use_checkpoint
143
+
144
+ self.attn = MultiheadAttention(
145
+ device=device,
146
+ dtype=dtype,
147
+ n_ctx=n_ctx,
148
+ width=width,
149
+ heads=heads,
150
+ qkv_bias=qkv_bias
151
+ )
152
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
153
+ self.mlp = MLP(device=device, dtype=dtype, width=width)
154
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
155
+
156
+ def _forward(self, x: torch.Tensor):
157
+ x = x + self.attn(self.ln_1(x))
158
+ x = x + self.mlp(self.ln_2(x))
159
+ return x
160
+
161
+ def forward(self, x: torch.Tensor):
162
+ return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
163
+
164
+
165
+ class MultiheadCrossAttention(nn.Module):
166
+ def __init__(
167
+ self,
168
+ *,
169
+ device: torch.device,
170
+ dtype: torch.dtype,
171
+ width: int,
172
+ heads: int,
173
+ qkv_bias: bool = True,
174
+ n_data: Optional[int] = None,
175
+ data_width: Optional[int] = None,
176
+ ):
177
+ super().__init__()
178
+ self.n_data = n_data
179
+ self.width = width
180
+ self.heads = heads
181
+ self.data_width = width if data_width is None else data_width
182
+ self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype)
183
+ self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype)
184
+ self.c_proj = nn.Linear(width, width, device=device, dtype=dtype)
185
+ self.attention = QKVMultiheadCrossAttention(
186
+ device=device, dtype=dtype, heads=heads, n_data=n_data
187
+ )
188
+
189
+ def forward(self, x, data):
190
+ x = self.c_q(x)
191
+ data = self.c_kv(data)
192
+ x = checkpoint(self.attention, (x, data), (), True)
193
+ x = self.c_proj(x)
194
+ return x
195
+
196
+
197
+ class QKVMultiheadCrossAttention(nn.Module):
198
+ def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None):
199
+ super().__init__()
200
+ self.device = device
201
+ self.dtype = dtype
202
+ self.heads = heads
203
+ self.n_data = n_data
204
+
205
+ def forward(self, q, kv):
206
+ _, n_ctx, _ = q.shape
207
+ bs, n_data, width = kv.shape
208
+ attn_ch = width // self.heads // 2
209
+ scale = 1 / math.sqrt(attn_ch)
210
+ q = q.view(bs, n_ctx, self.heads, -1)
211
+ kv = kv.view(bs, n_data, self.heads, -1)
212
+ k, v = torch.split(kv, attn_ch, dim=-1)
213
+ weight = torch.einsum("bthc,bshc->bhts", q, k) * scale
214
+ wdtype = weight.dtype
215
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
216
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
217
+
218
+
219
+ class ResidualCrossAttentionBlock(nn.Module):
220
+ def __init__(
221
+ self,
222
+ *,
223
+ device: Optional[torch.device],
224
+ dtype: Optional[torch.dtype],
225
+ n_data: Optional[int] = None,
226
+ width: int,
227
+ heads: int,
228
+ data_width: Optional[int] = None,
229
+ qkv_bias: bool = True
230
+ ):
231
+ super().__init__()
232
+
233
+ if data_width is None:
234
+ data_width = width
235
+
236
+ self.attn = MultiheadCrossAttention(
237
+ device=device,
238
+ dtype=dtype,
239
+ n_data=n_data,
240
+ width=width,
241
+ heads=heads,
242
+ data_width=data_width,
243
+ qkv_bias=qkv_bias
244
+ )
245
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
246
+ self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype)
247
+ self.mlp = MLP(device=device, dtype=dtype, width=width)
248
+ self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype)
249
+
250
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
251
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
252
+ x = x + self.mlp(self.ln_3(x))
253
+ return x
254
+
255
+
256
+ class MLP(nn.Module):
257
+ def __init__(self, *,
258
+ device: Optional[torch.device],
259
+ dtype: Optional[torch.dtype],
260
+ width: int):
261
+ super().__init__()
262
+ self.width = width
263
+ self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype)
264
+ self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype)
265
+ self.gelu = nn.GELU()
266
+
267
+ def forward(self, x):
268
+ return self.c_proj(self.gelu(self.c_fc(x)))
269
+
270
+
271
+ class Transformer(nn.Module):
272
+ def __init__(
273
+ self,
274
+ *,
275
+ device: Optional[torch.device],
276
+ dtype: Optional[torch.dtype],
277
+ n_ctx: int,
278
+ width: int,
279
+ layers: int,
280
+ heads: int,
281
+ qkv_bias: bool = True,
282
+ use_checkpoint: bool = False
283
+ ):
284
+ super().__init__()
285
+ self.n_ctx = n_ctx
286
+ self.width = width
287
+ self.layers = layers
288
+ self.resblocks = nn.ModuleList(
289
+ [
290
+ ResidualAttentionBlock(
291
+ device=device,
292
+ dtype=dtype,
293
+ n_ctx=n_ctx,
294
+ width=width,
295
+ heads=heads,
296
+ qkv_bias=qkv_bias,
297
+ use_checkpoint=use_checkpoint
298
+ )
299
+ for _ in range(layers)
300
+ ]
301
+ )
302
+
303
+ self.apply(init_weights)
304
+
305
+ def forward(self, x: torch.Tensor):
306
+ for block in self.resblocks:
307
+ x = block(x)
308
+ return x