Commit
·
6cd6a16
1
Parent(s):
0e812de
Add MotifVision model for text-to-image generation
Browse files- Implemented the MotifVision class, combining a Diffusion transformer with rectified flow loss and multiple text encoders.
- Integrated Variational Autoencoder (VAE) for image encoding and decoding.
- Added methods for tokenization, text encoding, and sampling images with flow matching.
- Included support for multiple text encoders: T5 and CLIP models.
- Implemented functionality for handling LoRA parameters during state dict loading.
- Added evaluation method to compute loss during inference.
- configs/configuration_mmdit.py +89 -0
- configs/mmdit_xlarge_hq.json +26 -0
- models/mixin/encoder_mixin.py +222 -0
- models/mixin/flow_mixin.py +175 -0
- models/modeling_dit.py +653 -0
- models/modeling_motif_vision.py +612 -0
configs/configuration_mmdit.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
ENCODED_TEXT_DIM = 4096
|
5 |
+
POOLED_TEXT_DIM = 2048
|
6 |
+
VAE_COMPRESSION_RATIO = 8
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class MMDiTConfig:
|
11 |
+
# General
|
12 |
+
num_layers: int = 12
|
13 |
+
hidden_dim: int = 768 # common hidden dimension for the transformer arch
|
14 |
+
patch_size: int = 2
|
15 |
+
image_dim: int = 224
|
16 |
+
in_channel: int = 4
|
17 |
+
out_channel: int = 4
|
18 |
+
modulation_dim: int = ENCODED_TEXT_DIM # input dimension for modulation layer (shifting and scaling)
|
19 |
+
height: int = 1024
|
20 |
+
width: int = 1024
|
21 |
+
vae_compression: int = VAE_COMPRESSION_RATIO # reducing resolution with the VAE
|
22 |
+
vae_type: str = "SD3" # SDXL or SD3
|
23 |
+
pos_emb_size: int = None
|
24 |
+
conv_header: bool = False
|
25 |
+
|
26 |
+
# Outside of the MMDiT block
|
27 |
+
time_embed_dim: int = 2048 # Initial projection (discrete_time embedding) output dim
|
28 |
+
pooled_text_dim: int = POOLED_TEXT_DIM
|
29 |
+
text_emb_dim: int = 768
|
30 |
+
|
31 |
+
# MMDiTBlock
|
32 |
+
t_emb_dim: int = 256
|
33 |
+
attn_embed_dim: int = 768 # hidden dimension during the attention
|
34 |
+
mlp_hidden_dim: int = 2048
|
35 |
+
attn_mode: str = None # {'flash', 'sdpa', None}
|
36 |
+
use_final_layer_norm: bool = False
|
37 |
+
use_time_token_in_attn: bool = False
|
38 |
+
|
39 |
+
# GroupedQueryAttention
|
40 |
+
num_attention_heads: int = 12
|
41 |
+
num_key_value_heads: int = 6
|
42 |
+
use_scaled_dot_product_attention: bool = True
|
43 |
+
dropout: float = 0.0
|
44 |
+
|
45 |
+
# Modulation
|
46 |
+
use_modulation: bool = True
|
47 |
+
modulation_type: str = "film" # Choose from 'film', 'adain', or 'spade'
|
48 |
+
|
49 |
+
# Register tokens
|
50 |
+
register_token_num: int = 4
|
51 |
+
additional_register_token_num: int = 12
|
52 |
+
|
53 |
+
# use dinov2 feature-align loss
|
54 |
+
dinov2_feature_align_loss: bool = False
|
55 |
+
feature_align_loss_weight: float = 0.5
|
56 |
+
num_feature_align_layers: int = 8 # number of transformer layers to calculate feature-align loss
|
57 |
+
|
58 |
+
# Personalization related
|
59 |
+
image_encoder_name: str = None # if set, the persoanlized image encoder will be loaded
|
60 |
+
freeze_dit_backbone: bool = False
|
61 |
+
|
62 |
+
# Preference optimization
|
63 |
+
preference_train: bool = False
|
64 |
+
lora_rank: int = 64
|
65 |
+
lora_alpha: int = 8
|
66 |
+
|
67 |
+
skip_register_token_num: int = 0
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def from_json_file(cls, json_file):
|
71 |
+
"""
|
72 |
+
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
json_file (`str` or `os.PathLike`):
|
76 |
+
Path to the JSON file containing the parameters.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
[`PretrainedConfig`]: The configuration object instantiated from that JSON file.
|
80 |
+
|
81 |
+
"""
|
82 |
+
config_dict = cls._dict_from_json_file(json_file)
|
83 |
+
return cls(**config_dict)
|
84 |
+
|
85 |
+
@classmethod
|
86 |
+
def _dict_from_json_file(cls, json_file):
|
87 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
88 |
+
text = reader.read()
|
89 |
+
return json.loads(text)
|
configs/mmdit_xlarge_hq.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_layers": 30,
|
3 |
+
"hidden_dim": 1920,
|
4 |
+
"patch_size": 2,
|
5 |
+
"in_channel": 4,
|
6 |
+
"out_channel": 4,
|
7 |
+
"time_embed_dim": 4096,
|
8 |
+
"attn_embed_dim": 4096,
|
9 |
+
"num_attention_heads": 30,
|
10 |
+
"num_key_value_heads": 30,
|
11 |
+
"use_scaled_dot_product_attention": true,
|
12 |
+
"dropout": 0.0,
|
13 |
+
"mlp_hidden_dim": 7680,
|
14 |
+
"use_modulation": true,
|
15 |
+
"modulation_type": "film",
|
16 |
+
"register_token_num": 4,
|
17 |
+
"additional_register_token_num": 0,
|
18 |
+
"skip_register_token_num": 0,
|
19 |
+
"height": 1024,
|
20 |
+
"width": 1024,
|
21 |
+
"attn_mode": "flash",
|
22 |
+
"use_final_layer_norm": false,
|
23 |
+
"pos_emb_size": 64,
|
24 |
+
"conv_header": false,
|
25 |
+
"use_time_token_in_attn": true
|
26 |
+
}
|
models/mixin/encoder_mixin.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.models import AutoencoderKL
|
3 |
+
|
4 |
+
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer
|
5 |
+
|
6 |
+
|
7 |
+
class EncoderMixin:
|
8 |
+
"""Mixin class for handling various encoders in the MotifDiT model.
|
9 |
+
|
10 |
+
This mixin provides functionality for:
|
11 |
+
1. Loading and initializing encoders (VAE, T5, CLIP-L, CLIP-G)
|
12 |
+
2. Text tokenization and encoding
|
13 |
+
3. Managing encoder parameters and state
|
14 |
+
"""
|
15 |
+
|
16 |
+
TOKEN_MAX_LENGTH: int = 256
|
17 |
+
|
18 |
+
def prepare_embeddings(
|
19 |
+
self,
|
20 |
+
images: torch.Tensor,
|
21 |
+
raw_text: list[str],
|
22 |
+
vae: AutoencoderKL,
|
23 |
+
t5: T5EncoderModel,
|
24 |
+
clip_l: CLIPTextModel,
|
25 |
+
clip_g: CLIPTextModel,
|
26 |
+
t5_tokenizer: T5Tokenizer,
|
27 |
+
clip_l_tokenizer: CLIPTokenizerFast,
|
28 |
+
clip_g_tokenizer: CLIPTokenizerFast,
|
29 |
+
is_training,
|
30 |
+
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
|
31 |
+
"""Prepare image latents and text embeddings for model input.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
images (torch.Tensor): Input images tensor with shape [B, C=3, H, W].
|
35 |
+
raw_text (List[str]): List of raw text strings with length B.
|
36 |
+
"""
|
37 |
+
with torch.no_grad():
|
38 |
+
latents: torch.Tensor = (
|
39 |
+
vae.encode(images).latent_dist.sample() - vae.config.shift_factor
|
40 |
+
) * vae.config.scaling_factor # Latents shape: [B, 16, H//8, W//8]
|
41 |
+
|
42 |
+
# Tokenize the input text and move tokens and masks to the same device as latents
|
43 |
+
tokenizers = [t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer]
|
44 |
+
tokens, masks = self.tokenization(raw_text, tokenizers)
|
45 |
+
tokens = [token.to(latents.device) for token in tokens]
|
46 |
+
masks = [mask.to(latents.device) for mask in masks]
|
47 |
+
|
48 |
+
# Encode the text and drop unnecessary embeddings
|
49 |
+
text_embeddings, pooled_text_embeddings = self.text_encoding(
|
50 |
+
tokens,
|
51 |
+
masks,
|
52 |
+
t5,
|
53 |
+
clip_l,
|
54 |
+
clip_g,
|
55 |
+
t5_tokenizer.pad_token_id,
|
56 |
+
clip_l_tokenizer.eos_token_id,
|
57 |
+
clip_g_tokenizer.eos_token_id,
|
58 |
+
is_training,
|
59 |
+
)
|
60 |
+
text_embeddings = self.drop_text_emb(text_embeddings)
|
61 |
+
|
62 |
+
# Convert text embeddings to float
|
63 |
+
text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
|
64 |
+
|
65 |
+
# Convert pooled text embeddings to float
|
66 |
+
pooled_text_embeddings = pooled_text_embeddings.float()
|
67 |
+
|
68 |
+
return latents, text_embeddings, pooled_text_embeddings
|
69 |
+
|
70 |
+
def get_freezed_encoders_and_tokenizers(
|
71 |
+
self, vae_type: str
|
72 |
+
) -> tuple[
|
73 |
+
AutoencoderKL, T5EncoderModel, CLIPTextModel, CLIPTextModel, T5Tokenizer, CLIPTokenizerFast, CLIPTokenizerFast
|
74 |
+
]:
|
75 |
+
"""Initialize the VAE and text encoders."""
|
76 |
+
if vae_type != "SD3":
|
77 |
+
raise ValueError(
|
78 |
+
f"VAE type must be `SD3` but self.config.vae_type is {vae_type}."
|
79 |
+
f" note that the VAE type SDXL is deprecated."
|
80 |
+
)
|
81 |
+
|
82 |
+
vae: AutoencoderKL = AutoencoderKL.from_pretrained(
|
83 |
+
"stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae"
|
84 |
+
)
|
85 |
+
|
86 |
+
# Text encoders
|
87 |
+
# 1. T5-XXL from Google
|
88 |
+
t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
|
89 |
+
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
|
90 |
+
|
91 |
+
# 2. CLIP-L from OpenAI
|
92 |
+
clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
|
93 |
+
clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
|
94 |
+
|
95 |
+
# 3. CLIP-G from LAION
|
96 |
+
clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
|
97 |
+
clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
98 |
+
|
99 |
+
# Freeze all encoders
|
100 |
+
|
101 |
+
for encoder_module in [vae, clip_l, clip_g, t5]:
|
102 |
+
for param in encoder_module.parameters():
|
103 |
+
param.requires_grad = False
|
104 |
+
|
105 |
+
return vae, t5, clip_l, clip_g, t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer
|
106 |
+
|
107 |
+
def tokenization(
|
108 |
+
self, raw_text: list[str], tokenizers: list[T5Tokenizer | CLIPTokenizerFast]
|
109 |
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
110 |
+
"""Tokenize the input text using multiple tokenizers.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
raw_text (str): Input text string.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
Tuple[List[torch.Tensor], List[torch.Tensor]]: Lists of tokenized text tensors and attention masks.
|
117 |
+
"""
|
118 |
+
tokens, masks = [], []
|
119 |
+
for tokenizer in tokenizers:
|
120 |
+
tok = tokenizer(
|
121 |
+
raw_text,
|
122 |
+
padding="max_length",
|
123 |
+
max_length=min(EncoderMixin.TOKEN_MAX_LENGTH, tokenizer.model_max_length),
|
124 |
+
return_tensors="pt",
|
125 |
+
truncation=True,
|
126 |
+
)
|
127 |
+
tokens.append(tok.input_ids)
|
128 |
+
masks.append(tok.attention_mask)
|
129 |
+
return tokens, masks
|
130 |
+
|
131 |
+
@torch.no_grad()
|
132 |
+
def text_encoding(
|
133 |
+
self,
|
134 |
+
tokens: list[torch.Tensor],
|
135 |
+
masks: list[torch.Tensor],
|
136 |
+
t5: T5EncoderModel,
|
137 |
+
clip_l: CLIPTextModel,
|
138 |
+
clip_g: CLIPTextModel,
|
139 |
+
t5_pad_token_id: int = 0,
|
140 |
+
clip_l_tokenizer_eos_token_id: int = 49407,
|
141 |
+
clip_g_tokenizer_eos_token_id: int = 49407,
|
142 |
+
is_training: bool = False,
|
143 |
+
) -> tuple[list[torch.Tensor], torch.Tensor]:
|
144 |
+
"""Encode the tokenized text using multiple text encoders.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
tokens (List[torch.Tensor]): List of tokenized text tensors.
|
148 |
+
masks (List[torch.Tensor]): List of attention masks.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
Tuple[List[torch.Tensor], torch.Tensor]: Text embeddings and pooled text embeddings.
|
152 |
+
"""
|
153 |
+
t5_tokens, clip_l_tokens, clip_g_tokens = tokens
|
154 |
+
t5_masks, _, _ = masks
|
155 |
+
|
156 |
+
# T5 encoding
|
157 |
+
t5_emb = t5(t5_tokens, attention_mask=t5_masks)[0]
|
158 |
+
t5_emb = t5_emb * (t5_tokens != t5_pad_token_id).unsqueeze(-1)
|
159 |
+
|
160 |
+
# CLIP encodings
|
161 |
+
clip_l_emb = clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
|
162 |
+
clip_g_emb = clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
|
163 |
+
|
164 |
+
# Get pooled outputs
|
165 |
+
clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768
|
166 |
+
clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280
|
167 |
+
|
168 |
+
if is_training:
|
169 |
+
clip_l_emb_pooled = self.drop_text_emb(clip_l_emb_pooled)
|
170 |
+
clip_g_emb_pooled = self.drop_text_emb(clip_g_emb_pooled)
|
171 |
+
|
172 |
+
clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768
|
173 |
+
clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280
|
174 |
+
|
175 |
+
def masking_wo_first_eos(token, eos):
|
176 |
+
"""Create attention mask without first EOS token."""
|
177 |
+
idx = (token != eos).sum(dim=1)
|
178 |
+
mask = token != eos
|
179 |
+
arange = torch.arange(mask.size(0)).cuda()
|
180 |
+
if idx != len(mask[0]):
|
181 |
+
mask[arange, idx] = True
|
182 |
+
return mask.unsqueeze(-1) # B x L x 1
|
183 |
+
|
184 |
+
# Apply masking
|
185 |
+
clip_l_emb = clip_l_emb * masking_wo_first_eos(clip_l_tokens, clip_l_tokenizer_eos_token_id)
|
186 |
+
clip_g_emb = clip_g_emb * masking_wo_first_eos(clip_g_tokens, clip_g_tokenizer_eos_token_id)
|
187 |
+
|
188 |
+
encodings = [t5_emb, clip_l_emb, clip_g_emb]
|
189 |
+
pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # B x 2048
|
190 |
+
|
191 |
+
return encodings, pooled_encodings
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def drop_text_emb(
|
195 |
+
self, text_embeddings: list[torch.Tensor] | torch.Tensor, drop_prob: float = 0.464
|
196 |
+
) -> list[torch.Tensor] | torch.Tensor:
|
197 |
+
"""Randomly drop text embeddings with a specified probability.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
text_embeddings (Union[List[torch.Tensor], torch.Tensor]): Text embeddings to be dropped.
|
201 |
+
drop_prob (float, optional): Probability of dropping text embeddings. Defaults to 0.464.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
Union[List[torch.Tensor], torch.Tensor]: Text embeddings with dropped elements.
|
205 |
+
"""
|
206 |
+
if isinstance(text_embeddings, list):
|
207 |
+
# For BxLxC features
|
208 |
+
for text_embedding in text_embeddings:
|
209 |
+
probs = torch.ones((text_embedding.shape[0])).cuda() * (1 - drop_prob)
|
210 |
+
masks = torch.bernoulli(probs).cuda()
|
211 |
+
while len(masks.shape) < len(text_embedding.shape):
|
212 |
+
masks = masks.unsqueeze(-1)
|
213 |
+
text_embedding = text_embedding * masks
|
214 |
+
else:
|
215 |
+
# For a pooled BxC feature
|
216 |
+
probs = torch.ones((text_embeddings.shape[0])).cuda() * (1 - drop_prob)
|
217 |
+
masks = torch.bernoulli(probs).cuda()
|
218 |
+
while len(masks.shape) < len(text_embeddings.shape):
|
219 |
+
masks = masks.unsqueeze(-1)
|
220 |
+
text_embeddings = text_embeddings * masks
|
221 |
+
|
222 |
+
return text_embeddings
|
models/mixin/flow_mixin.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dongpin Oh: [email protected]
|
3 |
+
"""
|
4 |
+
|
5 |
+
import time
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
11 |
+
from loguru import logger
|
12 |
+
|
13 |
+
|
14 |
+
class FlowMixin:
|
15 |
+
"""Mixin class for flow-based models."""
|
16 |
+
|
17 |
+
_is_noise_initialized = False
|
18 |
+
MIN_STD: float = 0.0 # minimum size of std for the flow matching
|
19 |
+
CLAMP_CONTINUOUS_TIME: float = 0.0
|
20 |
+
|
21 |
+
def _timestep_shifting(self, timesteps: torch.Tensor, alpha: float = 3.0) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
Adjust the timesteps for higher resolution images by adding more noise.
|
24 |
+
higher resolution have more pixels we need more noise to destory their signal.
|
25 |
+
|
26 |
+
NOTE that the timestep must be reversed unlike original SD3 style timestep shifting,
|
27 |
+
since the flow-timestep is reversed (t=0: original image; t=1: pure noise)
|
28 |
+
Args:
|
29 |
+
timesteps (torch.Tensor): The original timesteps.
|
30 |
+
alpha (float, optional): Scaling factor for timestep shifting. Defaults to 3.0.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
torch.Tensor: The reversed and shifted timesteps.
|
34 |
+
"""
|
35 |
+
shifted_t = alpha * timesteps / (1 + (alpha - 1) * timesteps)
|
36 |
+
reversed_t = 1 - shifted_t
|
37 |
+
return reversed_t
|
38 |
+
|
39 |
+
def get_noisy_input(
|
40 |
+
self,
|
41 |
+
input: torch.Tensor,
|
42 |
+
normal_mean: float = 0.0,
|
43 |
+
normal_std: float = 1.0,
|
44 |
+
is_finetuning: bool = False,
|
45 |
+
t: torch.Tensor | None = None,
|
46 |
+
n_timesteps: int = 1000,
|
47 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
48 |
+
"""
|
49 |
+
Generate noisy input, based on the optimal-transport flow.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
input (torch.Tensor): Input tensor.
|
53 |
+
normal_mean (float, optional): Mean of the normal distribution for noise. Defaults to 0.0.
|
54 |
+
normal_std (float, optional): Standard deviation of the normal distribution for noise. Defaults to 1.0.
|
55 |
+
is_finetuning (bool, optional): Whether the model is in finetuning mode. Defaults to False.
|
56 |
+
t (torch.Tensor | None, optional): Predefined timesteps. If None, timesteps are sampled. Defaults to None.
|
57 |
+
n_timesteps (int, opyionsl): Number of discrete timesteps. Defaults to 1000.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing noise, noisy input, and timestep.
|
61 |
+
- The timestep t is a continuous timestep ranged [0, 1] which cannot directly be used
|
62 |
+
for the timestep embedding (needs to be discretized by self.discritize_timestep()).
|
63 |
+
"""
|
64 |
+
b = input.shape[0]
|
65 |
+
|
66 |
+
if not FlowMixin._is_noise_initialized:
|
67 |
+
logger.warning("The torch random seed is changed when generating the initial random noise.")
|
68 |
+
current_time = int(time.time() * 1000)
|
69 |
+
torch.manual_seed(current_time)
|
70 |
+
FlowMixin._is_noise_initialized = True
|
71 |
+
|
72 |
+
noise = randn_tensor(input.shape).cuda()
|
73 |
+
|
74 |
+
# Sample timestep from a log-normal distribution with mean 0 and std 1
|
75 |
+
if t is None:
|
76 |
+
# NOTE: timestep is sampled from log-normal distribution to make the model
|
77 |
+
# focus on the intermediate timesteps, which are the most informative part of flow-ODE
|
78 |
+
t = torch.randn(b).cuda() * normal_std + normal_mean
|
79 |
+
t = torch.sigmoid(t)
|
80 |
+
else:
|
81 |
+
t = t
|
82 |
+
|
83 |
+
# Clamp t to be within the interval [0, 1] for numerical stability
|
84 |
+
t = t.clamp(0 + self.CLAMP_CONTINUOUS_TIME, 1 - self.CLAMP_CONTINUOUS_TIME)
|
85 |
+
|
86 |
+
if is_finetuning:
|
87 |
+
t = self._timestep_shifting(t)
|
88 |
+
|
89 |
+
# Reshape t to match the dimensions required
|
90 |
+
for _ in range(len(noise.shape) - 1):
|
91 |
+
t = t.unsqueeze(1)
|
92 |
+
|
93 |
+
# Generate the noisy input
|
94 |
+
noisy_input = (1 - t) * input + (self.MIN_STD + (1 - self.MIN_STD) * t) * noise
|
95 |
+
|
96 |
+
t_squeezed = t.squeeze()
|
97 |
+
if t_squeezed.dim() == 0:
|
98 |
+
t_squeezed = t_squeezed.unsqueeze(0)
|
99 |
+
return noise, noisy_input, t_squeezed
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def _logit_norm(self, t: torch.Tensor, m: float = 0, s: float = 1) -> torch.Tensor:
|
103 |
+
"""
|
104 |
+
Compute the loss-weight for the flow-matching loss.
|
105 |
+
It will be focusing (giving high weights) on the intermidate timestep, since
|
106 |
+
such timesteps are hard to be matched, according to https://arxiv.org/pdf/2403.03206.pdf
|
107 |
+
Args:
|
108 |
+
t (torch.Tensor): Timestep tensor. (0 to 1)
|
109 |
+
m (float, optional): Mean of the logit distribution. Defaults to 0.
|
110 |
+
s (float, optional): Standard deviation of the logit distribution. Defaults to 1.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
torch.Tensor: Weight tensor for the flow-matching loss.
|
114 |
+
"""
|
115 |
+
coef = (1 / (s * ((2 * np.pi) ** 0.5))) * (1 / (t * (1 - t)))
|
116 |
+
|
117 |
+
def logit(x):
|
118 |
+
return torch.log(x) - torch.log(1 - x)
|
119 |
+
|
120 |
+
exp = torch.exp(-((logit(t) - m) ** 2) / (2 * s**2))
|
121 |
+
return coef * exp
|
122 |
+
|
123 |
+
def rectified_flow_loss(
|
124 |
+
self,
|
125 |
+
input: torch.Tensor,
|
126 |
+
noise: torch.Tensor,
|
127 |
+
t: torch.Tensor,
|
128 |
+
preds: torch.Tensor,
|
129 |
+
use_weighting: bool = False,
|
130 |
+
reduce: str = "mean",
|
131 |
+
) -> torch.Tensor:
|
132 |
+
"""
|
133 |
+
Compute the rectified flow loss, https://arxiv.org/pdf/2403.03206.pdf
|
134 |
+
|
135 |
+
Args:
|
136 |
+
input (torch.Tensor): Input tensor.
|
137 |
+
noise (torch.Tensor): Noise tensor.
|
138 |
+
t (torch.Tensor): Timestep tensor.
|
139 |
+
preds (torch.Tensor): Predicted tensor.
|
140 |
+
use_weighting (bool, optional): Whether to use weighting for the loss. Defaults to False.
|
141 |
+
reduce (str, optional): Reduction method for the loss. Options are 'mean' or 'none'. Defaults to 'mean'.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
torch.Tensor: Rectified flow loss.
|
145 |
+
"""
|
146 |
+
|
147 |
+
# Matching dimension for broadcasting
|
148 |
+
t = t.reshape(t.shape[0], *[1 for _ in range(len(input.shape) - len(t.shape))])
|
149 |
+
|
150 |
+
target_flow = (1 - self.MIN_STD) * noise - input
|
151 |
+
loss = F.mse_loss(preds.float(), target_flow.float(), reduction="none")
|
152 |
+
if use_weighting:
|
153 |
+
weight = self._logit_norm(t).detach()
|
154 |
+
loss = loss * weight
|
155 |
+
if reduce == "mean":
|
156 |
+
loss = loss.mean()
|
157 |
+
elif reduce == "none":
|
158 |
+
loss = loss
|
159 |
+
else:
|
160 |
+
raise NotImplementedError
|
161 |
+
|
162 |
+
return loss
|
163 |
+
|
164 |
+
def discritize_timestep(self, t: torch.Tensor, n_timesteps: int = 1000) -> torch.Tensor:
|
165 |
+
"""
|
166 |
+
Discretize the continuous timestep.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
t (torch.Tensor): Continuous timestep.
|
170 |
+
n_timesteps (int, optional): Number of discrete timesteps. Defaults to 1000.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
torch.Tensor: Discretized timestep tensor.
|
174 |
+
"""
|
175 |
+
return (t * n_timesteps).round().long()
|
models/modeling_dit.py
ADDED
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
9 |
+
from loguru import logger
|
10 |
+
|
11 |
+
try:
|
12 |
+
motif_ops = torch.ops.motif
|
13 |
+
MotifRMSNorm = motif_ops.T5LayerNorm
|
14 |
+
ScaledDotProductAttention = None
|
15 |
+
MotifFlashAttention = motif_ops.flash_attention
|
16 |
+
except ImportError: # if motif_ops is not available
|
17 |
+
MotifRMSNorm = None
|
18 |
+
ScaledDotProductAttention = None
|
19 |
+
MotifFlashAttention = None
|
20 |
+
|
21 |
+
NUM_MODULATIONS = 6
|
22 |
+
SD3_LATENT_CHANNEL = 16
|
23 |
+
LOW_RES_POSEMB_BASE_SIZE = 16
|
24 |
+
HIGH_RES_POSEMB_BASE_SIZE = 64
|
25 |
+
|
26 |
+
class IdentityConv2d(nn.Module):
|
27 |
+
def __init__(self, channels, kernel_size=3, stride=1, padding=1, bias=True):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.conv = nn.Conv2d(
|
31 |
+
in_channels=channels,
|
32 |
+
out_channels=channels,
|
33 |
+
kernel_size=kernel_size,
|
34 |
+
stride=stride,
|
35 |
+
padding=padding,
|
36 |
+
bias=bias,
|
37 |
+
)
|
38 |
+
|
39 |
+
self._initialize_identity()
|
40 |
+
|
41 |
+
def _initialize_identity(self):
|
42 |
+
k = self.conv.kernel_size[0]
|
43 |
+
|
44 |
+
nn.init.zeros_(self.conv.weight)
|
45 |
+
|
46 |
+
center = k // 2
|
47 |
+
for i in range(self.conv.in_channels):
|
48 |
+
self.conv.weight.data[i, i, center, center] = 1.0
|
49 |
+
|
50 |
+
if self.conv.bias is not None:
|
51 |
+
nn.init.zeros_(self.conv.bias)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return self.conv(x)
|
55 |
+
|
56 |
+
|
57 |
+
class RMSNorm(nn.Module):
|
58 |
+
def __init__(self, hidden_size, eps=1e-6):
|
59 |
+
"""
|
60 |
+
LlamaRMSNorm is equivalent to T5LayerNorm
|
61 |
+
"""
|
62 |
+
super().__init__()
|
63 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
64 |
+
self.variance_epsilon = eps
|
65 |
+
self.mask = None
|
66 |
+
|
67 |
+
def forward(self, hidden_states):
|
68 |
+
input_dtype = hidden_states.dtype
|
69 |
+
hidden_states = hidden_states.to(torch.float)
|
70 |
+
if self.mask is not None:
|
71 |
+
hidden_states = self.mask.to(hidden_states.device).to(hidden_states.dtype) * hidden_states
|
72 |
+
variance = hidden_states.pow(2).sum(-1, keepdim=True)
|
73 |
+
if self.mask is not None:
|
74 |
+
variance /= torch.count_nonzero(self.mask)
|
75 |
+
else:
|
76 |
+
variance /= hidden_states.shape[-1]
|
77 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
78 |
+
return self.weight * hidden_states.to(input_dtype)
|
79 |
+
|
80 |
+
|
81 |
+
class MLP(nn.Module):
|
82 |
+
def __init__(self, input_size, hidden_size=None):
|
83 |
+
super().__init__()
|
84 |
+
if hidden_size is None:
|
85 |
+
self.input_size, self.hidden_size = input_size, input_size * 4
|
86 |
+
else:
|
87 |
+
self.input_size, self.hidden_size = input_size, hidden_size
|
88 |
+
|
89 |
+
self.gate_proj = nn.Linear(self.input_size, self.hidden_size)
|
90 |
+
self.down_proj = nn.Linear(self.hidden_size, self.input_size)
|
91 |
+
|
92 |
+
self.act_fn = nn.SiLU()
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
down_proj = self.act_fn(self.gate_proj(x))
|
96 |
+
down_proj = self.down_proj(down_proj)
|
97 |
+
|
98 |
+
return down_proj
|
99 |
+
|
100 |
+
|
101 |
+
class TextTimeEmbToGlobalParams(nn.Module):
|
102 |
+
def __init__(self, emb_dim, hidden_dim):
|
103 |
+
super().__init__()
|
104 |
+
self.projection = nn.Linear(emb_dim, hidden_dim * NUM_MODULATIONS)
|
105 |
+
|
106 |
+
def forward(self, emb):
|
107 |
+
emb = F.silu(emb) # emb: B x D
|
108 |
+
params = self.projection(emb) # emb: B x C
|
109 |
+
params = params.reshape(params.shape[0], NUM_MODULATIONS, params.shape[-1] // NUM_MODULATIONS) # emb: B x 6 x C
|
110 |
+
return params.chunk(6, dim=1) # [B x 1 x C] x 6
|
111 |
+
|
112 |
+
|
113 |
+
class TextTimeEmbedding(nn.Module):
|
114 |
+
"""
|
115 |
+
Input:
|
116 |
+
pooled_text_emb (B x C_l)
|
117 |
+
time_steps (B)
|
118 |
+
|
119 |
+
Output:
|
120 |
+
()
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(self, time_channel, text_channel, embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0):
|
124 |
+
super().__init__()
|
125 |
+
self.time_proj = Timesteps(
|
126 |
+
time_channel, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift
|
127 |
+
)
|
128 |
+
self.time_emb = TimestepEmbedding(time_channel, time_channel * 4, out_dim=embed_dim) # Encode time emb with MLP
|
129 |
+
self.pooled_text_emb = TimestepEmbedding(
|
130 |
+
text_channel, text_channel * 4, out_dim=embed_dim
|
131 |
+
) # Encode pooled text with MLP
|
132 |
+
|
133 |
+
def forward(self, pooled_text_emb, time_steps):
|
134 |
+
time_steps = self.time_proj(time_steps)
|
135 |
+
time_emb = self.time_emb(time_steps.to(dtype=torch.bfloat16))
|
136 |
+
pooled_text_emb = self.pooled_text_emb(pooled_text_emb)
|
137 |
+
|
138 |
+
return time_emb + pooled_text_emb
|
139 |
+
|
140 |
+
|
141 |
+
class LatentPatchModule(nn.Module):
|
142 |
+
def __init__(self, patch_size, embedding_dim, latent_channels, vae_type):
|
143 |
+
super().__init__()
|
144 |
+
self.patch_size = patch_size
|
145 |
+
self.embedding_dim = embedding_dim
|
146 |
+
self.projection_SD3 = nn.Conv2d(SD3_LATENT_CHANNEL, embedding_dim, kernel_size=patch_size, stride=patch_size)
|
147 |
+
self.latent_channels = latent_channels
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
assert (
|
151 |
+
x.shape[1] == SD3_LATENT_CHANNEL
|
152 |
+
), f"VAE-Latent channel is not matched with '{SD3_LATENT_CHANNEL}'. current shape: {x.shape}"
|
153 |
+
patches = self.projection_SD3(
|
154 |
+
x.to(dtype=torch.bfloat16)
|
155 |
+
) # Shape: (B, embedding_dim, num_patches_h, num_patches_w)
|
156 |
+
patches = patches.to(dtype=torch.bfloat16)
|
157 |
+
patches = patches.contiguous()
|
158 |
+
patches = patches.flatten(2) # Shape: (B, embedding_dim, num_patches)
|
159 |
+
|
160 |
+
patches = patches.transpose(1, 2) # Shape: (B, num_patches, embedding_dim)
|
161 |
+
patches = patches.contiguous()
|
162 |
+
return patches
|
163 |
+
|
164 |
+
def unpatchify(self, x):
|
165 |
+
"""
|
166 |
+
x: (N, T, patch_size**2 * C)
|
167 |
+
imgs: (N, H, W, C)
|
168 |
+
"""
|
169 |
+
n = x.shape[0]
|
170 |
+
c = self.latent_channels
|
171 |
+
p = self.patch_size
|
172 |
+
|
173 |
+
# check the valid patching
|
174 |
+
h = w = int(x.shape[1] ** 0.5)
|
175 |
+
assert h * w == x.shape[1]
|
176 |
+
|
177 |
+
x = x.contiguous()
|
178 |
+
# (N x T x [C * patch_size**2]) -> (N x H x W x P_1 x P_2 x C)
|
179 |
+
x = x.reshape(shape=(n, h, w, p, p, c))
|
180 |
+
# x = torch.einsum('nhwpqc->nchpwq', x) # Note that einsum possibly be the problem.
|
181 |
+
|
182 |
+
# (N x H x W x P_1 x P_2 x C) -> (N x C x H x P_1 x W x P_2)
|
183 |
+
# (0 . 1 . 2 . 3 . 4 . 5) -> (0 . 5 . 1 . 3 2 . 4 )
|
184 |
+
x = x.permute(0, 5, 1, 3, 2, 4)
|
185 |
+
return x.reshape(shape=(n, c, h * p, h * p)).contiguous()
|
186 |
+
|
187 |
+
|
188 |
+
class TextConditionModule(nn.Module):
|
189 |
+
def __init__(self, text_dim, latent_dim):
|
190 |
+
super().__init__()
|
191 |
+
self.projection = nn.Linear(text_dim, latent_dim)
|
192 |
+
|
193 |
+
def forward(self, t5_xxl, clip_a, clip_b):
|
194 |
+
clip_emb = torch.cat([clip_a, clip_b], dim=-1)
|
195 |
+
clip_emb = torch.nn.functional.pad(clip_emb, (0, t5_xxl.shape[-1] - clip_emb.shape[-1]))
|
196 |
+
text_emb = torch.cat([clip_emb, t5_xxl], dim=-2)
|
197 |
+
text_emb = self.projection(text_emb.to(torch.bfloat16))
|
198 |
+
return text_emb
|
199 |
+
|
200 |
+
|
201 |
+
class MotifDiTBlock(nn.Module):
|
202 |
+
def __init__(self, emb_dim, t_emb_dim, attn_emb_dim, mlp_dim, attn_config, text_dim=4096):
|
203 |
+
super().__init__()
|
204 |
+
self.affine_params_c = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim)
|
205 |
+
self.affine_params_x = TextTimeEmbToGlobalParams(t_emb_dim, emb_dim)
|
206 |
+
|
207 |
+
self.norm_1_c = nn.LayerNorm(emb_dim, elementwise_affine=False)
|
208 |
+
self.norm_1_x = nn.LayerNorm(emb_dim, elementwise_affine=False)
|
209 |
+
self.linear_1_c = nn.Linear(emb_dim, attn_emb_dim)
|
210 |
+
self.linear_1_x = nn.Linear(emb_dim, attn_emb_dim)
|
211 |
+
|
212 |
+
self.attn = JointAttn(attn_config)
|
213 |
+
self.norm_2_c = nn.LayerNorm(emb_dim, elementwise_affine=False)
|
214 |
+
self.norm_2_x = nn.LayerNorm(emb_dim, elementwise_affine=False)
|
215 |
+
self.mlp_3_c = MLP(emb_dim, mlp_dim)
|
216 |
+
self.mlp_3_x = MLP(emb_dim, mlp_dim)
|
217 |
+
|
218 |
+
def forward(self, x_emb, c_emb, t_emb, perturbed=False):
|
219 |
+
"""
|
220 |
+
x_emb (N, TOKEN_LENGTH x 2, C)
|
221 |
+
c_emb (N, T + REGISTER_TOKENS, C)
|
222 |
+
t_emb (N, modulation_dim)
|
223 |
+
"""
|
224 |
+
|
225 |
+
device = x_emb.device
|
226 |
+
|
227 |
+
# get global affine transformation parameters
|
228 |
+
alpha_x, beta_x, gamma_x, delta_x, epsilon_x, zeta_x = self.affine_params_x(t_emb) # scale and shift for image
|
229 |
+
alpha_c, beta_c, gamma_c, delta_c, epsilon_c, zeta_c = self.affine_params_c(t_emb) # scale and shift for text
|
230 |
+
|
231 |
+
# projection and affine transform before attention
|
232 |
+
x_emb_pre_attn = self.linear_1_x((1 + alpha_x) * self.norm_1_x(x_emb) + beta_x)
|
233 |
+
c_emb_pre_attn = self.linear_1_c((1 + alpha_c) * self.norm_1_c(c_emb) + beta_c)
|
234 |
+
|
235 |
+
# attn_output, attn_weight (None), past_key_value (None)
|
236 |
+
x_emb_post_attn, c_emb_post_attn = self.attn(
|
237 |
+
x_emb_pre_attn, c_emb_pre_attn, perturbed
|
238 |
+
) # mixed feature for both text and image (N, [T_x + T_c], C)
|
239 |
+
|
240 |
+
# scale with gamma and residual with the original inputs
|
241 |
+
x_emb_post_attn = x_emb_post_attn.to(gamma_x.device)
|
242 |
+
x_emb_post_attn = (1 + gamma_x) * x_emb_post_attn + x_emb # NOTE: nan loss for self.linear_2_x.bias
|
243 |
+
c_emb_post_attn = c_emb_post_attn.to(gamma_c.device)
|
244 |
+
c_emb_post_attn = (1 + gamma_c) * c_emb_post_attn + c_emb
|
245 |
+
|
246 |
+
# norm the features -> affine transform with modulation -> MLP
|
247 |
+
normalized_x_emb = self.norm_2_x(x_emb_post_attn).to(delta_x.device)
|
248 |
+
normalized_c_emb = self.norm_2_c(c_emb_post_attn).to(delta_c.device)
|
249 |
+
x_emb_final = self.mlp_3_x(delta_x * normalized_x_emb + epsilon_x)
|
250 |
+
c_emb_final = self.mlp_3_c(delta_c * normalized_c_emb + epsilon_c)
|
251 |
+
|
252 |
+
# final scaling with zeta and residual with the original inputs
|
253 |
+
x_emb_final = zeta_x.to(device) * x_emb_final.to(device) + x_emb.to(device)
|
254 |
+
c_emb_final = zeta_c.to(device) * c_emb_final.to(device) + c_emb.to(device)
|
255 |
+
|
256 |
+
return x_emb_final, c_emb_final
|
257 |
+
|
258 |
+
|
259 |
+
class MotifDiT(nn.Module):
|
260 |
+
ENCODED_TEXT_DIM = 4096
|
261 |
+
|
262 |
+
def __init__(self, config):
|
263 |
+
super(MotifDiT, self).__init__()
|
264 |
+
self.patch_size = config.patch_size
|
265 |
+
self.h, self.w = config.height // config.vae_compression, config.width // config.vae_compression
|
266 |
+
|
267 |
+
self.latent_chennels = 16
|
268 |
+
|
269 |
+
# Embedding for (1) text; (2) input image; (3) time
|
270 |
+
self.text_cond = TextConditionModule(self.ENCODED_TEXT_DIM, config.hidden_dim)
|
271 |
+
self.patching = LatentPatchModule(config.patch_size, config.hidden_dim, self.latent_chennels, config.vae_type)
|
272 |
+
self.time_emb = TextTimeEmbedding(config.time_embed_dim, config.pooled_text_dim, config.modulation_dim)
|
273 |
+
|
274 |
+
# main multi-modal DiT blocks
|
275 |
+
self.mmdit_blocks = nn.ModuleList(
|
276 |
+
[
|
277 |
+
MotifDiTBlock(
|
278 |
+
config.hidden_dim, config.modulation_dim, config.hidden_dim, config.mlp_hidden_dim, config
|
279 |
+
)
|
280 |
+
for layer_idx in range(config.num_layers)
|
281 |
+
]
|
282 |
+
)
|
283 |
+
|
284 |
+
self.final_modulation = nn.Linear(config.modulation_dim, config.hidden_dim * 2)
|
285 |
+
self.final_linear_SD3 = nn.Linear(config.hidden_dim, SD3_LATENT_CHANNEL * config.patch_size**2)
|
286 |
+
self.skip_register_token_num = config.skip_register_token_num
|
287 |
+
|
288 |
+
if getattr(config, "pos_emb_size", None):
|
289 |
+
pos_emb_size = config.pos_emb_size
|
290 |
+
else:
|
291 |
+
pos_emb_size = HIGH_RES_POSEMB_BASE_SIZE if config.height > 512 else LOW_RES_POSEMB_BASE_SIZE
|
292 |
+
logger.info(f"Positional embedding of Motif-DiT is set to {pos_emb_size}")
|
293 |
+
|
294 |
+
self.pos_embed = torch.from_numpy(
|
295 |
+
get_2d_sincos_pos_embed(
|
296 |
+
config.hidden_dim, (self.h // self.patch_size, self.w // self.patch_size), base_size=pos_emb_size
|
297 |
+
)
|
298 |
+
).to(device="cuda", dtype=torch.bfloat16)
|
299 |
+
|
300 |
+
# set register tokens (https://arxiv.org/abs/2309.16588)
|
301 |
+
if config.register_token_num > 0:
|
302 |
+
self.register_token_num = config.register_token_num
|
303 |
+
self.register_tokens = nn.Parameter(torch.randn(1, self.register_token_num, config.hidden_dim))
|
304 |
+
self.register_parameter("register_tokens", self.register_tokens)
|
305 |
+
|
306 |
+
# if needed, add additional register tokens for higher resolution training
|
307 |
+
self.additional_register_token_num = config.additional_register_token_num
|
308 |
+
if config.additional_register_token_num > 0:
|
309 |
+
self.register_tokens_highres = nn.Parameter(
|
310 |
+
torch.randn(1, self.additional_register_token_num, config.hidden_dim)
|
311 |
+
)
|
312 |
+
self.register_parameter("register_tokens_highres", self.register_tokens_highres)
|
313 |
+
|
314 |
+
if config.use_final_layer_norm:
|
315 |
+
self.final_norm = nn.LayerNorm(config.hidden_dim)
|
316 |
+
|
317 |
+
if config.conv_header:
|
318 |
+
logger.info("use convolution header after de-patching")
|
319 |
+
self.depatching_conv_header = IdentityConv2d(SD3_LATENT_CHANNEL)
|
320 |
+
|
321 |
+
if config.use_time_token_in_attn:
|
322 |
+
self.t_token_proj = nn.Linear(config.modulation_dim, config.hidden_dim)
|
323 |
+
|
324 |
+
def forward(self, latent, t, text_embs: List[torch.Tensor], pooled_text_embs, guiding_feature=None):
|
325 |
+
"""
|
326 |
+
latent (torch.Tensor)
|
327 |
+
t (torch.Tensor)
|
328 |
+
text_embs (List[torch.Tensor])
|
329 |
+
pooled_text_embs (torch.Tensor)
|
330 |
+
"""
|
331 |
+
# 1. get inputs for the MMDiT blocks
|
332 |
+
emb_c = self.text_cond(*text_embs) # (N, L, D), text conditions
|
333 |
+
emb_t = self.time_emb(pooled_text_embs, t).to(emb_c.device) # (N, D), time and pooled text conditions
|
334 |
+
|
335 |
+
emb_x = (self.patching(latent) + self.pos_embed).to(
|
336 |
+
emb_c.device
|
337 |
+
) # (N, T, D), where T = H*W / (patch_size ** 2), input latent patches
|
338 |
+
|
339 |
+
# additional "register" tokens, to convey the global information and prevent high-norm abnormal patch
|
340 |
+
# see https://openreview.net/forum?id=2dnO3LLiJ1
|
341 |
+
if hasattr(self, "register_tokens"):
|
342 |
+
if hasattr(self, "register_tokens_highres"):
|
343 |
+
emb_x = torch.cat(
|
344 |
+
(
|
345 |
+
self.register_tokens_highres.expand(emb_x.shape[0], -1, -1),
|
346 |
+
self.register_tokens.expand(emb_x.shape[0], -1, -1),
|
347 |
+
emb_x,
|
348 |
+
),
|
349 |
+
dim=1,
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
emb_x = torch.cat((self.register_tokens.expand(emb_x.shape[0], -1, -1), emb_x), dim=1)
|
353 |
+
|
354 |
+
# time embedding into text embedding
|
355 |
+
if hasattr(self, "use_time_token_in_attn"):
|
356 |
+
t_token = self.t_token_proj(emb_t).unsqueeze(1)
|
357 |
+
emb_c = torch.cat([emb_c, t_token], dim=1) # (N, [T_c + 1], C)
|
358 |
+
|
359 |
+
# 2. MMDiT Blocks
|
360 |
+
for block_idx, block in enumerate(self.mmdit_blocks):
|
361 |
+
emb_x, emb_c = block(emb_x, emb_c, emb_t)
|
362 |
+
|
363 |
+
# accumulating the feature_similarity loss
|
364 |
+
# TODO: add modeling_dit related test
|
365 |
+
if hasattr(self, "num_feature_align_layers") and block_idx == self.num_feature_align_layers:
|
366 |
+
self.feature_alignment_loss = self.feature_align_mlp(emb_x, guiding_feature) # exclude register tokens
|
367 |
+
|
368 |
+
# Remove the register tokens at the certain layer (the last layer as default).
|
369 |
+
if block_idx == len(self.mmdit_blocks) - (1 + self.skip_register_token_num):
|
370 |
+
if hasattr(self, "register_tokens_highres"):
|
371 |
+
emb_x = emb_x[
|
372 |
+
:, self.register_token_num + self.additional_register_token_num :
|
373 |
+
] # remove the register tokens for the output layer
|
374 |
+
elif hasattr(self, "register_tokens"):
|
375 |
+
emb_x = emb_x[:, self.register_token_num :] # remove the register tokens for the output layer
|
376 |
+
|
377 |
+
# 3. final modulation (shift-and-scale)
|
378 |
+
scale, shift = self.final_modulation(emb_t).chunk(2, -1) # (N, D) x 2
|
379 |
+
scale, shift = scale.unsqueeze(1), shift.unsqueeze(1) # (N, 1, D) x 2
|
380 |
+
|
381 |
+
if hasattr(self, "final_norm"):
|
382 |
+
emb_x = self.final_norm(emb_x)
|
383 |
+
|
384 |
+
final_emb = (scale + 1) * emb_x + shift
|
385 |
+
|
386 |
+
# 4. final linear layer to reduce channel and un-patching
|
387 |
+
emb_x = self.final_linear_SD3(final_emb) # (N, T, D) to (N, T, out_channels * patch_size**2)
|
388 |
+
emb_x = self.patching.unpatchify(emb_x) # (N, out_channels, H, W)
|
389 |
+
|
390 |
+
if hasattr(self, "depatching_conv_header"):
|
391 |
+
emb_x = self.depatching_conv_header(emb_x)
|
392 |
+
return emb_x
|
393 |
+
|
394 |
+
|
395 |
+
class JointAttn(nn.Module):
|
396 |
+
"""
|
397 |
+
SD3 style joint-attention layer
|
398 |
+
"""
|
399 |
+
|
400 |
+
def __init__(self, config):
|
401 |
+
super().__init__()
|
402 |
+
self.config = config
|
403 |
+
self.hidden_size = config.hidden_dim
|
404 |
+
self.num_heads = config.num_attention_heads
|
405 |
+
self.head_dim = self.hidden_size // self.num_heads
|
406 |
+
|
407 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
408 |
+
raise ValueError(
|
409 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
410 |
+
f" and `num_heads`: {self.num_heads})."
|
411 |
+
)
|
412 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
413 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
414 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
415 |
+
|
416 |
+
self.add_q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
417 |
+
self.add_k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
418 |
+
self.add_v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
419 |
+
|
420 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
421 |
+
self.add_o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
422 |
+
|
423 |
+
self.q_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
|
424 |
+
self.k_norm_x = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
|
425 |
+
|
426 |
+
self.q_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
|
427 |
+
self.k_norm_c = MotifRMSNorm(self.head_dim) if MotifRMSNorm else RMSNorm(self.head_dim)
|
428 |
+
self.q_scale = nn.Parameter(torch.ones(self.num_heads))
|
429 |
+
|
430 |
+
# Attention mode : {'sdpa', 'flash', None}
|
431 |
+
self.attn_mode = config.attn_mode
|
432 |
+
|
433 |
+
def forward(
|
434 |
+
self,
|
435 |
+
hidden_states: torch.FloatTensor,
|
436 |
+
encoder_hidden_states: torch.FloatTensor,
|
437 |
+
*args,
|
438 |
+
**kwargs,
|
439 |
+
) -> torch.FloatTensor:
|
440 |
+
residual = hidden_states
|
441 |
+
|
442 |
+
input_ndim = hidden_states.ndim
|
443 |
+
if input_ndim == 4:
|
444 |
+
batch_size, channel, height, width = hidden_states.shape
|
445 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
446 |
+
context_input_ndim = encoder_hidden_states.ndim
|
447 |
+
if context_input_ndim == 4:
|
448 |
+
batch_size, channel, height, width = encoder_hidden_states.shape
|
449 |
+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
450 |
+
|
451 |
+
batch_size = encoder_hidden_states.shape[0]
|
452 |
+
|
453 |
+
# `sample` projections.
|
454 |
+
query = self.q_proj(hidden_states)
|
455 |
+
key = self.k_proj(hidden_states)
|
456 |
+
value = self.v_proj(hidden_states)
|
457 |
+
|
458 |
+
# `context` projections.
|
459 |
+
query_c = self.add_q_proj(encoder_hidden_states)
|
460 |
+
key_c = self.add_k_proj(encoder_hidden_states)
|
461 |
+
value_c = self.add_v_proj(encoder_hidden_states)
|
462 |
+
|
463 |
+
# head first
|
464 |
+
inner_dim = key.shape[-1]
|
465 |
+
head_dim = inner_dim // self.num_heads
|
466 |
+
|
467 |
+
def norm_qk(x, f_norm):
|
468 |
+
x = x.view(batch_size, -1, self.num_heads, head_dim)
|
469 |
+
b, l, h, d_h = x.shape
|
470 |
+
x = x.reshape(b * l, h, d_h)
|
471 |
+
x = f_norm(x)
|
472 |
+
return x.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h]
|
473 |
+
|
474 |
+
query = norm_qk(query, self.q_norm_x) # [b, h, l, d_h]
|
475 |
+
key = norm_qk(key, self.k_norm_x) # [b, h, l, d_h]
|
476 |
+
value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l, d_h]
|
477 |
+
|
478 |
+
query_c = norm_qk(query_c, self.q_norm_c) * self.q_scale.reshape(1, self.num_heads, 1, 1) # [b, h, l_c, d]
|
479 |
+
key_c = norm_qk(key_c, self.k_norm_c) # [b, h, l_c, d]
|
480 |
+
value_c = value_c.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) # [b, h, l_c, d]
|
481 |
+
|
482 |
+
# attention
|
483 |
+
query = torch.cat([query, query_c], dim=2).contiguous() # [b, h, l + l_c, d]
|
484 |
+
key = torch.cat([key, key_c], dim=2).contiguous() # [b, h, l + l_c, d]
|
485 |
+
value = torch.cat([value, value_c], dim=2).contiguous() # [b, h, l + l_c, d]
|
486 |
+
|
487 |
+
# deprecated.
|
488 |
+
hidden_states = self.joint_attention(batch_size, query, key, value, head_dim)
|
489 |
+
hidden_states = hidden_states.to(query.dtype)
|
490 |
+
|
491 |
+
# Split the attention outputs.
|
492 |
+
hidden_states, encoder_hidden_states = (
|
493 |
+
hidden_states[:, : residual.shape[1]],
|
494 |
+
hidden_states[:, residual.shape[1] :],
|
495 |
+
)
|
496 |
+
|
497 |
+
# linear proj
|
498 |
+
hidden_states = self.o_proj(hidden_states)
|
499 |
+
encoder_hidden_states = self.add_o_proj(encoder_hidden_states)
|
500 |
+
|
501 |
+
if input_ndim == 4:
|
502 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
503 |
+
if context_input_ndim == 4:
|
504 |
+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
505 |
+
|
506 |
+
return hidden_states, encoder_hidden_states
|
507 |
+
|
508 |
+
def joint_attention(self, batch_size, query, key, value, head_dim):
|
509 |
+
if self.attn_mode == "sdpa" and ScaledDotProductAttention is not None:
|
510 |
+
# NOTE: SDPA does not support high-resolution (long-context).
|
511 |
+
q_len = query.size(-2)
|
512 |
+
masked_bias = torch.zeros((batch_size, self.num_heads, query.size(-2), key.size(-2)), device="cuda")
|
513 |
+
|
514 |
+
query = query.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
|
515 |
+
key = key.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
|
516 |
+
value = value.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size).contiguous()
|
517 |
+
|
518 |
+
scale_factor = 1.0
|
519 |
+
scale_factor /= float(self.head_dim) ** 0.5
|
520 |
+
|
521 |
+
hidden_states = ScaledDotProductAttention(
|
522 |
+
query,
|
523 |
+
key,
|
524 |
+
value,
|
525 |
+
masked_bias,
|
526 |
+
dropout_rate=0.0,
|
527 |
+
training=self.training,
|
528 |
+
attn_weight_scale_factor=scale_factor,
|
529 |
+
num_kv_groups=1,
|
530 |
+
)
|
531 |
+
elif self.attn_mode == "flash" and MotifFlashAttention is not None:
|
532 |
+
query = query.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
|
533 |
+
key = key.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
|
534 |
+
value = value.permute(0, 2, 1, 3).contiguous() # [b, l + l_c, h, d]
|
535 |
+
scale_factor = 1.0 / math.sqrt(self.head_dim)
|
536 |
+
|
537 |
+
# NOTE (1): masking of motif flash-attention uses (`1`: un-mask, `0`: mask) and has [Batch, Seq] shape
|
538 |
+
# NOTE (2): Q,K,V must be [Batch, Seq, Heads, Dim] and contiguous.
|
539 |
+
mask = torch.ones((batch_size, query.size(-3))).cuda()
|
540 |
+
hidden_states = MotifFlashAttention(
|
541 |
+
query,
|
542 |
+
key,
|
543 |
+
value,
|
544 |
+
padding_mask=mask,
|
545 |
+
softmax_scale=scale_factor,
|
546 |
+
causal=False,
|
547 |
+
)
|
548 |
+
hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * head_dim).contiguous()
|
549 |
+
else:
|
550 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0)
|
551 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
|
552 |
+
|
553 |
+
return hidden_states
|
554 |
+
|
555 |
+
@staticmethod
|
556 |
+
def alt_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, scale=None) -> torch.Tensor:
|
557 |
+
"""
|
558 |
+
Pure-pytorch version of the xformers.scaled_dot_product_attention
|
559 |
+
(or F.scaled_dot_product_attention from torch>2.0.0)
|
560 |
+
|
561 |
+
Args:
|
562 |
+
query (Tensor): query tensor
|
563 |
+
key (Tensor): key tensor
|
564 |
+
value (Tensor): value tensor
|
565 |
+
attn_mask (Tensor, optional): attention mask. Defaults to None.
|
566 |
+
dropout_p (float, optional): attention dropout probability. Defaults to 0.0.
|
567 |
+
scale (Tensor or float, optional): scaling for QK. Defaults to None.
|
568 |
+
|
569 |
+
Returns:
|
570 |
+
torch.Tensor: attention score (after softmax)
|
571 |
+
"""
|
572 |
+
L, S = query.size(-2), key.size(-2)
|
573 |
+
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
574 |
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
575 |
+
|
576 |
+
if attn_mask is not None:
|
577 |
+
if attn_mask.dtype == torch.bool:
|
578 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
579 |
+
else:
|
580 |
+
attn_bias += attn_mask
|
581 |
+
|
582 |
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor # B, L, S
|
583 |
+
attn_weight += attn_bias
|
584 |
+
attn_weight = torch.softmax(attn_weight, dim=-1) # B, L, S
|
585 |
+
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
|
586 |
+
return attn_weight @ value # B, L, S * S, D -> B, L, D
|
587 |
+
|
588 |
+
|
589 |
+
# ===============================================
|
590 |
+
# Sine/Cosine Positional Embedding Functions
|
591 |
+
# ===============================================
|
592 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
593 |
+
|
594 |
+
|
595 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
|
596 |
+
"""
|
597 |
+
grid_size: int of the grid height and width
|
598 |
+
return:
|
599 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
600 |
+
"""
|
601 |
+
if not isinstance(grid_size, tuple):
|
602 |
+
grid_size = (grid_size, grid_size)
|
603 |
+
|
604 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
|
605 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
|
606 |
+
if base_size is not None:
|
607 |
+
grid_h *= base_size / grid_size[0]
|
608 |
+
grid_w *= base_size / grid_size[1]
|
609 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
610 |
+
grid = np.stack(grid, axis=0)
|
611 |
+
|
612 |
+
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
613 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
614 |
+
if cls_token and extra_tokens > 0:
|
615 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
616 |
+
return pos_embed
|
617 |
+
|
618 |
+
|
619 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
620 |
+
assert embed_dim % 2 == 0
|
621 |
+
|
622 |
+
# use half of dimensions to encode grid_h
|
623 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
624 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
625 |
+
|
626 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
627 |
+
return emb
|
628 |
+
|
629 |
+
|
630 |
+
def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
|
631 |
+
pos = np.arange(0, length)[..., None] / scale
|
632 |
+
return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
|
633 |
+
|
634 |
+
|
635 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
636 |
+
"""
|
637 |
+
embed_dim: output dimension for each position
|
638 |
+
pos: a list of positions to be encoded: size (M,)
|
639 |
+
out: (M, D)
|
640 |
+
"""
|
641 |
+
assert embed_dim % 2 == 0
|
642 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
643 |
+
omega /= embed_dim / 2.0
|
644 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
645 |
+
|
646 |
+
pos = pos.reshape(-1) # (M,)
|
647 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
648 |
+
|
649 |
+
emb_sin = np.sin(out) # (M, D/2)
|
650 |
+
emb_cos = np.cos(out) # (M, D/2)
|
651 |
+
|
652 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
653 |
+
return emb
|
models/modeling_motif_vision.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.transforms as T
|
8 |
+
import tqdm
|
9 |
+
from diffusers.models import AutoencoderKL
|
10 |
+
from diffusers.utils.torch_utils import randn_tensor
|
11 |
+
from loguru import logger
|
12 |
+
from PIL import Image, ImageFilter
|
13 |
+
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer
|
14 |
+
|
15 |
+
from models.mixin.flow_mixin import FlowMixin
|
16 |
+
from models.modeling_dit import MotifDiT
|
17 |
+
|
18 |
+
TOKEN_MAX_LENGTH: int = 256
|
19 |
+
DROP_PROB: float = 0.1
|
20 |
+
LATENT_CHANNELS: int = 4
|
21 |
+
VAE_DOWNSCALE_FACTOR: int = 8
|
22 |
+
SD3_LATENT_CHANNEL: int = 16
|
23 |
+
|
24 |
+
|
25 |
+
def generate_intervals(steps, ratio, start=1.0):
|
26 |
+
intervals = torch.linspace(start, 0, steps=steps)
|
27 |
+
intervals = intervals.pow(ratio)
|
28 |
+
return intervals
|
29 |
+
|
30 |
+
|
31 |
+
class MotifVision(nn.Module, FlowMixin):
|
32 |
+
"""
|
33 |
+
MotifVision Text-to-Image model.
|
34 |
+
|
35 |
+
This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders.
|
36 |
+
It uses a VAE (Variational Autoencoder) for image encoding and decoding.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
config (MMDiTConfig): Configuration object for the MMDiT model.
|
40 |
+
|
41 |
+
Attributes:
|
42 |
+
dit (MotifDiT): MotifDiT model instance.
|
43 |
+
noise_scheduler (DDPMScheduler): Noise scheduler for the diffusion process.
|
44 |
+
normalize_img (Callable): Function to normalize images from [-1, 1] range.
|
45 |
+
unnormalize_img (Callable): Function to unnormalize images to [0, 1] range.
|
46 |
+
cond_drop_prob (float): Probability of dropping text embeddings during training.
|
47 |
+
snr_gamma (str): Strategy for weighting the loss based on Signal-to-Noise Ratio (SNR).
|
48 |
+
loss_weight_strategy (str): Strategy for weighting the loss.
|
49 |
+
vae (AutoencoderKL): Variational Autoencoder for image encoding and decoding.
|
50 |
+
t5 (T5EncoderModel): T5 encoder model for text encoding.
|
51 |
+
t5_tokenizer (T5Tokenizer): T5 tokenizer for text tokenization.
|
52 |
+
clip_l (CLIPModel): CLIP (Contrastive Language-Image Pre-training) model (large) for text encoding.
|
53 |
+
clip_l_tokenizer (CLIPTokenizerFast): CLIP tokenizer (large) for text tokenization.
|
54 |
+
clip_g (CLIPModel): CLIP model (giant) for text encoding.
|
55 |
+
clip_g_tokenizer (CLIPTokenizerFast): CLIP tokenizer (giant) for text tokenization.
|
56 |
+
tokenizers (List[Union[T5Tokenizer, CLIPTokenizerFast]]): List of tokenizers.
|
57 |
+
text_encoders (List[Union[T5EncoderModel, CLIPModel]]): List of text encoder models.
|
58 |
+
"""
|
59 |
+
|
60 |
+
def __init__(self, config):
|
61 |
+
super().__init__()
|
62 |
+
self.config = config
|
63 |
+
self.dit = MotifDiT(config)
|
64 |
+
self.cond_drop_prob = 0.1
|
65 |
+
self.use_weighting = False
|
66 |
+
self._get_encoders()
|
67 |
+
self._freeze_encoders()
|
68 |
+
|
69 |
+
def forward(self, images: torch.Tensor, raw_text: str) -> torch.Tensor:
|
70 |
+
"""
|
71 |
+
Forward pass of the MotifDiT model.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
images (torch.Tensor): Input images tensor, [0-1] ranged.
|
75 |
+
raw_text (List[str]): Input text string.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
torch.Tensor: Rectified flow matching loss.
|
79 |
+
"""
|
80 |
+
# 1. Encode images and texts
|
81 |
+
with torch.no_grad():
|
82 |
+
latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
|
83 |
+
tokens, masks = self.tokenization(raw_text)
|
84 |
+
tokens = [token.to(latents.device) for token in tokens]
|
85 |
+
masks = [mask.to(latents.device) for mask in masks]
|
86 |
+
text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
|
87 |
+
text_embeddings = self._drop_text_emb(text_embeddings)
|
88 |
+
text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
|
89 |
+
pooled_text_embeddings = pooled_text_embeddings.float()
|
90 |
+
|
91 |
+
# 2. Get noisy input via the rectified flow
|
92 |
+
is_finetuning = self.config.height > 256
|
93 |
+
noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
|
94 |
+
|
95 |
+
timesteps = self.discritize_timestep(t, self.n_timesteps)
|
96 |
+
|
97 |
+
# 3. Forward pass through the dit
|
98 |
+
preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)
|
99 |
+
|
100 |
+
# 4. Rectified flow matching loss
|
101 |
+
loss = self.rectified_flow_loss(latents, noise, t, preds, use_weighting=self.use_weighting)
|
102 |
+
|
103 |
+
return [loss]
|
104 |
+
|
105 |
+
def _get_encoders(self) -> None:
|
106 |
+
"""Initialize the VAE and text encoders."""
|
107 |
+
if self.config.vae_type == "SD3":
|
108 |
+
self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
|
109 |
+
elif self.config.vae_type == "SDXL":
|
110 |
+
self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
111 |
+
else:
|
112 |
+
raise ValueError(f"VAE type must be `SD3` or `SDXL` but self.config.vae_type is {self.config.vae_type}")
|
113 |
+
|
114 |
+
# Text encoders
|
115 |
+
# 1. T5-XXL from Google
|
116 |
+
self.t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
|
117 |
+
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
|
118 |
+
|
119 |
+
# 2. CLIP-L from OpenAI
|
120 |
+
self.clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
|
121 |
+
self.clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
|
122 |
+
|
123 |
+
# 3. CLIP-G from LAION
|
124 |
+
self.clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
|
125 |
+
self.clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
126 |
+
|
127 |
+
self.tokenizers = [self.t5_tokenizer, self.clip_l_tokenizer, self.clip_g_tokenizer]
|
128 |
+
self.text_encoders = [self.t5, self.clip_l, self.clip_g]
|
129 |
+
|
130 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
131 |
+
state_dict = super(MotifVision, self).state_dict(destination, prefix, keep_vars)
|
132 |
+
exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."]
|
133 |
+
for key in list(state_dict.keys()):
|
134 |
+
if any(key.startswith(exclude_key) for exclude_key in exclude_keys):
|
135 |
+
state_dict.pop(key)
|
136 |
+
return state_dict
|
137 |
+
|
138 |
+
def load_state_dict(self, state_dict, strict=False):
|
139 |
+
"""
|
140 |
+
Load state dict and merge LoRA parameters if present.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
state_dict (dict): State dictionary containing model parameters
|
144 |
+
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys in this module
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
tuple: (missing_keys, unexpected_keys) lists of parameters that were missing or unexpected
|
148 |
+
"""
|
149 |
+
# Check if state_dict contains LoRA parameters
|
150 |
+
has_lora = any("lora_" in key for key in state_dict.keys())
|
151 |
+
|
152 |
+
if has_lora:
|
153 |
+
# If model doesn't have LoRA enabled but state_dict has LoRA params, enable it
|
154 |
+
if not hasattr(self.dit, "peft_config"):
|
155 |
+
logger.info("Enabling LoRA for parameter merging...")
|
156 |
+
# Use default values if not already configured
|
157 |
+
lora_rank = getattr(self.config, "lora_rank", 64)
|
158 |
+
lora_alpha = getattr(self.config, "lora_alpha", 8)
|
159 |
+
self.enable_lora(lora_rank, lora_alpha)
|
160 |
+
|
161 |
+
if has_lora:
|
162 |
+
try:
|
163 |
+
# Load LoRA parameters
|
164 |
+
# state_dict = {
|
165 |
+
# k.replace("base_layer.", ""): v
|
166 |
+
# for k, v in state_dict.items()
|
167 |
+
# if "lora_" not in k and "lora" not in k
|
168 |
+
# }
|
169 |
+
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
|
170 |
+
# Merge LoRA weights with base model
|
171 |
+
logger.info("Merging LoRA parameters with base model...")
|
172 |
+
for name, module in self.dit.named_modules():
|
173 |
+
if hasattr(module, "merge_and_unload"):
|
174 |
+
module.merge_and_unload()
|
175 |
+
|
176 |
+
logger.info("Successfully merged LoRA parameters")
|
177 |
+
|
178 |
+
except Exception as e:
|
179 |
+
logger.error(f"Error merging LoRA parameters: {str(e)}")
|
180 |
+
raise
|
181 |
+
else:
|
182 |
+
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
|
183 |
+
|
184 |
+
# Log summary of missing/unexpected parameters
|
185 |
+
missing_top_levels = set()
|
186 |
+
for key in missing_keys:
|
187 |
+
top_level_name = key.split(".")[0]
|
188 |
+
missing_top_levels.add(top_level_name)
|
189 |
+
if missing_top_levels:
|
190 |
+
logger.debug("Missing keys during loading at top level:")
|
191 |
+
for name in missing_top_levels:
|
192 |
+
logger.debug(name)
|
193 |
+
|
194 |
+
if unexpected_keys:
|
195 |
+
logger.debug("Unexpected keys found:")
|
196 |
+
for key in unexpected_keys:
|
197 |
+
logger.debug(key)
|
198 |
+
|
199 |
+
return missing_keys, unexpected_keys
|
200 |
+
|
201 |
+
def _freeze_encoders(self) -> None:
|
202 |
+
"""
|
203 |
+
freeze all encoders
|
204 |
+
"""
|
205 |
+
for encoder_module in [self.vae, self.clip_l, self.clip_g, self.t5]:
|
206 |
+
for param in encoder_module.parameters():
|
207 |
+
param.requires_grad = False
|
208 |
+
|
209 |
+
def tokenization(
|
210 |
+
self, raw_texts: List[str], repeat_if_short: bool = False
|
211 |
+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
212 |
+
"""
|
213 |
+
Tokenizes a BATCH of input texts using multiple tokenizers efficiently.
|
214 |
+
Optionally repeats each text to fill the max length if it's shorter,
|
215 |
+
BEFORE passing the pre-processed batch to the tokenizer.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
raw_texts (List[str]): A list of input text strings (the batch).
|
219 |
+
repeat_if_short (bool): If True and a text is short, repeat that text
|
220 |
+
to fill the context length. Defaults to True.
|
221 |
+
|
222 |
+
Returns:
|
223 |
+
Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
224 |
+
- A list containing one batch tensor of input IDs per tokenizer.
|
225 |
+
Each tensor shape: [batch_size, max_length]
|
226 |
+
- A list containing one batch tensor of attention masks per tokenizer.
|
227 |
+
Each tensor shape: [batch_size, max_length]
|
228 |
+
"""
|
229 |
+
final_batch_tokens = []
|
230 |
+
final_batch_masks = []
|
231 |
+
|
232 |
+
# Process the batch with each tokenizer
|
233 |
+
for tokenizer in self.tokenizers:
|
234 |
+
effective_max_length = min(TOKEN_MAX_LENGTH, tokenizer.model_max_length)
|
235 |
+
|
236 |
+
# 1. Pre-process the batch: Create a new list of potentially repeated strings.
|
237 |
+
processed_texts_for_tokenizer = []
|
238 |
+
for text_item in raw_texts:
|
239 |
+
# Start with the original text for this item
|
240 |
+
processed_text = text_item
|
241 |
+
|
242 |
+
if repeat_if_short:
|
243 |
+
# Apply repetition logic individually based on text_item's length
|
244 |
+
num_initial_tokens = len(text_item.split())
|
245 |
+
available_length = effective_max_length - 2 # Heuristic
|
246 |
+
|
247 |
+
if num_initial_tokens > 0 and num_initial_tokens < available_length:
|
248 |
+
num_additional_repeats = available_length // (num_initial_tokens + 1)
|
249 |
+
if num_additional_repeats > 0:
|
250 |
+
total_repeats = 1 + num_additional_repeats
|
251 |
+
processed_text = " ".join([text_item] * total_repeats)
|
252 |
+
|
253 |
+
# Add the processed text (original or repeated) to the list for this tokenizer
|
254 |
+
processed_texts_for_tokenizer.append(processed_text)
|
255 |
+
|
256 |
+
# 2. Tokenize the entire batch of processed texts at once.
|
257 |
+
# Pass the list `processed_texts_for_tokenizer` directly to the tokenizer.
|
258 |
+
# The tokenizer's __call__ method should handle the batch efficiently.
|
259 |
+
batch_tok_output = tokenizer( # Call the tokenizer ONCE with the full list
|
260 |
+
processed_texts_for_tokenizer,
|
261 |
+
padding="max_length",
|
262 |
+
max_length=effective_max_length,
|
263 |
+
return_tensors="pt",
|
264 |
+
truncation=True,
|
265 |
+
)
|
266 |
+
|
267 |
+
# 3. Store the resulting batch tensors directly.
|
268 |
+
# The tokenizer should return tensors with shape [batch_size, max_length].
|
269 |
+
final_batch_tokens.append(batch_tok_output.input_ids)
|
270 |
+
final_batch_masks.append(batch_tok_output.attention_mask)
|
271 |
+
|
272 |
+
return final_batch_tokens, final_batch_masks
|
273 |
+
|
274 |
+
@torch.no_grad()
|
275 |
+
def text_encoding(
|
276 |
+
self, tokens: List[torch.Tensor], masks, noisy_pad=False, zero_masking=True
|
277 |
+
) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
278 |
+
"""
|
279 |
+
Encode the tokenized text using multiple text encoders.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
tokens (List[torch.Tensor]): List of tokenized text tensors.
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing a list of text embeddings and pooled text embeddings.
|
286 |
+
"""
|
287 |
+
t5_tokens, clip_l_tokens, clip_g_tokens = tokens
|
288 |
+
t5_masks, clip_l_masks, clip_g_masks = masks
|
289 |
+
t5_emb = self.t5(t5_tokens, attention_mask=t5_masks)[0]
|
290 |
+
if zero_masking:
|
291 |
+
t5_emb = t5_emb * (t5_tokens != self.t5_tokenizer.pad_token_id).unsqueeze(-1)
|
292 |
+
if noisy_pad:
|
293 |
+
t5_pad_noise = (
|
294 |
+
(t5_tokens == self.t5_tokenizer.pad_token_id).unsqueeze(-1) * torch.randn_like(t5_emb).cuda() * 0.008
|
295 |
+
)
|
296 |
+
t5_emb = t5_emb + t5_pad_noise
|
297 |
+
|
298 |
+
clip_l_emb = self.clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
|
299 |
+
clip_g_emb = self.clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
|
300 |
+
clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768
|
301 |
+
clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280
|
302 |
+
|
303 |
+
clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768,
|
304 |
+
clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280,
|
305 |
+
|
306 |
+
def masking_wo_first_eos(token, eos):
|
307 |
+
idx = (token != eos).sum(dim=1)
|
308 |
+
mask = token != eos
|
309 |
+
arange = torch.arange(mask.size(0)).cuda()
|
310 |
+
mask[arange, idx] = True
|
311 |
+
mask = mask.unsqueeze(-1) # B x L x 1
|
312 |
+
return mask
|
313 |
+
|
314 |
+
if zero_masking:
|
315 |
+
clip_l_emb = clip_l_emb * masking_wo_first_eos(
|
316 |
+
clip_l_tokens, self.clip_l_tokenizer.eos_token_id
|
317 |
+
) # B x L x 768,
|
318 |
+
clip_g_emb = clip_g_emb * masking_wo_first_eos(
|
319 |
+
clip_g_tokens, self.clip_g_tokenizer.eos_token_id
|
320 |
+
) # B x L x 768,
|
321 |
+
|
322 |
+
if noisy_pad:
|
323 |
+
clip_l_pad_noise = (
|
324 |
+
~masking_wo_first_eos(clip_l_tokens, self.clip_l_tokenizer.eos_token_id)
|
325 |
+
* torch.randn_like(clip_l_emb).cuda()
|
326 |
+
* 0.08
|
327 |
+
)
|
328 |
+
clip_g_pad_noise = (
|
329 |
+
~masking_wo_first_eos(clip_g_tokens, self.clip_g_tokenizer.eos_token_id)
|
330 |
+
* torch.randn_like(clip_g_emb).cuda()
|
331 |
+
* 0.08
|
332 |
+
)
|
333 |
+
clip_l_emb = clip_l_emb + clip_l_pad_noise
|
334 |
+
clip_g_emb = clip_g_emb + clip_g_pad_noise
|
335 |
+
|
336 |
+
encodings = [t5_emb, clip_l_emb, clip_g_emb]
|
337 |
+
pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # cat by channel, B x 2048
|
338 |
+
|
339 |
+
return encodings, pooled_encodings
|
340 |
+
|
341 |
+
@torch.no_grad()
|
342 |
+
def prompt_embedding(self, prompts: str, device, noisy_pad=False, zero_masking=True):
|
343 |
+
tokens, masks = self.tokenization(prompts)
|
344 |
+
tokens = [token.to(device) for token in tokens]
|
345 |
+
masks = [mask.to(device) for mask in masks]
|
346 |
+
text_embeddings, pooled_text_embeddings = self.text_encoding(
|
347 |
+
tokens, masks, noisy_pad=noisy_pad, zero_masking=zero_masking
|
348 |
+
)
|
349 |
+
text_embeddings = [text_embedding.bfloat16() for text_embedding in text_embeddings]
|
350 |
+
pooled_text_embeddings = pooled_text_embeddings.bfloat16()
|
351 |
+
return text_embeddings, pooled_text_embeddings
|
352 |
+
|
353 |
+
@torch.no_grad()
|
354 |
+
def sample(
|
355 |
+
self,
|
356 |
+
raw_text: List[str],
|
357 |
+
steps: int = 50,
|
358 |
+
guidance_scale: float = 7.5,
|
359 |
+
resolution: List[int] = (256, 256),
|
360 |
+
pre_latent=None,
|
361 |
+
pre_timestep=None,
|
362 |
+
step_scaling=1.0,
|
363 |
+
noisy_pad=False,
|
364 |
+
zero_masking=False,
|
365 |
+
negative_prompt: Optional[List[str]] = None,
|
366 |
+
device: str = "cuda",
|
367 |
+
rescale_cfg=-1.0,
|
368 |
+
clip_t=[0.0, 1.0],
|
369 |
+
use_linear_quadratic_schedule=False,
|
370 |
+
linear_quadratic_emulating_steps=250,
|
371 |
+
prompt_rewriter=None,
|
372 |
+
moderator=None,
|
373 |
+
get_intermediate_steps: bool = False, # Defaulting to True based on user code
|
374 |
+
) -> Union[List[Image.Image], Tuple[List[Image.Image], List[List[Image.Image]]]]: # Updated return type hint
|
375 |
+
"""
|
376 |
+
Sample images using flow matching. Optionally returns intermediate step images
|
377 |
+
calculated via observed average velocity method.
|
378 |
+
|
379 |
+
Args:
|
380 |
+
raw_text (List[str]): raw text prompts
|
381 |
+
steps (int, optional): number of function estimations for flow matching ODE. Defaults to 50.
|
382 |
+
guidance_scale (float, optional): classifier free guidance scale. Defaults to 7.5.
|
383 |
+
resolution (List[int], optional): input and output resolution of raw images. Defaults to (256, 256).
|
384 |
+
device (str, optional): Defaults to 'cuda'.
|
385 |
+
pre_latent (Tensor, optional): the optional input to generate image with pre-defined latents.
|
386 |
+
for instance, it would be utilized for denoising or image-editing.
|
387 |
+
pre_timestep (float [0,1], optional): the pre-defined timestep. with `pre_latent`, image generation
|
388 |
+
can be done by starting with intermediate timestep.
|
389 |
+
step_scaling (float, default to 1.3): scaling factor for each ODE-solving.
|
390 |
+
use_linear_quadratic_schedule (bool, default to false): boolean option to linear-quaratic t schdule. If false, then linear t schdule.
|
391 |
+
linear_quadratic_emulating_steps (int, default to 250): N value in linear-quadratic t schedule from Meta moviegen paper
|
392 |
+
Reference: (https://ai.meta.com/static-resource/movie-gen-research-paper) Figure 10
|
393 |
+
get_intermediate_steps (bool, optional): Whether to calculate and return intermediate step images.
|
394 |
+
Calculation is based on initial_noise - avg(velocity). Defaults to True.
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
Union[List[PIL.Image], Tuple[List[PIL.Image], List[List[PIL.Image]]]]:
|
398 |
+
If get_intermediate_steps is False: Returns a list of final PIL images.
|
399 |
+
If get_intermediate_steps is True: Returns a tuple containing:
|
400 |
+
- List[PIL.Image]: Final output PIL images.
|
401 |
+
- List[List[PIL.Image]]: List of intermediate PIL images. Each inner list contains
|
402 |
+
the batch of images for one intermediate step.
|
403 |
+
"""
|
404 |
+
if prompt_rewriter:
|
405 |
+
prompts = [prompt_rewriter.rewrite(prompt) for prompt in raw_text]
|
406 |
+
else:
|
407 |
+
prompts = raw_text
|
408 |
+
|
409 |
+
# Simplified check for rewriter status
|
410 |
+
if prompts == raw_text and prompt_rewriter is not None:
|
411 |
+
logger.debug("Prompt rewriter did not change the prompts.")
|
412 |
+
elif prompt_rewriter is None:
|
413 |
+
logger.debug("Prompt rewriter not provided.")
|
414 |
+
|
415 |
+
if moderator is None:
|
416 |
+
is_safe_prompt = [True for _ in prompts]
|
417 |
+
else:
|
418 |
+
is_safe_prompt = [moderator and moderator.is_safe_content(prompt, threshold=0.7) for prompt in prompts]
|
419 |
+
if not all(is_safe_prompt):
|
420 |
+
logger.warning("Noxious prompt detected. Output image(s) will be blurred.")
|
421 |
+
|
422 |
+
b = len(prompts)
|
423 |
+
h, w = resolution
|
424 |
+
|
425 |
+
# --- [Initial Latent Noise (e = x_1)] ---
|
426 |
+
latent_channels = 16
|
427 |
+
if pre_latent is None:
|
428 |
+
initial_noise = randn_tensor( # Store initial noise separately
|
429 |
+
(b, latent_channels, h // VAE_DOWNSCALE_FACTOR, w // VAE_DOWNSCALE_FACTOR),
|
430 |
+
device=device,
|
431 |
+
dtype=torch.float32, # Use float32 for calculations
|
432 |
+
)
|
433 |
+
else:
|
434 |
+
initial_noise = pre_latent.to(device=device, dtype=torch.float32)
|
435 |
+
if pre_timestep is not None and pre_timestep < 1.0: # Check if it's truly intermediate
|
436 |
+
logger.warning(
|
437 |
+
"Using pre_latent as initial_noise for average calculation, but pre_timestep suggests it's not pure noise. Results might be unexpected."
|
438 |
+
)
|
439 |
+
|
440 |
+
latents = initial_noise.clone() # Working latents for the ODE solver
|
441 |
+
|
442 |
+
# --- [Text Embeddings & CFG Setup] ---
|
443 |
+
text_embeddings, pooled_text_embeddings = self.prompt_embedding(
|
444 |
+
prompts, latents.device, noisy_pad=noisy_pad, zero_masking=zero_masking
|
445 |
+
)
|
446 |
+
text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in text_embeddings]
|
447 |
+
pooled_text_embeddings = pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)
|
448 |
+
|
449 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
450 |
+
if do_classifier_free_guidance:
|
451 |
+
negative_text_embeddings = [
|
452 |
+
torch.zeros_like(text_embedding, device=text_embedding.device) for text_embedding in text_embeddings
|
453 |
+
]
|
454 |
+
negative_pooled_text_embeddings = torch.zeros_like(
|
455 |
+
pooled_text_embeddings, device=pooled_text_embeddings.device
|
456 |
+
)
|
457 |
+
text_embeddings = [
|
458 |
+
torch.cat([text_embedding, negative_text_embedding], dim=0)
|
459 |
+
for text_embedding, negative_text_embedding in zip(text_embeddings, negative_text_embeddings)
|
460 |
+
]
|
461 |
+
pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)
|
462 |
+
|
463 |
+
# if negative_prompt is None:
|
464 |
+
# negative_prompt = [""] * b
|
465 |
+
# logger.debug("No negative prompt provided, using empty strings for CFG.")
|
466 |
+
# negative_text_embeddings, negative_pooled_text_embeddings = self.prompt_embedding(negative_prompt, latents.device)
|
467 |
+
# negative_text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in negative_text_embeddings]
|
468 |
+
# negative_pooled_text_embeddings = negative_pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)
|
469 |
+
|
470 |
+
# text_embeddings = [torch.cat([pos_emb, neg_emb], dim=0) for pos_emb, neg_emb in zip(text_embeddings, negative_text_embeddings)]
|
471 |
+
# pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)
|
472 |
+
|
473 |
+
# --- [Timestep Schedule (Sigmas)] ---
|
474 |
+
# linear t schedule
|
475 |
+
sigmas = torch.linspace(1, 0, steps + 1) if not pre_timestep else torch.linspace(pre_timestep, 0, steps + 1)
|
476 |
+
|
477 |
+
if use_linear_quadratic_schedule:
|
478 |
+
# liner-quadratic t schedule
|
479 |
+
assert steps % 2 == 0
|
480 |
+
N = linear_quadratic_emulating_steps
|
481 |
+
sigmas = torch.concat(
|
482 |
+
[
|
483 |
+
torch.linspace(1, 0, N + 1)[: steps // 2],
|
484 |
+
torch.linspace(0, 1, steps // 2 + 1) ** 2 * (steps // 2 * 1 / N - 1) - (steps // 2 * 1 / N - 1),
|
485 |
+
]
|
486 |
+
)
|
487 |
+
|
488 |
+
# --- [Initialization for Intermediate Step Calculation] ---
|
489 |
+
# intermediate_latents will store the latent states for intermediate steps
|
490 |
+
intermediate_latents = [] if get_intermediate_steps else None
|
491 |
+
predicted_velocities = [] # Store dx from each step
|
492 |
+
sigma_history = []
|
493 |
+
# --- [Sampling Loop] ---
|
494 |
+
for infer_step, t in tqdm.tqdm(enumerate(sigmas[:-1]), total=len(sigmas[:-1]), desc="Sampling"):
|
495 |
+
# Prepare input for DiT model
|
496 |
+
if do_classifier_free_guidance:
|
497 |
+
input_latents = torch.cat([latents] * 2, dim=0)
|
498 |
+
else:
|
499 |
+
input_latents = latents
|
500 |
+
|
501 |
+
# Prepare timestep input
|
502 |
+
timestep = (t * 1000).round().long().to(latents.device)
|
503 |
+
timestep = timestep.expand(input_latents.shape[0]).to(torch.bfloat16) # Ensure timestep is bfloat16
|
504 |
+
|
505 |
+
# Predict velocity dx = v(x_t, t) ≈ e - x_0
|
506 |
+
dx = self.dit(input_latents.to(torch.bfloat16), timestep, text_embeddings, pooled_text_embeddings)
|
507 |
+
dt = sigmas[infer_step + 1] - sigmas[infer_step] # dt is negative
|
508 |
+
sigma_history.append(dt)
|
509 |
+
|
510 |
+
# Apply Classifier-Free Guidance
|
511 |
+
if do_classifier_free_guidance:
|
512 |
+
cond_dx, uncond_dx = dx.chunk(2)
|
513 |
+
current_guidance_scale = guidance_scale if clip_t[0] <= t and t <= clip_t[1] else 1.0
|
514 |
+
dx = uncond_dx + current_guidance_scale * (cond_dx - uncond_dx)
|
515 |
+
|
516 |
+
if rescale_cfg > 0.0:
|
517 |
+
std_pos = torch.std(cond_dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
|
518 |
+
std_cfg = torch.std(dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
|
519 |
+
factor = std_pos / std_cfg
|
520 |
+
factor = rescale_cfg * factor + (1.0 - rescale_cfg)
|
521 |
+
dx = dx * factor
|
522 |
+
|
523 |
+
# --- Store the predicted velocity for averaging ---
|
524 |
+
predicted_velocities.append(dx.clone())
|
525 |
+
|
526 |
+
# --- Update Latents using standard Euler step ---
|
527 |
+
latents = latents + dt * dx
|
528 |
+
|
529 |
+
# --- Calculate and Store Intermediate Latent State (if requested) ---
|
530 |
+
if get_intermediate_steps:
|
531 |
+
dxs = torch.stack(predicted_velocities)
|
532 |
+
|
533 |
+
sigma_sum = sum(sigma_history)
|
534 |
+
normalized_sigma_history = [s / (sigma_sum) for s in sigma_history]
|
535 |
+
dts = torch.tensor(normalized_sigma_history, device=dxs.device, dtype=dxs.dtype).view(-1, 1, 1, 1, 1)
|
536 |
+
|
537 |
+
avg_dx = torch.sum(dxs * dts, dim=0)
|
538 |
+
observed_state = initial_noise - avg_dx # Calculate the desired intermediate state
|
539 |
+
intermediate_latents.append(observed_state.clone()) # Store its latent representation
|
540 |
+
|
541 |
+
# --- [Decode Final Latents to PIL Images] ---
|
542 |
+
self.vae = self.vae.to(device=latents.device, dtype=torch.float32) # Ensure VAE is ready
|
543 |
+
final_latents_scaled = latents.to(torch.float32) / self.vae.config.scaling_factor
|
544 |
+
final_image_tensors = self.vae.decode(final_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
|
545 |
+
final_image_tensors = ((final_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)
|
546 |
+
|
547 |
+
final_pil_images = []
|
548 |
+
for i, image_tensor in enumerate(final_image_tensors):
|
549 |
+
img = T.ToPILImage()(image_tensor.cpu())
|
550 |
+
if not is_safe_prompt[i]:
|
551 |
+
img = img.filter(ImageFilter.GaussianBlur(radius=30))
|
552 |
+
final_pil_images.append(img)
|
553 |
+
|
554 |
+
# --- [Decode Intermediate Latents to PIL Images (if requested)] ---
|
555 |
+
if get_intermediate_steps:
|
556 |
+
intermediate_pil_images = []
|
557 |
+
# Ensure VAE is still ready (it should be from final decoding)
|
558 |
+
for step_latents in tqdm.tqdm(intermediate_latents, desc="Decoding intermediates"):
|
559 |
+
step_latents_scaled = (
|
560 |
+
step_latents.to(dtype=torch.float32, device="cuda") / self.vae.config.scaling_factor
|
561 |
+
)
|
562 |
+
step_image_tensors = (
|
563 |
+
self.vae.decode(step_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
|
564 |
+
)
|
565 |
+
step_image_tensors = ((step_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)
|
566 |
+
|
567 |
+
current_step_pil = []
|
568 |
+
for i, image_tensor in enumerate(step_image_tensors):
|
569 |
+
img = T.ToPILImage()(image_tensor.cpu())
|
570 |
+
# Apply moderation blur consistency
|
571 |
+
if not is_safe_prompt[i]:
|
572 |
+
img = img.filter(ImageFilter.GaussianBlur(radius=30))
|
573 |
+
current_step_pil.append(img)
|
574 |
+
intermediate_pil_images.append(current_step_pil) # Append list of images for this step
|
575 |
+
|
576 |
+
return final_pil_images, intermediate_pil_images # Return both final and intermediate images
|
577 |
+
else:
|
578 |
+
return final_pil_images # Return only final images
|
579 |
+
|
580 |
+
@torch.no_grad()
|
581 |
+
def eval_with_loss(self, images, raw_text):
|
582 |
+
latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
|
583 |
+
|
584 |
+
tokens, masks = self.tokenization(raw_text)
|
585 |
+
tokens = [token.to(latents.device) for token in tokens]
|
586 |
+
masks = [mask.to(latents.device) for mask in masks]
|
587 |
+
text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
|
588 |
+
text_embeddings = [text_embedding for text_embedding in text_embeddings]
|
589 |
+
pooled_text_embeddings = pooled_text_embeddings.float()
|
590 |
+
|
591 |
+
# 2. Get noisy input via the rectified flow
|
592 |
+
is_finetuning = self.config.height > 256
|
593 |
+
noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
|
594 |
+
timesteps = self.discritize_timestep(t, self.n_timesteps)
|
595 |
+
|
596 |
+
# 3. Forward pass through the dit
|
597 |
+
preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)
|
598 |
+
|
599 |
+
# 4. Rectified flow matching loss
|
600 |
+
loss = self.rectified_flow_loss(noise_latents, noise, t, preds, reduce="none", use_weighting=False).mean(
|
601 |
+
dim=[1, 2, 3]
|
602 |
+
)
|
603 |
+
|
604 |
+
intervals = np.linspace(0, 1, 9)
|
605 |
+
t_interval = [(intervals[i], intervals[i + 1]) for i in range(len(intervals) - 1)]
|
606 |
+
|
607 |
+
loss_bins = defaultdict(list)
|
608 |
+
for i, interval in enumerate(t_interval, 0):
|
609 |
+
idx = (interval[0] < t) & (t < interval[1])
|
610 |
+
loss_bins[i].append(loss[idx])
|
611 |
+
|
612 |
+
return loss_bins
|