beomgyu-kim commited on
Commit
6cd6a16
·
1 Parent(s): 0e812de

Add MotifVision model for text-to-image generation

Browse files

- Implemented the MotifVision class, combining a Diffusion transformer with rectified flow loss and multiple text encoders.
- Integrated Variational Autoencoder (VAE) for image encoding and decoding.
- Added methods for tokenization, text encoding, and sampling images with flow matching.
- Included support for multiple text encoders: T5 and CLIP models.
- Implemented functionality for handling LoRA parameters during state dict loading.
- Added evaluation method to compute loss during inference.

configs/configuration_mmdit.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import dataclass
3
+
4
+ ENCODED_TEXT_DIM = 4096
5
+ POOLED_TEXT_DIM = 2048
6
+ VAE_COMPRESSION_RATIO = 8
7
+
8
+
9
+ @dataclass
10
+ class MMDiTConfig:
11
+ # General
12
+ num_layers: int = 12
13
+ hidden_dim: int = 768 # common hidden dimension for the transformer arch
14
+ patch_size: int = 2
15
+ image_dim: int = 224
16
+ in_channel: int = 4
17
+ out_channel: int = 4
18
+ modulation_dim: int = ENCODED_TEXT_DIM # input dimension for modulation layer (shifting and scaling)
19
+ height: int = 1024
20
+ width: int = 1024
21
+ vae_compression: int = VAE_COMPRESSION_RATIO # reducing resolution with the VAE
22
+ vae_type: str = "SD3" # SDXL or SD3
23
+ pos_emb_size: int = None
24
+ conv_header: bool = False
25
+
26
+ # Outside of the MMDiT block
27
+ time_embed_dim: int = 2048 # Initial projection (discrete_time embedding) output dim
28
+ pooled_text_dim: int = POOLED_TEXT_DIM
29
+ text_emb_dim: int = 768
30
+
31
+ # MMDiTBlock
32
+ t_emb_dim: int = 256
33
+ attn_embed_dim: int = 768 # hidden dimension during the attention
34
+ mlp_hidden_dim: int = 2048
35
+ attn_mode: str = None # {'flash', 'sdpa', None}
36
+ use_final_layer_norm: bool = False
37
+ use_time_token_in_attn: bool = False
38
+
39
+ # GroupedQueryAttention
40
+ num_attention_heads: int = 12
41
+ num_key_value_heads: int = 6
42
+ use_scaled_dot_product_attention: bool = True
43
+ dropout: float = 0.0
44
+
45
+ # Modulation
46
+ use_modulation: bool = True
47
+ modulation_type: str = "film" # Choose from 'film', 'adain', or 'spade'
48
+
49
+ # Register tokens
50
+ register_token_num: int = 4
51
+ additional_register_token_num: int = 12
52
+
53
+ # use dinov2 feature-align loss
54
+ dinov2_feature_align_loss: bool = False
55
+ feature_align_loss_weight: float = 0.5
56
+ num_feature_align_layers: int = 8 # number of transformer layers to calculate feature-align loss
57
+
58
+ # Personalization related
59
+ image_encoder_name: str = None # if set, the persoanlized image encoder will be loaded
60
+ freeze_dit_backbone: bool = False
61
+
62
+ # Preference optimization
63
+ preference_train: bool = False
64
+ lora_rank: int = 64
65
+ lora_alpha: int = 8
66
+
67
+ skip_register_token_num: int = 0
68
+
69
+ @classmethod
70
+ def from_json_file(cls, json_file):
71
+ """
72
+ Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
73
+
74
+ Args:
75
+ json_file (`str` or `os.PathLike`):
76
+ Path to the JSON file containing the parameters.
77
+
78
+ Returns:
79
+ [`PretrainedConfig`]: The configuration object instantiated from that JSON file.
80
+
81
+ """
82
+ config_dict = cls._dict_from_json_file(json_file)
83
+ return cls(**config_dict)
84
+
85
+ @classmethod
86
+ def _dict_from_json_file(cls, json_file):
87
+ with open(json_file, "r", encoding="utf-8") as reader:
88
+ text = reader.read()
89
+ return json.loads(text)
configs/mmdit_xlarge_hq.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_layers": 30,
3
+ "hidden_dim": 1920,
4
+ "patch_size": 2,
5
+ "in_channel": 4,
6
+ "out_channel": 4,
7
+ "time_embed_dim": 4096,
8
+ "attn_embed_dim": 4096,
9
+ "num_attention_heads": 30,
10
+ "num_key_value_heads": 30,
11
+ "use_scaled_dot_product_attention": true,
12
+ "dropout": 0.0,
13
+ "mlp_hidden_dim": 7680,
14
+ "use_modulation": true,
15
+ "modulation_type": "film",
16
+ "register_token_num": 4,
17
+ "additional_register_token_num": 0,
18
+ "skip_register_token_num": 0,
19
+ "height": 1024,
20
+ "width": 1024,
21
+ "attn_mode": "flash",
22
+ "use_final_layer_norm": false,
23
+ "pos_emb_size": 64,
24
+ "conv_header": false,
25
+ "use_time_token_in_attn": true
26
+ }
models/mixin/encoder_mixin.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models import AutoencoderKL
3
+
4
+ from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer
5
+
6
+
7
+ class EncoderMixin:
8
+ """Mixin class for handling various encoders in the MotifDiT model.
9
+
10
+ This mixin provides functionality for:
11
+ 1. Loading and initializing encoders (VAE, T5, CLIP-L, CLIP-G)
12
+ 2. Text tokenization and encoding
13
+ 3. Managing encoder parameters and state
14
+ """
15
+
16
+ TOKEN_MAX_LENGTH: int = 256
17
+
18
+ def prepare_embeddings(
19
+ self,
20
+ images: torch.Tensor,
21
+ raw_text: list[str],
22
+ vae: AutoencoderKL,
23
+ t5: T5EncoderModel,
24
+ clip_l: CLIPTextModel,
25
+ clip_g: CLIPTextModel,
26
+ t5_tokenizer: T5Tokenizer,
27
+ clip_l_tokenizer: CLIPTokenizerFast,
28
+ clip_g_tokenizer: CLIPTokenizerFast,
29
+ is_training,
30
+ ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
31
+ """Prepare image latents and text embeddings for model input.
32
+
33
+ Args:
34
+ images (torch.Tensor): Input images tensor with shape [B, C=3, H, W].
35
+ raw_text (List[str]): List of raw text strings with length B.
36
+ """
37
+ with torch.no_grad():
38
+ latents: torch.Tensor = (
39
+ vae.encode(images).latent_dist.sample() - vae.config.shift_factor
40
+ ) * vae.config.scaling_factor # Latents shape: [B, 16, H//8, W//8]
41
+
42
+ # Tokenize the input text and move tokens and masks to the same device as latents
43
+ tokenizers = [t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer]
44
+ tokens, masks = self.tokenization(raw_text, tokenizers)
45
+ tokens = [token.to(latents.device) for token in tokens]
46
+ masks = [mask.to(latents.device) for mask in masks]
47
+
48
+ # Encode the text and drop unnecessary embeddings
49
+ text_embeddings, pooled_text_embeddings = self.text_encoding(
50
+ tokens,
51
+ masks,
52
+ t5,
53
+ clip_l,
54
+ clip_g,
55
+ t5_tokenizer.pad_token_id,
56
+ clip_l_tokenizer.eos_token_id,
57
+ clip_g_tokenizer.eos_token_id,
58
+ is_training,
59
+ )
60
+ text_embeddings = self.drop_text_emb(text_embeddings)
61
+
62
+ # Convert text embeddings to float
63
+ text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
64
+
65
+ # Convert pooled text embeddings to float
66
+ pooled_text_embeddings = pooled_text_embeddings.float()
67
+
68
+ return latents, text_embeddings, pooled_text_embeddings
69
+
70
+ def get_freezed_encoders_and_tokenizers(
71
+ self, vae_type: str
72
+ ) -> tuple[
73
+ AutoencoderKL, T5EncoderModel, CLIPTextModel, CLIPTextModel, T5Tokenizer, CLIPTokenizerFast, CLIPTokenizerFast
74
+ ]:
75
+ """Initialize the VAE and text encoders."""
76
+ if vae_type != "SD3":
77
+ raise ValueError(
78
+ f"VAE type must be `SD3` but self.config.vae_type is {vae_type}."
79
+ f" note that the VAE type SDXL is deprecated."
80
+ )
81
+
82
+ vae: AutoencoderKL = AutoencoderKL.from_pretrained(
83
+ "stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae"
84
+ )
85
+
86
+ # Text encoders
87
+ # 1. T5-XXL from Google
88
+ t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
89
+ t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
90
+
91
+ # 2. CLIP-L from OpenAI
92
+ clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
93
+ clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
94
+
95
+ # 3. CLIP-G from LAION
96
+ clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
97
+ clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
98
+
99
+ # Freeze all encoders
100
+
101
+ for encoder_module in [vae, clip_l, clip_g, t5]:
102
+ for param in encoder_module.parameters():
103
+ param.requires_grad = False
104
+
105
+ return vae, t5, clip_l, clip_g, t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer
106
+
107
+ def tokenization(
108
+ self, raw_text: list[str], tokenizers: list[T5Tokenizer | CLIPTokenizerFast]
109
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
110
+ """Tokenize the input text using multiple tokenizers.
111
+
112
+ Args:
113
+ raw_text (str): Input text string.
114
+
115
+ Returns:
116
+ Tuple[List[torch.Tensor], List[torch.Tensor]]: Lists of tokenized text tensors and attention masks.
117
+ """
118
+ tokens, masks = [], []
119
+ for tokenizer in tokenizers:
120
+ tok = tokenizer(
121
+ raw_text,
122
+ padding="max_length",
123
+ max_length=min(EncoderMixin.TOKEN_MAX_LENGTH, tokenizer.model_max_length),
124
+ return_tensors="pt",
125
+ truncation=True,
126
+ )
127
+ tokens.append(tok.input_ids)
128
+ masks.append(tok.attention_mask)
129
+ return tokens, masks
130
+
131
+ @torch.no_grad()
132
+ def text_encoding(
133
+ self,
134
+ tokens: list[torch.Tensor],
135
+ masks: list[torch.Tensor],
136
+ t5: T5EncoderModel,
137
+ clip_l: CLIPTextModel,
138
+ clip_g: CLIPTextModel,
139
+ t5_pad_token_id: int = 0,
140
+ clip_l_tokenizer_eos_token_id: int = 49407,
141
+ clip_g_tokenizer_eos_token_id: int = 49407,
142
+ is_training: bool = False,
143
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
144
+ """Encode the tokenized text using multiple text encoders.
145
+
146
+ Args:
147
+ tokens (List[torch.Tensor]): List of tokenized text tensors.
148
+ masks (List[torch.Tensor]): List of attention masks.
149
+
150
+ Returns:
151
+ Tuple[List[torch.Tensor], torch.Tensor]: Text embeddings and pooled text embeddings.
152
+ """
153
+ t5_tokens, clip_l_tokens, clip_g_tokens = tokens
154
+ t5_masks, _, _ = masks
155
+
156
+ # T5 encoding
157
+ t5_emb = t5(t5_tokens, attention_mask=t5_masks)[0]
158
+ t5_emb = t5_emb * (t5_tokens != t5_pad_token_id).unsqueeze(-1)
159
+
160
+ # CLIP encodings
161
+ clip_l_emb = clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
162
+ clip_g_emb = clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
163
+
164
+ # Get pooled outputs
165
+ clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768
166
+ clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280
167
+
168
+ if is_training:
169
+ clip_l_emb_pooled = self.drop_text_emb(clip_l_emb_pooled)
170
+ clip_g_emb_pooled = self.drop_text_emb(clip_g_emb_pooled)
171
+
172
+ clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768
173
+ clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280
174
+
175
+ def masking_wo_first_eos(token, eos):
176
+ """Create attention mask without first EOS token."""
177
+ idx = (token != eos).sum(dim=1)
178
+ mask = token != eos
179
+ arange = torch.arange(mask.size(0)).cuda()
180
+ if idx != len(mask[0]):
181
+ mask[arange, idx] = True
182
+ return mask.unsqueeze(-1) # B x L x 1
183
+
184
+ # Apply masking
185
+ clip_l_emb = clip_l_emb * masking_wo_first_eos(clip_l_tokens, clip_l_tokenizer_eos_token_id)
186
+ clip_g_emb = clip_g_emb * masking_wo_first_eos(clip_g_tokens, clip_g_tokenizer_eos_token_id)
187
+
188
+ encodings = [t5_emb, clip_l_emb, clip_g_emb]
189
+ pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # B x 2048
190
+
191
+ return encodings, pooled_encodings
192
+
193
+ @torch.no_grad()
194
+ def drop_text_emb(
195
+ self, text_embeddings: list[torch.Tensor] | torch.Tensor, drop_prob: float = 0.464
196
+ ) -> list[torch.Tensor] | torch.Tensor:
197
+ """Randomly drop text embeddings with a specified probability.
198
+
199
+ Args:
200
+ text_embeddings (Union[List[torch.Tensor], torch.Tensor]): Text embeddings to be dropped.
201
+ drop_prob (float, optional): Probability of dropping text embeddings. Defaults to 0.464.
202
+
203
+ Returns:
204
+ Union[List[torch.Tensor], torch.Tensor]: Text embeddings with dropped elements.
205
+ """
206
+ if isinstance(text_embeddings, list):
207
+ # For BxLxC features
208
+ for text_embedding in text_embeddings:
209
+ probs = torch.ones((text_embedding.shape[0])).cuda() * (1 - drop_prob)
210
+ masks = torch.bernoulli(probs).cuda()
211
+ while len(masks.shape) < len(text_embedding.shape):
212
+ masks = masks.unsqueeze(-1)
213
+ text_embedding = text_embedding * masks
214
+ else:
215
+ # For a pooled BxC feature
216
+ probs = torch.ones((text_embeddings.shape[0])).cuda() * (1 - drop_prob)
217
+ masks = torch.bernoulli(probs).cuda()
218
+ while len(masks.shape) < len(text_embeddings.shape):
219
+ masks = masks.unsqueeze(-1)
220
+ text_embeddings = text_embeddings * masks
221
+
222
+ return text_embeddings
models/mixin/flow_mixin.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dongpin Oh: [email protected]
3
+ """
4
+
5
+ import time
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from loguru import logger
12
+
13
+
14
+ class FlowMixin:
15
+ """Mixin class for flow-based models."""
16
+
17
+ _is_noise_initialized = False
18
+ MIN_STD: float = 0.0 # minimum size of std for the flow matching
19
+ CLAMP_CONTINUOUS_TIME: float = 0.0
20
+
21
+ def _timestep_shifting(self, timesteps: torch.Tensor, alpha: float = 3.0) -> torch.Tensor:
22
+ """
23
+ Adjust the timesteps for higher resolution images by adding more noise.
24
+ higher resolution have more pixels we need more noise to destory their signal.
25
+
26
+ NOTE that the timestep must be reversed unlike original SD3 style timestep shifting,
27
+ since the flow-timestep is reversed (t=0: original image; t=1: pure noise)
28
+ Args:
29
+ timesteps (torch.Tensor): The original timesteps.
30
+ alpha (float, optional): Scaling factor for timestep shifting. Defaults to 3.0.
31
+
32
+ Returns:
33
+ torch.Tensor: The reversed and shifted timesteps.
34
+ """
35
+ shifted_t = alpha * timesteps / (1 + (alpha - 1) * timesteps)
36
+ reversed_t = 1 - shifted_t
37
+ return reversed_t
38
+
39
+ def get_noisy_input(
40
+ self,
41
+ input: torch.Tensor,
42
+ normal_mean: float = 0.0,
43
+ normal_std: float = 1.0,
44
+ is_finetuning: bool = False,
45
+ t: torch.Tensor | None = None,
46
+ n_timesteps: int = 1000,
47
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
48
+ """
49
+ Generate noisy input, based on the optimal-transport flow.
50
+
51
+ Args:
52
+ input (torch.Tensor): Input tensor.
53
+ normal_mean (float, optional): Mean of the normal distribution for noise. Defaults to 0.0.
54
+ normal_std (float, optional): Standard deviation of the normal distribution for noise. Defaults to 1.0.
55
+ is_finetuning (bool, optional): Whether the model is in finetuning mode. Defaults to False.
56
+ t (torch.Tensor | None, optional): Predefined timesteps. If None, timesteps are sampled. Defaults to None.
57
+ n_timesteps (int, opyionsl): Number of discrete timesteps. Defaults to 1000.
58
+
59
+ Returns:
60
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing noise, noisy input, and timestep.
61
+ - The timestep t is a continuous timestep ranged [0, 1] which cannot directly be used
62
+ for the timestep embedding (needs to be discretized by self.discritize_timestep()).
63
+ """
64
+ b = input.shape[0]
65
+
66
+ if not FlowMixin._is_noise_initialized:
67
+ logger.warning("The torch random seed is changed when generating the initial random noise.")
68
+ current_time = int(time.time() * 1000)
69
+ torch.manual_seed(current_time)
70
+ FlowMixin._is_noise_initialized = True
71
+
72
+ noise = randn_tensor(input.shape).cuda()
73
+
74
+ # Sample timestep from a log-normal distribution with mean 0 and std 1
75
+ if t is None:
76
+ # NOTE: timestep is sampled from log-normal distribution to make the model
77
+ # focus on the intermediate timesteps, which are the most informative part of flow-ODE
78
+ t = torch.randn(b).cuda() * normal_std + normal_mean
79
+ t = torch.sigmoid(t)
80
+ else:
81
+ t = t
82
+
83
+ # Clamp t to be within the interval [0, 1] for numerical stability
84
+ t = t.clamp(0 + self.CLAMP_CONTINUOUS_TIME, 1 - self.CLAMP_CONTINUOUS_TIME)
85
+
86
+ if is_finetuning:
87
+ t = self._timestep_shifting(t)
88
+
89
+ # Reshape t to match the dimensions required
90
+ for _ in range(len(noise.shape) - 1):
91
+ t = t.unsqueeze(1)
92
+
93
+ # Generate the noisy input
94
+ noisy_input = (1 - t) * input + (self.MIN_STD + (1 - self.MIN_STD) * t) * noise
95
+
96
+ t_squeezed = t.squeeze()
97
+ if t_squeezed.dim() == 0:
98
+ t_squeezed = t_squeezed.unsqueeze(0)
99
+ return noise, noisy_input, t_squeezed
100
+
101
+ @torch.no_grad()
102
+ def _logit_norm(self, t: torch.Tensor, m: float = 0, s: float = 1) -> torch.Tensor:
103
+ """
104
+ Compute the loss-weight for the flow-matching loss.
105
+ It will be focusing (giving high weights) on the intermidate timestep, since
106
+ such timesteps are hard to be matched, according to https://arxiv.org/pdf/2403.03206.pdf
107
+ Args:
108
+ t (torch.Tensor): Timestep tensor. (0 to 1)
109
+ m (float, optional): Mean of the logit distribution. Defaults to 0.
110
+ s (float, optional): Standard deviation of the logit distribution. Defaults to 1.
111
+
112
+ Returns:
113
+ torch.Tensor: Weight tensor for the flow-matching loss.
114
+ """
115
+ coef = (1 / (s * ((2 * np.pi) ** 0.5))) * (1 / (t * (1 - t)))
116
+
117
+ def logit(x):
118
+ return torch.log(x) - torch.log(1 - x)
119
+
120
+ exp = torch.exp(-((logit(t) - m) ** 2) / (2 * s**2))
121
+ return coef * exp
122
+
123
+ def rectified_flow_loss(
124
+ self,
125
+ input: torch.Tensor,
126
+ noise: torch.Tensor,
127
+ t: torch.Tensor,
128
+ preds: torch.Tensor,
129
+ use_weighting: bool = False,
130
+ reduce: str = "mean",
131
+ ) -> torch.Tensor:
132
+ """
133
+ Compute the rectified flow loss, https://arxiv.org/pdf/2403.03206.pdf
134
+
135
+ Args:
136
+ input (torch.Tensor): Input tensor.
137
+ noise (torch.Tensor): Noise tensor.
138
+ t (torch.Tensor): Timestep tensor.
139
+ preds (torch.Tensor): Predicted tensor.
140
+ use_weighting (bool, optional): Whether to use weighting for the loss. Defaults to False.
141
+ reduce (str, optional): Reduction method for the loss. Options are 'mean' or 'none'. Defaults to 'mean'.
142
+
143
+ Returns:
144
+ torch.Tensor: Rectified flow loss.
145
+ """
146
+
147
+ # Matching dimension for broadcasting
148
+ t = t.reshape(t.shape[0], *[1 for _ in range(len(input.shape) - len(t.shape))])
149
+
150
+ target_flow = (1 - self.MIN_STD) * noise - input
151
+ loss = F.mse_loss(preds.float(), target_flow.float(), reduction="none")
152
+ if use_weighting:
153
+ weight = self._logit_norm(t).detach()
154
+ loss = loss * weight
155
+ if reduce == "mean":
156
+ loss = loss.mean()
157
+ elif reduce == "none":
158
+ loss = loss
159
+ else:
160
+ raise NotImplementedError
161
+
162
+ return loss
163
+
164
+ def discritize_timestep(self, t: torch.Tensor, n_timesteps: int = 1000) -> torch.Tensor:
165
+ """
166
+ Discretize the continuous timestep.
167
+
168
+ Args:
169
+ t (torch.Tensor): Continuous timestep.
170
+ n_timesteps (int, optional): Number of discrete timesteps. Defaults to 1000.
171
+
172
+ Returns:
173
+ torch.Tensor: Discretized timestep tensor.
174
+ """
175
+ return (t * n_timesteps).round().long()
models/modeling_dit.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
9
+ from loguru import logger
10
+
11
+ try:
12
+ motif_ops = torch.ops.motif
13
+ MotifRMSNorm = motif_ops.T5LayerNorm
14
+ ScaledDotProductAttention = None
15
+ MotifFlashAttention = motif_ops.flash_attention
16
+ except ImportError: # if motif_ops is not available
17
+ MotifRMSNorm = None
18
+ ScaledDotProductAttention = None
19
+ MotifFlashAttention = None
20
+
21
+ NUM_MODULATIONS = 6
22
+ SD3_LATENT_CHANNEL = 16
23
+ LOW_RES_POSEMB_BASE_SIZE = 16
24
+ HIGH_RES_POSEMB_BASE_SIZE = 64
25
+
26
+ class IdentityConv2d(nn.Module):
27
+ def __init__(self, channels, kernel_size=3, stride=1, padding=1, bias=True):
28
+ super().__init__()
29
+
30
+ self.conv = nn.Conv2d(
31
+ in_channels=channels,
32
+ out_channels=channels,
33
+ kernel_size=kernel_size,
34
+ stride=stride,
35
+ padding=padding,
36
+ bias=bias,
37
+ )
38
+
39
+ self._initialize_identity()
40
+
41
+ def _initialize_identity(self):
42
+ k = self.conv.kernel_size[0]
43
+
44
+ nn.init.zeros_(self.conv.weight)
45
+
46
+ center = k // 2
47
+ for i in range(self.conv.in_channels):
48
+ self.conv.weight.data[i, i, center, center] = 1.0
49
+
50
+ if self.conv.bias is not None:
51
+ nn.init.zeros_(self.conv.bias)
52
+
53
+ def forward(self, x):
54
+ return self.conv(x)
55
+
56
+
57
+ class RMSNorm(nn.Module):
58
+ def __init__(self, hidden_size, eps=1e-6):
59
+ """
60
+ LlamaRMSNorm is equivalent to T5LayerNorm
61
+ """
62
+ super().__init__()
63
+ self.weight = nn.Parameter(torch.ones(hidden_size))
64
+ self.variance_epsilon = eps
65
+ self.mask = None
66
+
67
+ def forward(self, hidden_states):
68
+ input_dtype = hidden_states.dtype
69
+ hidden_states = hidden_states.to(torch.float)
70
+ if self.mask is not None:
71
+ hidden_states = self.mask.to(hidden_states.device).to(hidden_states.dtype) * hidden_states
72
+ variance = hidden_states.pow(2).sum(-1, keepdim=True)
73
+ if self.mask is not None:
74
+ variance /= torch.count_nonzero(self.mask)
75
+ else:
76
+ variance /= hidden_states.shape[-1]
77
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
78
+ return self.weight * hidden_states.to(input_dtype)
79
+
80
+
81
+ class MLP(nn.Module):
82
+ def __init__(self, input_size, hidden_size=None):
83
+ super().__init__()
84
+ if hidden_size is None:
85
+ self.input_size, self.hidden_size = input_size, input_size * 4
86
+ else:
87
+ self.input_size, self.hidden_size = input_size, hidden_size
88
+
89
+ self.gate_proj = nn.Linear(self.input_size, self.hidden_size)
90
+ self.down_proj = nn.Linear(self.hidden_size, self.input_size)
91
+
92
+ self.act_fn = nn.SiLU()
93
+
94
+ def forward(self, x):
95
+ down_proj = self.act_fn(self.gate_proj(x))
96
+ down_proj = self.down_proj(down_proj)
97
+
98
+ return down_proj
99
+
100
+
101
+ class TextTimeEmbToGlobalParams(nn.Module):
102
+ def __init__(self, emb_dim, hidden_dim):
103
+ super().__init__()
104
+ self.projection = nn.Linear(emb_dim, hidden_dim * NUM_MODULATIONS)
105
+
106
+ def forward(self, emb):
107
+ emb = F.silu(emb) # emb: B x D
108
+ params = self.projection(emb) # emb: B x C
109
+ params = params.reshape(params.shape[0], NUM_MODULATIONS, params.shape[-1] // NUM_MODULATIONS) # emb: B x 6 x C
110
+ return params.chunk(6, dim=1) # [B x 1 x C] x 6
111
+
112
+
113
+ class TextTimeEmbedding(nn.Module):
114
+ """
115
+ Input:
116
+ pooled_text_emb (B x C_l)
117
+ time_steps (B)
118
+
119
+ Output:
120
+ ()
121
+ """
122
+
123
+ def __init__(self, time_channel, text_channel, embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0):
124
+ super().__init__()
125
+ self.time_proj = Timesteps(
126
+ time_channel, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift
127
+ )
128
+ self.time_emb = TimestepEmbedding(time_channel, time_channel * 4, out_dim=embed_dim) # Encode time emb with MLP
129
+ self.pooled_text_emb = TimestepEmbedding(
130
+ text_channel, text_channel * 4, out_dim=embed_dim
131
+ ) # Encode pooled text with MLP
132
+
133
+ def forward(self, pooled_text_emb, time_steps):
134
+ time_steps = self.time_proj(time_steps)
135
+ time_emb = self.time_emb(time_steps.to(dtype=torch.bfloat16))
136
+ pooled_text_emb = self.pooled_text_emb(pooled_text_emb)
137
+
138
+ return time_emb + pooled_text_emb
139
+
140
+
141
+ class LatentPatchModule(nn.Module):
142
+ def __init__(self, patch_size, embedding_dim, latent_channels, vae_type):
143
+ super().__init__()
144
+ self.patch_size = patch_size
145
+ self.embedding_dim = embedding_dim
146
+ self.projection_SD3 = nn.Conv2d(SD3_LATENT_CHANNEL, embedding_dim, kernel_size=patch_size, stride=patch_size)
147
+ self.latent_channels = latent_channels
148
+
149
+ def forward(self, x):
150
+ assert (
151
+ x.shape[1] == SD3_LATENT_CHANNEL
152
+ ), f"VAE-Latent channel is not matched with '{SD3_LATENT_CHANNEL}'. current shape: {x.shape}"
153
+ patches = self.projection_SD3(
154
+ x.to(dtype=torch.bfloat16)
155
+ ) # Shape: (B, embedding_dim, num_patches_h, num_patches_w)
156
+ patches = patches.to(dtype=torch.bfloat16)
157
+ patches = patches.contiguous()
158
+ patches = patches.flatten(2) # Shape: (B, embedding_dim, num_patches)
159
+
160
+ patches = patches.transpose(1, 2) # Shape: (B, num_patches, embedding_dim)
161
+ patches = patches.contiguous()
162
+ return patches
163
+
164
+ def unpatchify(self, x):
165
+ """
166
+ x: (N, T, patch_size**2 * C)
167
+ imgs: (N, H, W, C)
168
+ """
169
+ n = x.shape[0]
170
+ c = self.latent_channels
171
+ p = self.patch_size
172
+
173
+ # check the valid patching
174
+ h = w = int(x.shape[1] ** 0.5)
175
+ assert h * w == x.shape[1]
176
+
177
+ x = x.contiguous()
178
+ # (N x T x [C * patch_size**2]) -> (N x H x W x P_1 x P_2 x C)
179
+ x = x.reshape(shape=(n, h, w, p, p, c))
180
+ # x = torch.einsum('nhwpqc->nchpwq', x) # Note that einsum possibly be the problem.
181
+
182
+ # (N x H x W x P_1 x P_2 x C) -> (N x C x H x P_1 x W x P_2)
183
+ # (0 . 1 . 2 . 3 . 4 . 5) -> (0 . 5 . 1 . 3 2 . 4 )
184
+ x = x.permute(0, 5, 1, 3, 2, 4)
185
+ return x.reshape(shape=(n, c, h * p, h * p)).contiguous()
186
+
187
+
188
+ class TextConditionModule(nn.Module):
189
+ def __init__(self, text_dim, latent_dim):
190
+ super().__init__()
191
+ self.projection = nn.Linear(text_dim, latent_dim)
192
+
193
+ def forward(self, t5_xxl, clip_a, clip_b):
194
+ clip_emb = torch.cat([clip_a, clip_b], dim=-1)
195
+ clip_emb = torch.nn.functional.pad(clip_emb, (0, t5_xxl.shape[-1] - clip_emb.shape[-1]))
196
+ text_emb = torch.cat([clip_emb, t5_xxl], dim=-2)
197
+ text_emb = self.projection(text_emb.to(torch.bfloat16))
198
+ return text_emb
199
+
200
+
201
+ class MotifDiTBlock(nn.Module):
202
+ def __init__(self, emb_dim, t_emb_dim, attn_emb_dim, mlp_dim, attn_config, text_dim=4096):
203
+ super().__init__()
204
+ self.affine_params_c = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim)
205
+ self.affine_params_x = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim)
206
+
207
+ self.norm_1_c = nn.LayerNorm(emb_dim, elementwise_affine=False)
208
+ self.norm_1_x = nn.LayerNorm(emb_dim, elementwise_affine=False)
209
+ self.linear_1_c = nn.Linear(emb_dim, attn_emb_dim)
210
+ self.linear_1_x = nn.Linear(emb_dim, attn_emb_dim)
211
+
212
+ self.attn = JointAttn(attn_config)
213
+ self.norm_2_c = nn.LayerNorm(emb_dim, elementwise_affine=False)
214
+ self.norm_2_x = nn.LayerNorm(emb_dim, elementwise_affine=False)
215
+ self.mlp_3_c = MLP(emb_dim, mlp_dim)
216
+ self.mlp_3_x = MLP(emb_dim, mlp_dim)
217
+
218
+ def forward(self, x_emb, c_emb, t_emb, perturbed=False):
219
+ """
220
+ x_emb (N, TOKEN_LENGTH x 2, C)
221
+ c_emb (N, T + REGISTER_TOKENS, C)
222
+ t_emb (N, modulation_dim)
223
+ """
224
+
225
+ device = x_emb.device
226
+
227
+ # get global affine transformation parameters
228
+ alpha_x, beta_x, gamma_x, delta_x, epsilon_x, zeta_x = self.affine_params_x(t_emb) # scale and shift for image
229
+ alpha_c, beta_c, gamma_c, delta_c, epsilon_c, zeta_c = self.affine_params_c(t_emb) # scale and shift for text
230
+
231
+ # projection and affine transform before attention
232
+ x_emb_pre_attn = self.linear_1_x((1 + alpha_x) * self.norm_1_x(x_emb) + beta_x)
233
+ c_emb_pre_attn = self.linear_1_c((1 + alpha_c) * self.norm_1_c(c_emb) + beta_c)
234
+
235
+ # attn_output, attn_weight (None), past_key_value (None)
236
+ x_emb_post_attn, c_emb_post_attn = self.attn(
237
+ x_emb_pre_attn, c_emb_pre_attn, perturbed
238
+ ) # mixed feature for both text and image (N, [T_x + T_c], C)
239
+
240
+ # scale with gamma and residual with the original inputs
241
+ x_emb_post_attn = x_emb_post_attn.to(gamma_x.device)
242
+ x_emb_post_attn = (1 + gamma_x) * x_emb_post_attn + x_emb # NOTE: nan loss for self.linear_2_x.bias
243
+ c_emb_post_attn = c_emb_post_attn.to(gamma_c.device)
244
+ c_emb_post_attn = (1 + gamma_c) * c_emb_post_attn + c_emb
245
+
246
+ # norm the features -> affine transform with modulation -> MLP
247
+ normalized_x_emb = self.norm_2_x(x_emb_post_attn).to(delta_x.device)
248
+ normalized_c_emb = self.norm_2_c(c_emb_post_attn).to(delta_c.device)
249
+ x_emb_final = self.mlp_3_x(delta_x * normalized_x_emb + epsilon_x)
250
+ c_emb_final = self.mlp_3_c(delta_c * normalized_c_emb + epsilon_c)
251
+
252
+ # final scaling with zeta and residual with the original inputs
253
+ x_emb_final = zeta_x.to(device) * x_emb_final.to(device) + x_emb.to(device)
254
+ c_emb_final = zeta_c.to(device) * c_emb_final.to(device) + c_emb.to(device)
255
+
256
+ return x_emb_final, c_emb_final
257
+
258
+
259
+ class MotifDiT(nn.Module):
260
+ ENCODED_TEXT_DIM = 4096
261
+
262
+ def __init__(self, config):
263
+ super(MotifDiT, self).__init__()
264
+ self.patch_size = config.patch_size
265
+ self.h, self.w = config.height // config.vae_compression, config.width // config.vae_compression
266
+
267
+ self.latent_chennels = 16
268
+
269
+ # Embedding for (1) text; (2) input image; (3) time
270
+ self.text_cond = TextConditionModule(self.ENCODED_TEXT_DIM, config.hidden_dim)
271
+ self.patching = LatentPatchModule(config.patch_size, config.hidden_dim, self.latent_chennels, config.vae_type)
272
+ self.time_emb = TextTimeEmbedding(config.time_embed_dim, config.pooled_text_dim, config.modulation_dim)
273
+
274
+ # main multi-modal DiT blocks
275
+ self.mmdit_blocks = nn.ModuleList(
276
+ [
277
+ MotifDiTBlock(
278
+ config.hidden_dim, config.modulation_dim, config.hidden_dim, config.mlp_hidden_dim, config
279
+ )
280
+ for layer_idx in range(config.num_layers)
281
+ ]
282
+ )
283
+
284
+ self.final_modulation = nn.Linear(config.modulation_dim, config.hidden_dim * 2)
285
+ self.final_linear_SD3 = nn.Linear(config.hidden_dim, SD3_LATENT_CHANNEL * config.patch_size**2)
286
+ self.skip_register_token_num = config.skip_register_token_num
287
+
288
+ if getattr(config, "pos_emb_size", None):
289
+ pos_emb_size = config.pos_emb_size
290
+ else:
291
+ pos_emb_size = HIGH_RES_POSEMB_BASE_SIZE if config.height > 512 else LOW_RES_POSEMB_BASE_SIZE
292
+ logger.info(f"Positional embedding of Motif-DiT is set to {pos_emb_size}")
293
+
294
+ self.pos_embed = torch.from_numpy(
295
+ get_2d_sincos_pos_embed(
296
+ config.hidden_dim, (self.h // self.patch_size, self.w // self.patch_size), base_size=pos_emb_size
297
+ )
298
+ ).to(device="cuda", dtype=torch.bfloat16)
299
+
300
+ # set register tokens (https://arxiv.org/abs/2309.16588)
301
+ if config.register_token_num > 0:
302
+ self.register_token_num = config.register_token_num
303
+ self.register_tokens = nn.Parameter(torch.randn(1, self.register_token_num, config.hidden_dim))
304
+ self.register_parameter("register_tokens", self.register_tokens)
305
+
306
+ # if needed, add additional register tokens for higher resolution training
307
+ self.additional_register_token_num = config.additional_register_token_num
308
+ if config.additional_register_token_num > 0:
309
+ self.register_tokens_highres = nn.Parameter(
310
+ torch.randn(1, self.additional_register_token_num, config.hidden_dim)
311
+ )
312
+ self.register_parameter("register_tokens_highres", self.register_tokens_highres)
313
+
314
+ if config.use_final_layer_norm:
315
+ self.final_norm = nn.LayerNorm(config.hidden_dim)
316
+
317
+ if config.conv_header:
318
+ logger.info("use convolution header after de-patching")
319
+ self.depatching_conv_header = IdentityConv2d(SD3_LATENT_CHANNEL)
320
+
321
+ if config.use_time_token_in_attn:
322
+ self.t_token_proj = nn.Linear(config.modulation_dim, config.hidden_dim)
323
+
324
+ def forward(self, latent, t, text_embs: List[torch.Tensor], pooled_text_embs, guiding_feature=None):
325
+ """
326
+ latent (torch.Tensor)
327
+ t (torch.Tensor)
328
+ text_embs (List[torch.Tensor])
329
+ pooled_text_embs (torch.Tensor)
330
+ """
331
+ # 1. get inputs for the MMDiT blocks
332
+ emb_c = self.text_cond(*text_embs) # (N, L, D), text conditions
333
+ emb_t = self.time_emb(pooled_text_embs, t).to(emb_c.device) # (N, D), time and pooled text conditions
334
+
335
+ emb_x = (self.patching(latent) + self.pos_embed).to(
336
+ emb_c.device
337
+ ) # (N, T, D), where T = H*W / (patch_size ** 2), input latent patches
338
+
339
+ # additional "register" tokens, to convey the global information and prevent high-norm abnormal patch
340
+ # see https://openreview.net/forum?id=2dnO3LLiJ1
341
+ if hasattr(self, "register_tokens"):
342
+ if hasattr(self, "register_tokens_highres"):
343
+ emb_x = torch.cat(
344
+ (
345
+ self.register_tokens_highres.expand(emb_x.shape[0], -1, -1),
346
+ self.register_tokens.expand(emb_x.shape[0], -1, -1),
347
+ emb_x,
348
+ ),
349
+ dim=1,
350
+ )
351
+ else:
352
+ emb_x = torch.cat((self.register_tokens.expand(emb_x.shape[0], -1, -1), emb_x), dim=1)
353
+
354
+ # time embedding into text embedding
355
+ if hasattr(self, "use_time_token_in_attn"):
356
+ t_token = self.t_token_proj(emb_t).unsqueeze(1)
357
+ emb_c = torch.cat([emb_c, t_token], dim=1) # (N, [T_c + 1], C)
358
+
359
+ # 2. MMDiT Blocks
360
+ for block_idx, block in enumerate(self.mmdit_blocks):
361
+ emb_x, emb_c = block(emb_x, emb_c, emb_t)
362
+
363
+ # accumulating the feature_similarity loss
364
+ # TODO: add modeling_dit related test
365
+ if hasattr(self, "num_feature_align_layers") and block_idx == self.num_feature_align_layers:
366
+ self.feature_alignment_loss = self.feature_align_mlp(emb_x, guiding_feature) # exclude register tokens
367
+
368
+ # Remove the register tokens at the certain layer (the last layer as default).
369
+ if block_idx == len(self.mmdit_blocks) - (1 + self.skip_register_token_num):
370
+ if hasattr(self, "register_tokens_highres"):
371
+ emb_x = emb_x[
372
+ :, self.register_token_num + self.additional_register_token_num :
373
+ ] # remove the register tokens for the output layer
374
+ elif hasattr(self, "register_tokens"):
375
+ emb_x = emb_x[:, self.register_token_num :] # remove the register tokens for the output layer
376
+
377
+ # 3. final modulation (shift-and-scale)
378
+ scale, shift = self.final_modulation(emb_t).chunk(2, -1) # (N, D) x 2
379
+ scale, shift = scale.unsqueeze(1), shift.unsqueeze(1) # (N, 1, D) x 2
380
+
381
+ if hasattr(self, "final_norm"):
382
+ emb_x = self.final_norm(emb_x)
383
+
384
+ final_emb = (scale + 1) * emb_x + shift
385
+
386
+ # 4. final linear layer to reduce channel and un-patching
387
+ emb_x = self.final_linear_SD3(final_emb) # (N, T, D) to (N, T, out_channels * patch_size**2)
388
+ emb_x = self.patching.unpatchify(emb_x) # (N, out_channels, H, W)
389
+
390
+ if hasattr(self, "depatching_conv_header"):
391
+ emb_x = self.depatching_conv_header(emb_x)
392
+ return emb_x
393
+
394
+
395
+ class JointAttn(nn.Module):
396
+ """
397
+ SD3 style joint-attention layer
398
+ """
399
+
400
+ def __init__(self, config):
401
+ super().__init__()
402
+ self.config = config
403
+ self.hidden_size = config.hidden_dim
404
+ self.num_heads = config.num_attention_heads
405
+ self.head_dim = self.hidden_size // self.num_heads
406
+
407
+ if (self.head_dim * self.num_heads) != self.hidden_size:
408
+ raise ValueError(
409
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
410
+ f" and `num_heads`: {self.num_heads})."
411
+ )
412
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
413
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
414
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
415
+
416
+ self.add_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
417
+ self.add_k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
418
+ self.add_v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
419
+
420
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
421
+ self.add_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
422
+
423
+ self.q_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
424
+ self.k_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
425
+
426
+ self.q_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
427
+ self.k_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
428
+ self.q_scale = nn.Parameter(torch.ones(self.num_heads))
429
+
430
+ # Attention mode : {'sdpa', 'flash', None}
431
+ self.attn_mode = config.attn_mode
432
+
433
+ def forward(
434
+ self,
435
+ hidden_states: torch.FloatTensor,
436
+ encoder_hidden_states: torch.FloatTensor,
437
+ *args,
438
+ **kwargs,
439
+ ) -> torch.FloatTensor:
440
+ residual = hidden_states
441
+
442
+ input_ndim = hidden_states.ndim
443
+ if input_ndim == 4:
444
+ batch_size, channel, height, width = hidden_states.shape
445
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
446
+ context_input_ndim = encoder_hidden_states.ndim
447
+ if context_input_ndim == 4:
448
+ batch_size, channel, height, width = encoder_hidden_states.shape
449
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
450
+
451
+ batch_size = encoder_hidden_states.shape[0]
452
+
453
+ # `sample` projections.
454
+ query = self.q_proj(hidden_states)
455
+ key = self.k_proj(hidden_states)
456
+ value = self.v_proj(hidden_states)
457
+
458
+ # `context` projections.
459
+ query_c = self.add_q_proj(encoder_hidden_states)
460
+ key_c = self.add_k_proj(encoder_hidden_states)
461
+ value_c = self.add_v_proj(encoder_hidden_states)
462
+
463
+ # head first
464
+ inner_dim = key.shape[-1]
465
+ head_dim = inner_dim // self.num_heads
466
+
467
+ def norm_qk(x, f_norm):
468
+ x = x.view(batch_size, -1, self.num_heads, head_dim)
469
+ b, l, h, d_h = x.shape
470
+ x = x.reshape(b * l, h, d_h)
471
+ x = f_norm(x)
472
+ return x.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h]
473
+
474
+ query = norm_qk(query, self.q_norm_x) # [b, h, l, d_h]
475
+ key = norm_qk(key, self.k_norm_x) # [b, h, l, d_h]
476
+ value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h]
477
+
478
+ query_c = norm_qk(query_c, self.q_norm_c) * self.q_scale.reshape(1, self.num_heads, 1, 1) # [b, h, l_c, d]
479
+ key_c = norm_qk(key_c, self.k_norm_c) # [b, h, l_c, d]
480
+ value_c = value_c.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l_c, d]
481
+
482
+ # attention
483
+ query = torch.cat([query, query_c], dim=2).contiguous() # [b, h, l + l_c, d]
484
+ key = torch.cat([key, key_c], dim=2).contiguous() # [b, h, l + l_c, d]
485
+ value = torch.cat([value, value_c], dim=2).contiguous() # [b, h, l + l_c, d]
486
+
487
+ # deprecated.
488
+ hidden_states = self.joint_attention(batch_size, query, key, value, head_dim)
489
+ hidden_states = hidden_states.to(query.dtype)
490
+
491
+ # Split the attention outputs.
492
+ hidden_states, encoder_hidden_states = (
493
+ hidden_states[:, : residual.shape[1]],
494
+ hidden_states[:, residual.shape[1] :],
495
+ )
496
+
497
+ # linear proj
498
+ hidden_states = self.o_proj(hidden_states)
499
+ encoder_hidden_states = self.add_o_proj(encoder_hidden_states)
500
+
501
+ if input_ndim == 4:
502
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
503
+ if context_input_ndim == 4:
504
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
505
+
506
+ return hidden_states, encoder_hidden_states
507
+
508
+ def joint_attention(self, batch_size, query, key, value, head_dim):
509
+ if self.attn_mode == "sdpa" and ScaledDotProductAttention is not None:
510
+ # NOTE: SDPA does not support high-resolution (long-context).
511
+ q_len = query.size(-2)
512
+ masked_bias = torch.zeros((batch_size, self.num_heads, query.size(-2), key.size(-2)), device="cuda")
513
+
514
+ query = query.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
515
+ key = key.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
516
+ value = value.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
517
+
518
+ scale_factor = 1.0
519
+ scale_factor /= float(self.head_dim) ** 0.5
520
+
521
+ hidden_states = ScaledDotProductAttention(
522
+ query,
523
+ key,
524
+ value,
525
+ masked_bias,
526
+ dropout_rate=0.0,
527
+ training=self.training,
528
+ attn_weight_scale_factor=scale_factor,
529
+ num_kv_groups=1,
530
+ )
531
+ elif self.attn_mode == "flash" and MotifFlashAttention is not None:
532
+ query = query.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
533
+ key = key.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
534
+ value = value.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
535
+ scale_factor = 1.0 / math.sqrt(self.head_dim)
536
+
537
+ # NOTE (1): masking of motif flash-attention uses (`1`: un-mask, `0`: mask) and has [Batch, Seq] shape
538
+ # NOTE (2): Q,K,V must be [Batch, Seq, Heads, Dim] and contiguous.
539
+ mask = torch.ones((batch_size, query.size(-3))).cuda()
540
+ hidden_states = MotifFlashAttention(
541
+ query,
542
+ key,
543
+ value,
544
+ padding_mask=mask,
545
+ softmax_scale=scale_factor,
546
+ causal=False,
547
+ )
548
+ hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * head_dim).contiguous()
549
+ else:
550
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0)
551
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
552
+
553
+ return hidden_states
554
+
555
+ @staticmethod
556
+ def alt_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, scale=None) -> torch.Tensor:
557
+ """
558
+ Pure-pytorch version of the xformers.scaled_dot_product_attention
559
+ (or F.scaled_dot_product_attention from torch>2.0.0)
560
+
561
+ Args:
562
+ query (Tensor): query tensor
563
+ key (Tensor): key tensor
564
+ value (Tensor): value tensor
565
+ attn_mask (Tensor, optional): attention mask. Defaults to None.
566
+ dropout_p (float, optional): attention dropout probability. Defaults to 0.0.
567
+ scale (Tensor or float, optional): scaling for QK. Defaults to None.
568
+
569
+ Returns:
570
+ torch.Tensor: attention score (after softmax)
571
+ """
572
+ L, S = query.size(-2), key.size(-2)
573
+ scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
574
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
575
+
576
+ if attn_mask is not None:
577
+ if attn_mask.dtype == torch.bool:
578
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
579
+ else:
580
+ attn_bias += attn_mask
581
+
582
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor # B, L, S
583
+ attn_weight += attn_bias
584
+ attn_weight = torch.softmax(attn_weight, dim=-1) # B, L, S
585
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
586
+ return attn_weight @ value # B, L, S * S, D -> B, L, D
587
+
588
+
589
+ # ===============================================
590
+ # Sine/Cosine Positional Embedding Functions
591
+ # ===============================================
592
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
593
+
594
+
595
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
596
+ """
597
+ grid_size: int of the grid height and width
598
+ return:
599
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
600
+ """
601
+ if not isinstance(grid_size, tuple):
602
+ grid_size = (grid_size, grid_size)
603
+
604
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
605
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
606
+ if base_size is not None:
607
+ grid_h *= base_size / grid_size[0]
608
+ grid_w *= base_size / grid_size[1]
609
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
610
+ grid = np.stack(grid, axis=0)
611
+
612
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
613
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
614
+ if cls_token and extra_tokens > 0:
615
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
616
+ return pos_embed
617
+
618
+
619
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
620
+ assert embed_dim % 2 == 0
621
+
622
+ # use half of dimensions to encode grid_h
623
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
624
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
625
+
626
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
627
+ return emb
628
+
629
+
630
+ def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
631
+ pos = np.arange(0, length)[..., None] / scale
632
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
633
+
634
+
635
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
636
+ """
637
+ embed_dim: output dimension for each position
638
+ pos: a list of positions to be encoded: size (M,)
639
+ out: (M, D)
640
+ """
641
+ assert embed_dim % 2 == 0
642
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
643
+ omega /= embed_dim / 2.0
644
+ omega = 1.0 / 10000**omega # (D/2,)
645
+
646
+ pos = pos.reshape(-1) # (M,)
647
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
648
+
649
+ emb_sin = np.sin(out) # (M, D/2)
650
+ emb_cos = np.cos(out) # (M, D/2)
651
+
652
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
653
+ return emb
models/modeling_motif_vision.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms as T
8
+ import tqdm
9
+ from diffusers.models import AutoencoderKL
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from loguru import logger
12
+ from PIL import Image, ImageFilter
13
+ from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer
14
+
15
+ from models.mixin.flow_mixin import FlowMixin
16
+ from models.modeling_dit import MotifDiT
17
+
18
+ TOKEN_MAX_LENGTH: int = 256
19
+ DROP_PROB: float = 0.1
20
+ LATENT_CHANNELS: int = 4
21
+ VAE_DOWNSCALE_FACTOR: int = 8
22
+ SD3_LATENT_CHANNEL: int = 16
23
+
24
+
25
+ def generate_intervals(steps, ratio, start=1.0):
26
+ intervals = torch.linspace(start, 0, steps=steps)
27
+ intervals = intervals.pow(ratio)
28
+ return intervals
29
+
30
+
31
+ class MotifVision(nn.Module, FlowMixin):
32
+ """
33
+ MotifVision Text-to-Image model.
34
+
35
+ This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders.
36
+ It uses a VAE (Variational Autoencoder) for image encoding and decoding.
37
+
38
+ Args:
39
+ config (MMDiTConfig): Configuration object for the MMDiT model.
40
+
41
+ Attributes:
42
+ dit (MotifDiT): MotifDiT model instance.
43
+ noise_scheduler (DDPMScheduler): Noise scheduler for the diffusion process.
44
+ normalize_img (Callable): Function to normalize images from [-1, 1] range.
45
+ unnormalize_img (Callable): Function to unnormalize images to [0, 1] range.
46
+ cond_drop_prob (float): Probability of dropping text embeddings during training.
47
+ snr_gamma (str): Strategy for weighting the loss based on Signal-to-Noise Ratio (SNR).
48
+ loss_weight_strategy (str): Strategy for weighting the loss.
49
+ vae (AutoencoderKL): Variational Autoencoder for image encoding and decoding.
50
+ t5 (T5EncoderModel): T5 encoder model for text encoding.
51
+ t5_tokenizer (T5Tokenizer): T5 tokenizer for text tokenization.
52
+ clip_l (CLIPModel): CLIP (Contrastive Language-Image Pre-training) model (large) for text encoding.
53
+ clip_l_tokenizer (CLIPTokenizerFast): CLIP tokenizer (large) for text tokenization.
54
+ clip_g (CLIPModel): CLIP model (giant) for text encoding.
55
+ clip_g_tokenizer (CLIPTokenizerFast): CLIP tokenizer (giant) for text tokenization.
56
+ tokenizers (List[Union[T5Tokenizer, CLIPTokenizerFast]]): List of tokenizers.
57
+ text_encoders (List[Union[T5EncoderModel, CLIPModel]]): List of text encoder models.
58
+ """
59
+
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.config = config
63
+ self.dit = MotifDiT(config)
64
+ self.cond_drop_prob = 0.1
65
+ self.use_weighting = False
66
+ self._get_encoders()
67
+ self._freeze_encoders()
68
+
69
+ def forward(self, images: torch.Tensor, raw_text: str) -> torch.Tensor:
70
+ """
71
+ Forward pass of the MotifDiT model.
72
+
73
+ Args:
74
+ images (torch.Tensor): Input images tensor, [0-1] ranged.
75
+ raw_text (List[str]): Input text string.
76
+
77
+ Returns:
78
+ torch.Tensor: Rectified flow matching loss.
79
+ """
80
+ # 1. Encode images and texts
81
+ with torch.no_grad():
82
+ latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
83
+ tokens, masks = self.tokenization(raw_text)
84
+ tokens = [token.to(latents.device) for token in tokens]
85
+ masks = [mask.to(latents.device) for mask in masks]
86
+ text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
87
+ text_embeddings = self._drop_text_emb(text_embeddings)
88
+ text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
89
+ pooled_text_embeddings = pooled_text_embeddings.float()
90
+
91
+ # 2. Get noisy input via the rectified flow
92
+ is_finetuning = self.config.height > 256
93
+ noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
94
+
95
+ timesteps = self.discritize_timestep(t, self.n_timesteps)
96
+
97
+ # 3. Forward pass through the dit
98
+ preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)
99
+
100
+ # 4. Rectified flow matching loss
101
+ loss = self.rectified_flow_loss(latents, noise, t, preds, use_weighting=self.use_weighting)
102
+
103
+ return [loss]
104
+
105
+ def _get_encoders(self) -> None:
106
+ """Initialize the VAE and text encoders."""
107
+ if self.config.vae_type == "SD3":
108
+ self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
109
+ elif self.config.vae_type == "SDXL":
110
+ self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
111
+ else:
112
+ raise ValueError(f"VAE type must be `SD3` or `SDXL` but self.config.vae_type is {self.config.vae_type}")
113
+
114
+ # Text encoders
115
+ # 1. T5-XXL from Google
116
+ self.t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
117
+ self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
118
+
119
+ # 2. CLIP-L from OpenAI
120
+ self.clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
121
+ self.clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
122
+
123
+ # 3. CLIP-G from LAION
124
+ self.clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
125
+ self.clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
126
+
127
+ self.tokenizers = [self.t5_tokenizer, self.clip_l_tokenizer, self.clip_g_tokenizer]
128
+ self.text_encoders = [self.t5, self.clip_l, self.clip_g]
129
+
130
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
131
+ state_dict = super(MotifVision, self).state_dict(destination, prefix, keep_vars)
132
+ exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."]
133
+ for key in list(state_dict.keys()):
134
+ if any(key.startswith(exclude_key) for exclude_key in exclude_keys):
135
+ state_dict.pop(key)
136
+ return state_dict
137
+
138
+ def load_state_dict(self, state_dict, strict=False):
139
+ """
140
+ Load state dict and merge LoRA parameters if present.
141
+
142
+ Args:
143
+ state_dict (dict): State dictionary containing model parameters
144
+ strict (bool): Whether to strictly enforce that the keys in state_dict match the keys in this module
145
+
146
+ Returns:
147
+ tuple: (missing_keys, unexpected_keys) lists of parameters that were missing or unexpected
148
+ """
149
+ # Check if state_dict contains LoRA parameters
150
+ has_lora = any("lora_" in key for key in state_dict.keys())
151
+
152
+ if has_lora:
153
+ # If model doesn't have LoRA enabled but state_dict has LoRA params, enable it
154
+ if not hasattr(self.dit, "peft_config"):
155
+ logger.info("Enabling LoRA for parameter merging...")
156
+ # Use default values if not already configured
157
+ lora_rank = getattr(self.config, "lora_rank", 64)
158
+ lora_alpha = getattr(self.config, "lora_alpha", 8)
159
+ self.enable_lora(lora_rank, lora_alpha)
160
+
161
+ if has_lora:
162
+ try:
163
+ # Load LoRA parameters
164
+ # state_dict = {
165
+ # k.replace("base_layer.", ""): v
166
+ # for k, v in state_dict.items()
167
+ # if "lora_" not in k and "lora" not in k
168
+ # }
169
+ missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
170
+ # Merge LoRA weights with base model
171
+ logger.info("Merging LoRA parameters with base model...")
172
+ for name, module in self.dit.named_modules():
173
+ if hasattr(module, "merge_and_unload"):
174
+ module.merge_and_unload()
175
+
176
+ logger.info("Successfully merged LoRA parameters")
177
+
178
+ except Exception as e:
179
+ logger.error(f"Error merging LoRA parameters: {str(e)}")
180
+ raise
181
+ else:
182
+ missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
183
+
184
+ # Log summary of missing/unexpected parameters
185
+ missing_top_levels = set()
186
+ for key in missing_keys:
187
+ top_level_name = key.split(".")[0]
188
+ missing_top_levels.add(top_level_name)
189
+ if missing_top_levels:
190
+ logger.debug("Missing keys during loading at top level:")
191
+ for name in missing_top_levels:
192
+ logger.debug(name)
193
+
194
+ if unexpected_keys:
195
+ logger.debug("Unexpected keys found:")
196
+ for key in unexpected_keys:
197
+ logger.debug(key)
198
+
199
+ return missing_keys, unexpected_keys
200
+
201
+ def _freeze_encoders(self) -> None:
202
+ """
203
+ freeze all encoders
204
+ """
205
+ for encoder_module in [self.vae, self.clip_l, self.clip_g, self.t5]:
206
+ for param in encoder_module.parameters():
207
+ param.requires_grad = False
208
+
209
+ def tokenization(
210
+ self, raw_texts: List[str], repeat_if_short: bool = False
211
+ ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
212
+ """
213
+ Tokenizes a BATCH of input texts using multiple tokenizers efficiently.
214
+ Optionally repeats each text to fill the max length if it's shorter,
215
+ BEFORE passing the pre-processed batch to the tokenizer.
216
+
217
+ Args:
218
+ raw_texts (List[str]): A list of input text strings (the batch).
219
+ repeat_if_short (bool): If True and a text is short, repeat that text
220
+ to fill the context length. Defaults to True.
221
+
222
+ Returns:
223
+ Tuple[List[torch.Tensor], List[torch.Tensor]]:
224
+ - A list containing one batch tensor of input IDs per tokenizer.
225
+ Each tensor shape: [batch_size, max_length]
226
+ - A list containing one batch tensor of attention masks per tokenizer.
227
+ Each tensor shape: [batch_size, max_length]
228
+ """
229
+ final_batch_tokens = []
230
+ final_batch_masks = []
231
+
232
+ # Process the batch with each tokenizer
233
+ for tokenizer in self.tokenizers:
234
+ effective_max_length = min(TOKEN_MAX_LENGTH, tokenizer.model_max_length)
235
+
236
+ # 1. Pre-process the batch: Create a new list of potentially repeated strings.
237
+ processed_texts_for_tokenizer = []
238
+ for text_item in raw_texts:
239
+ # Start with the original text for this item
240
+ processed_text = text_item
241
+
242
+ if repeat_if_short:
243
+ # Apply repetition logic individually based on text_item's length
244
+ num_initial_tokens = len(text_item.split())
245
+ available_length = effective_max_length - 2 # Heuristic
246
+
247
+ if num_initial_tokens > 0 and num_initial_tokens < available_length:
248
+ num_additional_repeats = available_length // (num_initial_tokens + 1)
249
+ if num_additional_repeats > 0:
250
+ total_repeats = 1 + num_additional_repeats
251
+ processed_text = " ".join([text_item] * total_repeats)
252
+
253
+ # Add the processed text (original or repeated) to the list for this tokenizer
254
+ processed_texts_for_tokenizer.append(processed_text)
255
+
256
+ # 2. Tokenize the entire batch of processed texts at once.
257
+ # Pass the list `processed_texts_for_tokenizer` directly to the tokenizer.
258
+ # The tokenizer's __call__ method should handle the batch efficiently.
259
+ batch_tok_output = tokenizer( # Call the tokenizer ONCE with the full list
260
+ processed_texts_for_tokenizer,
261
+ padding="max_length",
262
+ max_length=effective_max_length,
263
+ return_tensors="pt",
264
+ truncation=True,
265
+ )
266
+
267
+ # 3. Store the resulting batch tensors directly.
268
+ # The tokenizer should return tensors with shape [batch_size, max_length].
269
+ final_batch_tokens.append(batch_tok_output.input_ids)
270
+ final_batch_masks.append(batch_tok_output.attention_mask)
271
+
272
+ return final_batch_tokens, final_batch_masks
273
+
274
+ @torch.no_grad()
275
+ def text_encoding(
276
+ self, tokens: List[torch.Tensor], masks, noisy_pad=False, zero_masking=True
277
+ ) -> Tuple[List[torch.Tensor], torch.Tensor]:
278
+ """
279
+ Encode the tokenized text using multiple text encoders.
280
+
281
+ Args:
282
+ tokens (List[torch.Tensor]): List of tokenized text tensors.
283
+
284
+ Returns:
285
+ Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing a list of text embeddings and pooled text embeddings.
286
+ """
287
+ t5_tokens, clip_l_tokens, clip_g_tokens = tokens
288
+ t5_masks, clip_l_masks, clip_g_masks = masks
289
+ t5_emb = self.t5(t5_tokens, attention_mask=t5_masks)[0]
290
+ if zero_masking:
291
+ t5_emb = t5_emb * (t5_tokens != self.t5_tokenizer.pad_token_id).unsqueeze(-1)
292
+ if noisy_pad:
293
+ t5_pad_noise = (
294
+ (t5_tokens == self.t5_tokenizer.pad_token_id).unsqueeze(-1) * torch.randn_like(t5_emb).cuda() * 0.008
295
+ )
296
+ t5_emb = t5_emb + t5_pad_noise
297
+
298
+ clip_l_emb = self.clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
299
+ clip_g_emb = self.clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
300
+ clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768
301
+ clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280
302
+
303
+ clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768,
304
+ clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280,
305
+
306
+ def masking_wo_first_eos(token, eos):
307
+ idx = (token != eos).sum(dim=1)
308
+ mask = token != eos
309
+ arange = torch.arange(mask.size(0)).cuda()
310
+ mask[arange, idx] = True
311
+ mask = mask.unsqueeze(-1) # B x L x 1
312
+ return mask
313
+
314
+ if zero_masking:
315
+ clip_l_emb = clip_l_emb * masking_wo_first_eos(
316
+ clip_l_tokens, self.clip_l_tokenizer.eos_token_id
317
+ ) # B x L x 768,
318
+ clip_g_emb = clip_g_emb * masking_wo_first_eos(
319
+ clip_g_tokens, self.clip_g_tokenizer.eos_token_id
320
+ ) # B x L x 768,
321
+
322
+ if noisy_pad:
323
+ clip_l_pad_noise = (
324
+ ~masking_wo_first_eos(clip_l_tokens, self.clip_l_tokenizer.eos_token_id)
325
+ * torch.randn_like(clip_l_emb).cuda()
326
+ * 0.08
327
+ )
328
+ clip_g_pad_noise = (
329
+ ~masking_wo_first_eos(clip_g_tokens, self.clip_g_tokenizer.eos_token_id)
330
+ * torch.randn_like(clip_g_emb).cuda()
331
+ * 0.08
332
+ )
333
+ clip_l_emb = clip_l_emb + clip_l_pad_noise
334
+ clip_g_emb = clip_g_emb + clip_g_pad_noise
335
+
336
+ encodings = [t5_emb, clip_l_emb, clip_g_emb]
337
+ pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # cat by channel, B x 2048
338
+
339
+ return encodings, pooled_encodings
340
+
341
+ @torch.no_grad()
342
+ def prompt_embedding(self, prompts: str, device, noisy_pad=False, zero_masking=True):
343
+ tokens, masks = self.tokenization(prompts)
344
+ tokens = [token.to(device) for token in tokens]
345
+ masks = [mask.to(device) for mask in masks]
346
+ text_embeddings, pooled_text_embeddings = self.text_encoding(
347
+ tokens, masks, noisy_pad=noisy_pad, zero_masking=zero_masking
348
+ )
349
+ text_embeddings = [text_embedding.bfloat16() for text_embedding in text_embeddings]
350
+ pooled_text_embeddings = pooled_text_embeddings.bfloat16()
351
+ return text_embeddings, pooled_text_embeddings
352
+
353
+ @torch.no_grad()
354
+ def sample(
355
+ self,
356
+ raw_text: List[str],
357
+ steps: int = 50,
358
+ guidance_scale: float = 7.5,
359
+ resolution: List[int] = (256, 256),
360
+ pre_latent=None,
361
+ pre_timestep=None,
362
+ step_scaling=1.0,
363
+ noisy_pad=False,
364
+ zero_masking=False,
365
+ negative_prompt: Optional[List[str]] = None,
366
+ device: str = "cuda",
367
+ rescale_cfg=-1.0,
368
+ clip_t=[0.0, 1.0],
369
+ use_linear_quadratic_schedule=False,
370
+ linear_quadratic_emulating_steps=250,
371
+ prompt_rewriter=None,
372
+ moderator=None,
373
+ get_intermediate_steps: bool = False, # Defaulting to True based on user code
374
+ ) -> Union[List[Image.Image], Tuple[List[Image.Image], List[List[Image.Image]]]]: # Updated return type hint
375
+ """
376
+ Sample images using flow matching. Optionally returns intermediate step images
377
+ calculated via observed average velocity method.
378
+
379
+ Args:
380
+ raw_text (List[str]): raw text prompts
381
+ steps (int, optional): number of function estimations for flow matching ODE. Defaults to 50.
382
+ guidance_scale (float, optional): classifier free guidance scale. Defaults to 7.5.
383
+ resolution (List[int], optional): input and output resolution of raw images. Defaults to (256, 256).
384
+ device (str, optional): Defaults to 'cuda'.
385
+ pre_latent (Tensor, optional): the optional input to generate image with pre-defined latents.
386
+ for instance, it would be utilized for denoising or image-editing.
387
+ pre_timestep (float [0,1], optional): the pre-defined timestep. with `pre_latent`, image generation
388
+ can be done by starting with intermediate timestep.
389
+ step_scaling (float, default to 1.3): scaling factor for each ODE-solving.
390
+ use_linear_quadratic_schedule (bool, default to false): boolean option to linear-quaratic t schdule. If false, then linear t schdule.
391
+ linear_quadratic_emulating_steps (int, default to 250): N value in linear-quadratic t schedule from Meta moviegen paper
392
+ Reference: (https://ai.meta.com/static-resource/movie-gen-research-paper) Figure 10
393
+ get_intermediate_steps (bool, optional): Whether to calculate and return intermediate step images.
394
+ Calculation is based on initial_noise - avg(velocity). Defaults to True.
395
+
396
+ Returns:
397
+ Union[List[PIL.Image], Tuple[List[PIL.Image], List[List[PIL.Image]]]]:
398
+ If get_intermediate_steps is False: Returns a list of final PIL images.
399
+ If get_intermediate_steps is True: Returns a tuple containing:
400
+ - List[PIL.Image]: Final output PIL images.
401
+ - List[List[PIL.Image]]: List of intermediate PIL images. Each inner list contains
402
+ the batch of images for one intermediate step.
403
+ """
404
+ if prompt_rewriter:
405
+ prompts = [prompt_rewriter.rewrite(prompt) for prompt in raw_text]
406
+ else:
407
+ prompts = raw_text
408
+
409
+ # Simplified check for rewriter status
410
+ if prompts == raw_text and prompt_rewriter is not None:
411
+ logger.debug("Prompt rewriter did not change the prompts.")
412
+ elif prompt_rewriter is None:
413
+ logger.debug("Prompt rewriter not provided.")
414
+
415
+ if moderator is None:
416
+ is_safe_prompt = [True for _ in prompts]
417
+ else:
418
+ is_safe_prompt = [moderator and moderator.is_safe_content(prompt, threshold=0.7) for prompt in prompts]
419
+ if not all(is_safe_prompt):
420
+ logger.warning("Noxious prompt detected. Output image(s) will be blurred.")
421
+
422
+ b = len(prompts)
423
+ h, w = resolution
424
+
425
+ # --- [Initial Latent Noise (e = x_1)] ---
426
+ latent_channels = 16
427
+ if pre_latent is None:
428
+ initial_noise = randn_tensor( # Store initial noise separately
429
+ (b, latent_channels, h // VAE_DOWNSCALE_FACTOR, w // VAE_DOWNSCALE_FACTOR),
430
+ device=device,
431
+ dtype=torch.float32, # Use float32 for calculations
432
+ )
433
+ else:
434
+ initial_noise = pre_latent.to(device=device, dtype=torch.float32)
435
+ if pre_timestep is not None and pre_timestep < 1.0: # Check if it's truly intermediate
436
+ logger.warning(
437
+ "Using pre_latent as initial_noise for average calculation, but pre_timestep suggests it's not pure noise. Results might be unexpected."
438
+ )
439
+
440
+ latents = initial_noise.clone() # Working latents for the ODE solver
441
+
442
+ # --- [Text Embeddings & CFG Setup] ---
443
+ text_embeddings, pooled_text_embeddings = self.prompt_embedding(
444
+ prompts, latents.device, noisy_pad=noisy_pad, zero_masking=zero_masking
445
+ )
446
+ text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in text_embeddings]
447
+ pooled_text_embeddings = pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)
448
+
449
+ do_classifier_free_guidance = guidance_scale > 1.0
450
+ if do_classifier_free_guidance:
451
+ negative_text_embeddings = [
452
+ torch.zeros_like(text_embedding, device=text_embedding.device) for text_embedding in text_embeddings
453
+ ]
454
+ negative_pooled_text_embeddings = torch.zeros_like(
455
+ pooled_text_embeddings, device=pooled_text_embeddings.device
456
+ )
457
+ text_embeddings = [
458
+ torch.cat([text_embedding, negative_text_embedding], dim=0)
459
+ for text_embedding, negative_text_embedding in zip(text_embeddings, negative_text_embeddings)
460
+ ]
461
+ pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)
462
+
463
+ # if negative_prompt is None:
464
+ # negative_prompt = [""] * b
465
+ # logger.debug("No negative prompt provided, using empty strings for CFG.")
466
+ # negative_text_embeddings, negative_pooled_text_embeddings = self.prompt_embedding(negative_prompt, latents.device)
467
+ # negative_text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in negative_text_embeddings]
468
+ # negative_pooled_text_embeddings = negative_pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)
469
+
470
+ # text_embeddings = [torch.cat([pos_emb, neg_emb], dim=0) for pos_emb, neg_emb in zip(text_embeddings, negative_text_embeddings)]
471
+ # pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)
472
+
473
+ # --- [Timestep Schedule (Sigmas)] ---
474
+ # linear t schedule
475
+ sigmas = torch.linspace(1, 0, steps + 1) if not pre_timestep else torch.linspace(pre_timestep, 0, steps + 1)
476
+
477
+ if use_linear_quadratic_schedule:
478
+ # liner-quadratic t schedule
479
+ assert steps % 2 == 0
480
+ N = linear_quadratic_emulating_steps
481
+ sigmas = torch.concat(
482
+ [
483
+ torch.linspace(1, 0, N + 1)[: steps // 2],
484
+ torch.linspace(0, 1, steps // 2 + 1) ** 2 * (steps // 2 * 1 / N - 1) - (steps // 2 * 1 / N - 1),
485
+ ]
486
+ )
487
+
488
+ # --- [Initialization for Intermediate Step Calculation] ---
489
+ # intermediate_latents will store the latent states for intermediate steps
490
+ intermediate_latents = [] if get_intermediate_steps else None
491
+ predicted_velocities = [] # Store dx from each step
492
+ sigma_history = []
493
+ # --- [Sampling Loop] ---
494
+ for infer_step, t in tqdm.tqdm(enumerate(sigmas[:-1]), total=len(sigmas[:-1]), desc="Sampling"):
495
+ # Prepare input for DiT model
496
+ if do_classifier_free_guidance:
497
+ input_latents = torch.cat([latents] * 2, dim=0)
498
+ else:
499
+ input_latents = latents
500
+
501
+ # Prepare timestep input
502
+ timestep = (t * 1000).round().long().to(latents.device)
503
+ timestep = timestep.expand(input_latents.shape[0]).to(torch.bfloat16) # Ensure timestep is bfloat16
504
+
505
+ # Predict velocity dx = v(x_t, t) ≈ e - x_0
506
+ dx = self.dit(input_latents.to(torch.bfloat16), timestep, text_embeddings, pooled_text_embeddings)
507
+ dt = sigmas[infer_step + 1] - sigmas[infer_step] # dt is negative
508
+ sigma_history.append(dt)
509
+
510
+ # Apply Classifier-Free Guidance
511
+ if do_classifier_free_guidance:
512
+ cond_dx, uncond_dx = dx.chunk(2)
513
+ current_guidance_scale = guidance_scale if clip_t[0] <= t and t <= clip_t[1] else 1.0
514
+ dx = uncond_dx + current_guidance_scale * (cond_dx - uncond_dx)
515
+
516
+ if rescale_cfg > 0.0:
517
+ std_pos = torch.std(cond_dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
518
+ std_cfg = torch.std(dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
519
+ factor = std_pos / std_cfg
520
+ factor = rescale_cfg * factor + (1.0 - rescale_cfg)
521
+ dx = dx * factor
522
+
523
+ # --- Store the predicted velocity for averaging ---
524
+ predicted_velocities.append(dx.clone())
525
+
526
+ # --- Update Latents using standard Euler step ---
527
+ latents = latents + dt * dx
528
+
529
+ # --- Calculate and Store Intermediate Latent State (if requested) ---
530
+ if get_intermediate_steps:
531
+ dxs = torch.stack(predicted_velocities)
532
+
533
+ sigma_sum = sum(sigma_history)
534
+ normalized_sigma_history = [s / (sigma_sum) for s in sigma_history]
535
+ dts = torch.tensor(normalized_sigma_history, device=dxs.device, dtype=dxs.dtype).view(-1, 1, 1, 1, 1)
536
+
537
+ avg_dx = torch.sum(dxs * dts, dim=0)
538
+ observed_state = initial_noise - avg_dx # Calculate the desired intermediate state
539
+ intermediate_latents.append(observed_state.clone()) # Store its latent representation
540
+
541
+ # --- [Decode Final Latents to PIL Images] ---
542
+ self.vae = self.vae.to(device=latents.device, dtype=torch.float32) # Ensure VAE is ready
543
+ final_latents_scaled = latents.to(torch.float32) / self.vae.config.scaling_factor
544
+ final_image_tensors = self.vae.decode(final_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
545
+ final_image_tensors = ((final_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)
546
+
547
+ final_pil_images = []
548
+ for i, image_tensor in enumerate(final_image_tensors):
549
+ img = T.ToPILImage()(image_tensor.cpu())
550
+ if not is_safe_prompt[i]:
551
+ img = img.filter(ImageFilter.GaussianBlur(radius=30))
552
+ final_pil_images.append(img)
553
+
554
+ # --- [Decode Intermediate Latents to PIL Images (if requested)] ---
555
+ if get_intermediate_steps:
556
+ intermediate_pil_images = []
557
+ # Ensure VAE is still ready (it should be from final decoding)
558
+ for step_latents in tqdm.tqdm(intermediate_latents, desc="Decoding intermediates"):
559
+ step_latents_scaled = (
560
+ step_latents.to(dtype=torch.float32, device="cuda") / self.vae.config.scaling_factor
561
+ )
562
+ step_image_tensors = (
563
+ self.vae.decode(step_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
564
+ )
565
+ step_image_tensors = ((step_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)
566
+
567
+ current_step_pil = []
568
+ for i, image_tensor in enumerate(step_image_tensors):
569
+ img = T.ToPILImage()(image_tensor.cpu())
570
+ # Apply moderation blur consistency
571
+ if not is_safe_prompt[i]:
572
+ img = img.filter(ImageFilter.GaussianBlur(radius=30))
573
+ current_step_pil.append(img)
574
+ intermediate_pil_images.append(current_step_pil) # Append list of images for this step
575
+
576
+ return final_pil_images, intermediate_pil_images # Return both final and intermediate images
577
+ else:
578
+ return final_pil_images # Return only final images
579
+
580
+ @torch.no_grad()
581
+ def eval_with_loss(self, images, raw_text):
582
+ latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
583
+
584
+ tokens, masks = self.tokenization(raw_text)
585
+ tokens = [token.to(latents.device) for token in tokens]
586
+ masks = [mask.to(latents.device) for mask in masks]
587
+ text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
588
+ text_embeddings = [text_embedding for text_embedding in text_embeddings]
589
+ pooled_text_embeddings = pooled_text_embeddings.float()
590
+
591
+ # 2. Get noisy input via the rectified flow
592
+ is_finetuning = self.config.height > 256
593
+ noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
594
+ timesteps = self.discritize_timestep(t, self.n_timesteps)
595
+
596
+ # 3. Forward pass through the dit
597
+ preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)
598
+
599
+ # 4. Rectified flow matching loss
600
+ loss = self.rectified_flow_loss(noise_latents, noise, t, preds, reduce="none", use_weighting=False).mean(
601
+ dim=[1, 2, 3]
602
+ )
603
+
604
+ intervals = np.linspace(0, 1, 9)
605
+ t_interval = [(intervals[i], intervals[i + 1]) for i in range(len(intervals) - 1)]
606
+
607
+ loss_bins = defaultdict(list)
608
+ for i, interval in enumerate(t_interval, 0):
609
+ idx = (interval[0] < t) & (t < interval[1])
610
+ loss_bins[i].append(loss[idx])
611
+
612
+ return loss_bins