File size: 29,642 Bytes
6cd6a16 327b52c 6cd6a16 327b52c 6cd6a16 327b52c 6cd6a16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 |
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
import tqdm
from diffusers.models import AutoencoderKL
from diffusers.utils.torch_utils import randn_tensor
from loguru import logger
from PIL import Image, ImageFilter
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer
from models.mixin.flow_mixin import FlowMixin
from models.modeling_dit import MotifDiT
TOKEN_MAX_LENGTH: int = 256
DROP_PROB: float = 0.1
LATENT_CHANNELS: int = 4
VAE_DOWNSCALE_FACTOR: int = 8
SD3_LATENT_CHANNEL: int = 16
def generate_intervals(steps, ratio, start=1.0):
intervals = torch.linspace(start, 0, steps=steps)
intervals = intervals.pow(ratio)
return intervals
class MotifImage(nn.Module, FlowMixin):
"""
MotifImage Text-to-Image model.
This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders.
It uses a VAE (Variational Autoencoder) for image encoding and decoding.
Args:
config (MMDiTConfig): Configuration object for the MMDiT model.
Attributes:
dit (MotifDiT): MotifDiT model instance.
noise_scheduler (DDPMScheduler): Noise scheduler for the diffusion process.
normalize_img (Callable): Function to normalize images from [-1, 1] range.
unnormalize_img (Callable): Function to unnormalize images to [0, 1] range.
cond_drop_prob (float): Probability of dropping text embeddings during training.
snr_gamma (str): Strategy for weighting the loss based on Signal-to-Noise Ratio (SNR).
loss_weight_strategy (str): Strategy for weighting the loss.
vae (AutoencoderKL): Variational Autoencoder for image encoding and decoding.
t5 (T5EncoderModel): T5 encoder model for text encoding.
t5_tokenizer (T5Tokenizer): T5 tokenizer for text tokenization.
clip_l (CLIPModel): CLIP (Contrastive Language-Image Pre-training) model (large) for text encoding.
clip_l_tokenizer (CLIPTokenizerFast): CLIP tokenizer (large) for text tokenization.
clip_g (CLIPModel): CLIP model (giant) for text encoding.
clip_g_tokenizer (CLIPTokenizerFast): CLIP tokenizer (giant) for text tokenization.
tokenizers (List[Union[T5Tokenizer, CLIPTokenizerFast]]): List of tokenizers.
text_encoders (List[Union[T5EncoderModel, CLIPModel]]): List of text encoder models.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.dit = MotifDiT(config)
self.cond_drop_prob = 0.1
self.use_weighting = False
self._get_encoders()
self._freeze_encoders()
def forward(self, images: torch.Tensor, raw_text: str) -> torch.Tensor:
"""
Forward pass of the MotifDiT model.
Args:
images (torch.Tensor): Input images tensor, [0-1] ranged.
raw_text (List[str]): Input text string.
Returns:
torch.Tensor: Rectified flow matching loss.
"""
# 1. Encode images and texts
with torch.no_grad():
latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
tokens, masks = self.tokenization(raw_text)
tokens = [token.to(latents.device) for token in tokens]
masks = [mask.to(latents.device) for mask in masks]
text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
text_embeddings = self._drop_text_emb(text_embeddings)
text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
pooled_text_embeddings = pooled_text_embeddings.float()
# 2. Get noisy input via the rectified flow
is_finetuning = self.config.height > 256
noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
timesteps = self.discritize_timestep(t, self.n_timesteps)
# 3. Forward pass through the dit
preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)
# 4. Rectified flow matching loss
loss = self.rectified_flow_loss(latents, noise, t, preds, use_weighting=self.use_weighting)
return [loss]
def _get_encoders(self) -> None:
"""Initialize the VAE and text encoders."""
if self.config.vae_type == "SD3":
self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
elif self.config.vae_type == "SDXL":
self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
else:
raise ValueError(f"VAE type must be `SD3` or `SDXL` but self.config.vae_type is {self.config.vae_type}")
# Text encoders
# 1. T5-XXL from Google
self.t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
# 2. CLIP-L from OpenAI
self.clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
self.clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
# 3. CLIP-G from LAION
self.clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
self.clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
self.tokenizers = [self.t5_tokenizer, self.clip_l_tokenizer, self.clip_g_tokenizer]
self.text_encoders = [self.t5, self.clip_l, self.clip_g]
def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super(MotifImage, self).state_dict(destination, prefix, keep_vars)
exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."]
for key in list(state_dict.keys()):
if any(key.startswith(exclude_key) for exclude_key in exclude_keys):
state_dict.pop(key)
return state_dict
def load_state_dict(self, state_dict, strict=False):
"""
Load state dict and merge LoRA parameters if present.
Args:
state_dict (dict): State dictionary containing model parameters
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys in this module
Returns:
tuple: (missing_keys, unexpected_keys) lists of parameters that were missing or unexpected
"""
# Check if state_dict contains LoRA parameters
has_lora = any("lora_" in key for key in state_dict.keys())
if has_lora:
# If model doesn't have LoRA enabled but state_dict has LoRA params, enable it
if not hasattr(self.dit, "peft_config"):
logger.info("Enabling LoRA for parameter merging...")
# Use default values if not already configured
lora_rank = getattr(self.config, "lora_rank", 64)
lora_alpha = getattr(self.config, "lora_alpha", 8)
self.enable_lora(lora_rank, lora_alpha)
if has_lora:
try:
# Load LoRA parameters
# state_dict = {
# k.replace("base_layer.", ""): v
# for k, v in state_dict.items()
# if "lora_" not in k and "lora" not in k
# }
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
# Merge LoRA weights with base model
logger.info("Merging LoRA parameters with base model...")
for name, module in self.dit.named_modules():
if hasattr(module, "merge_and_unload"):
module.merge_and_unload()
logger.info("Successfully merged LoRA parameters")
except Exception as e:
logger.error(f"Error merging LoRA parameters: {str(e)}")
raise
else:
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
# Log summary of missing/unexpected parameters
missing_top_levels = set()
for key in missing_keys:
top_level_name = key.split(".")[0]
missing_top_levels.add(top_level_name)
if missing_top_levels:
logger.debug("Missing keys during loading at top level:")
for name in missing_top_levels:
logger.debug(name)
if unexpected_keys:
logger.debug("Unexpected keys found:")
for key in unexpected_keys:
logger.debug(key)
return missing_keys, unexpected_keys
def _freeze_encoders(self) -> None:
"""
freeze all encoders
"""
for encoder_module in [self.vae, self.clip_l, self.clip_g, self.t5]:
for param in encoder_module.parameters():
param.requires_grad = False
def tokenization(
self, raw_texts: List[str], repeat_if_short: bool = False
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Tokenizes a BATCH of input texts using multiple tokenizers efficiently.
Optionally repeats each text to fill the max length if it's shorter,
BEFORE passing the pre-processed batch to the tokenizer.
Args:
raw_texts (List[str]): A list of input text strings (the batch).
repeat_if_short (bool): If True and a text is short, repeat that text
to fill the context length. Defaults to True.
Returns:
Tuple[List[torch.Tensor], List[torch.Tensor]]:
- A list containing one batch tensor of input IDs per tokenizer.
Each tensor shape: [batch_size, max_length]
- A list containing one batch tensor of attention masks per tokenizer.
Each tensor shape: [batch_size, max_length]
"""
final_batch_tokens = []
final_batch_masks = []
# Process the batch with each tokenizer
for tokenizer in self.tokenizers:
effective_max_length = min(TOKEN_MAX_LENGTH, tokenizer.model_max_length)
# 1. Pre-process the batch: Create a new list of potentially repeated strings.
processed_texts_for_tokenizer = []
for text_item in raw_texts:
# Start with the original text for this item
processed_text = text_item
if repeat_if_short:
# Apply repetition logic individually based on text_item's length
num_initial_tokens = len(text_item.split())
available_length = effective_max_length - 2 # Heuristic
if num_initial_tokens > 0 and num_initial_tokens < available_length:
num_additional_repeats = available_length // (num_initial_tokens + 1)
if num_additional_repeats > 0:
total_repeats = 1 + num_additional_repeats
processed_text = " ".join([text_item] * total_repeats)
# Add the processed text (original or repeated) to the list for this tokenizer
processed_texts_for_tokenizer.append(processed_text)
# 2. Tokenize the entire batch of processed texts at once.
# Pass the list `processed_texts_for_tokenizer` directly to the tokenizer.
# The tokenizer's __call__ method should handle the batch efficiently.
batch_tok_output = tokenizer( # Call the tokenizer ONCE with the full list
processed_texts_for_tokenizer,
padding="max_length",
max_length=effective_max_length,
return_tensors="pt",
truncation=True,
)
# 3. Store the resulting batch tensors directly.
# The tokenizer should return tensors with shape [batch_size, max_length].
final_batch_tokens.append(batch_tok_output.input_ids)
final_batch_masks.append(batch_tok_output.attention_mask)
return final_batch_tokens, final_batch_masks
@torch.no_grad()
def text_encoding(
self, tokens: List[torch.Tensor], masks, noisy_pad=False, zero_masking=True
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""
Encode the tokenized text using multiple text encoders.
Args:
tokens (List[torch.Tensor]): List of tokenized text tensors.
Returns:
Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing a list of text embeddings and pooled text embeddings.
"""
t5_tokens, clip_l_tokens, clip_g_tokens = tokens
t5_masks, clip_l_masks, clip_g_masks = masks
t5_emb = self.t5(t5_tokens, attention_mask=t5_masks)[0]
if zero_masking:
t5_emb = t5_emb * (t5_tokens != self.t5_tokenizer.pad_token_id).unsqueeze(-1)
if noisy_pad:
t5_pad_noise = (
(t5_tokens == self.t5_tokenizer.pad_token_id).unsqueeze(-1) * torch.randn_like(t5_emb).cuda() * 0.008
)
t5_emb = t5_emb + t5_pad_noise
clip_l_emb = self.clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
clip_g_emb = self.clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768
clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280
clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768,
clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280,
def masking_wo_first_eos(token, eos):
idx = (token != eos).sum(dim=1)
mask = token != eos
arange = torch.arange(mask.size(0)).cuda()
mask[arange, idx] = True
mask = mask.unsqueeze(-1) # B x L x 1
return mask
if zero_masking:
clip_l_emb = clip_l_emb * masking_wo_first_eos(
clip_l_tokens, self.clip_l_tokenizer.eos_token_id
) # B x L x 768,
clip_g_emb = clip_g_emb * masking_wo_first_eos(
clip_g_tokens, self.clip_g_tokenizer.eos_token_id
) # B x L x 768,
if noisy_pad:
clip_l_pad_noise = (
~masking_wo_first_eos(clip_l_tokens, self.clip_l_tokenizer.eos_token_id)
* torch.randn_like(clip_l_emb).cuda()
* 0.08
)
clip_g_pad_noise = (
~masking_wo_first_eos(clip_g_tokens, self.clip_g_tokenizer.eos_token_id)
* torch.randn_like(clip_g_emb).cuda()
* 0.08
)
clip_l_emb = clip_l_emb + clip_l_pad_noise
clip_g_emb = clip_g_emb + clip_g_pad_noise
encodings = [t5_emb, clip_l_emb, clip_g_emb]
pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # cat by channel, B x 2048
return encodings, pooled_encodings
@torch.no_grad()
def prompt_embedding(self, prompts: str, device, noisy_pad=False, zero_masking=True):
tokens, masks = self.tokenization(prompts)
tokens = [token.to(device) for token in tokens]
masks = [mask.to(device) for mask in masks]
text_embeddings, pooled_text_embeddings = self.text_encoding(
tokens, masks, noisy_pad=noisy_pad, zero_masking=zero_masking
)
text_embeddings = [text_embedding.bfloat16() for text_embedding in text_embeddings]
pooled_text_embeddings = pooled_text_embeddings.bfloat16()
return text_embeddings, pooled_text_embeddings
@torch.no_grad()
def sample(
self,
raw_text: List[str],
steps: int = 50,
guidance_scale: float = 7.5,
resolution: List[int] = (256, 256),
pre_latent=None,
pre_timestep=None,
step_scaling=1.0,
noisy_pad=False,
zero_masking=False,
negative_prompt: Optional[List[str]] = None,
device: str = "cuda",
rescale_cfg=-1.0,
clip_t=[0.0, 1.0],
use_linear_quadratic_schedule=False,
linear_quadratic_emulating_steps=250,
prompt_rewriter=None,
moderator=None,
get_intermediate_steps: bool = False, # Defaulting to True based on user code
) -> Union[List[Image.Image], Tuple[List[Image.Image], List[List[Image.Image]]]]: # Updated return type hint
"""
Sample images using flow matching. Optionally returns intermediate step images
calculated via observed average velocity method.
Args:
raw_text (List[str]): raw text prompts
steps (int, optional): number of function estimations for flow matching ODE. Defaults to 50.
guidance_scale (float, optional): classifier free guidance scale. Defaults to 7.5.
resolution (List[int], optional): input and output resolution of raw images. Defaults to (256, 256).
device (str, optional): Defaults to 'cuda'.
pre_latent (Tensor, optional): the optional input to generate image with pre-defined latents.
for instance, it would be utilized for denoising or image-editing.
pre_timestep (float [0,1], optional): the pre-defined timestep. with `pre_latent`, image generation
can be done by starting with intermediate timestep.
step_scaling (float, default to 1.3): scaling factor for each ODE-solving.
use_linear_quadratic_schedule (bool, default to false): boolean option to linear-quaratic t schdule. If false, then linear t schdule.
linear_quadratic_emulating_steps (int, default to 250): N value in linear-quadratic t schedule from Meta moviegen paper
Reference: (https://ai.meta.com/static-resource/movie-gen-research-paper) Figure 10
get_intermediate_steps (bool, optional): Whether to calculate and return intermediate step images.
Calculation is based on initial_noise - avg(velocity). Defaults to True.
Returns:
Union[List[PIL.Image], Tuple[List[PIL.Image], List[List[PIL.Image]]]]:
If get_intermediate_steps is False: Returns a list of final PIL images.
If get_intermediate_steps is True: Returns a tuple containing:
- List[PIL.Image]: Final output PIL images.
- List[List[PIL.Image]]: List of intermediate PIL images. Each inner list contains
the batch of images for one intermediate step.
"""
if prompt_rewriter:
prompts = [prompt_rewriter.rewrite(prompt) for prompt in raw_text]
else:
prompts = raw_text
# Simplified check for rewriter status
if prompts == raw_text and prompt_rewriter is not None:
logger.debug("Prompt rewriter did not change the prompts.")
elif prompt_rewriter is None:
logger.debug("Prompt rewriter not provided.")
if moderator is None:
is_safe_prompt = [True for _ in prompts]
else:
is_safe_prompt = [moderator and moderator.is_safe_content(prompt, threshold=0.7) for prompt in prompts]
if not all(is_safe_prompt):
logger.warning("Noxious prompt detected. Output image(s) will be blurred.")
b = len(prompts)
h, w = resolution
# --- [Initial Latent Noise (e = x_1)] ---
latent_channels = 16
if pre_latent is None:
initial_noise = randn_tensor( # Store initial noise separately
(b, latent_channels, h // VAE_DOWNSCALE_FACTOR, w // VAE_DOWNSCALE_FACTOR),
device=device,
dtype=torch.float32, # Use float32 for calculations
)
else:
initial_noise = pre_latent.to(device=device, dtype=torch.float32)
if pre_timestep is not None and pre_timestep < 1.0: # Check if it's truly intermediate
logger.warning(
"Using pre_latent as initial_noise for average calculation, but pre_timestep suggests it's not pure noise. Results might be unexpected."
)
latents = initial_noise.clone() # Working latents for the ODE solver
# --- [Text Embeddings & CFG Setup] ---
text_embeddings, pooled_text_embeddings = self.prompt_embedding(
prompts, latents.device, noisy_pad=noisy_pad, zero_masking=zero_masking
)
text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in text_embeddings]
pooled_text_embeddings = pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
negative_text_embeddings = [
torch.zeros_like(text_embedding, device=text_embedding.device) for text_embedding in text_embeddings
]
negative_pooled_text_embeddings = torch.zeros_like(
pooled_text_embeddings, device=pooled_text_embeddings.device
)
text_embeddings = [
torch.cat([text_embedding, negative_text_embedding], dim=0)
for text_embedding, negative_text_embedding in zip(text_embeddings, negative_text_embeddings)
]
pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)
# if negative_prompt is None:
# negative_prompt = [""] * b
# logger.debug("No negative prompt provided, using empty strings for CFG.")
# negative_text_embeddings, negative_pooled_text_embeddings = self.prompt_embedding(negative_prompt, latents.device)
# negative_text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in negative_text_embeddings]
# negative_pooled_text_embeddings = negative_pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)
# text_embeddings = [torch.cat([pos_emb, neg_emb], dim=0) for pos_emb, neg_emb in zip(text_embeddings, negative_text_embeddings)]
# pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)
# --- [Timestep Schedule (Sigmas)] ---
# linear t schedule
sigmas = torch.linspace(1, 0, steps + 1) if not pre_timestep else torch.linspace(pre_timestep, 0, steps + 1)
if use_linear_quadratic_schedule:
# liner-quadratic t schedule
assert steps % 2 == 0
N = linear_quadratic_emulating_steps
sigmas = torch.concat(
[
torch.linspace(1, 0, N + 1)[: steps // 2],
torch.linspace(0, 1, steps // 2 + 1) ** 2 * (steps // 2 * 1 / N - 1) - (steps // 2 * 1 / N - 1),
]
)
# --- [Initialization for Intermediate Step Calculation] ---
# intermediate_latents will store the latent states for intermediate steps
intermediate_latents = [] if get_intermediate_steps else None
predicted_velocities = [] # Store dx from each step
sigma_history = []
# --- [Sampling Loop] ---
for infer_step, t in tqdm.tqdm(enumerate(sigmas[:-1]), total=len(sigmas[:-1]), desc="Sampling"):
# Prepare input for DiT model
if do_classifier_free_guidance:
input_latents = torch.cat([latents] * 2, dim=0)
else:
input_latents = latents
# Prepare timestep input
timestep = (t * 1000).round().long().to(latents.device)
timestep = timestep.expand(input_latents.shape[0]).to(torch.bfloat16) # Ensure timestep is bfloat16
# Predict velocity dx = v(x_t, t) ≈ e - x_0
dx = self.dit(input_latents.to(torch.bfloat16), timestep, text_embeddings, pooled_text_embeddings)
dt = sigmas[infer_step + 1] - sigmas[infer_step] # dt is negative
sigma_history.append(dt)
# Apply Classifier-Free Guidance
if do_classifier_free_guidance:
cond_dx, uncond_dx = dx.chunk(2)
current_guidance_scale = guidance_scale if clip_t[0] <= t and t <= clip_t[1] else 1.0
dx = uncond_dx + current_guidance_scale * (cond_dx - uncond_dx)
if rescale_cfg > 0.0:
std_pos = torch.std(cond_dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
std_cfg = torch.std(dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
factor = std_pos / std_cfg
factor = rescale_cfg * factor + (1.0 - rescale_cfg)
dx = dx * factor
# --- Store the predicted velocity for averaging ---
predicted_velocities.append(dx.clone())
# --- Update Latents using standard Euler step ---
latents = latents + dt * dx
# --- Calculate and Store Intermediate Latent State (if requested) ---
if get_intermediate_steps:
dxs = torch.stack(predicted_velocities)
sigma_sum = sum(sigma_history)
normalized_sigma_history = [s / (sigma_sum) for s in sigma_history]
dts = torch.tensor(normalized_sigma_history, device=dxs.device, dtype=dxs.dtype).view(-1, 1, 1, 1, 1)
avg_dx = torch.sum(dxs * dts, dim=0)
observed_state = initial_noise - avg_dx # Calculate the desired intermediate state
intermediate_latents.append(observed_state.clone()) # Store its latent representation
# --- [Decode Final Latents to PIL Images] ---
self.vae = self.vae.to(device=latents.device, dtype=torch.float32) # Ensure VAE is ready
final_latents_scaled = latents.to(torch.float32) / self.vae.config.scaling_factor
final_image_tensors = self.vae.decode(final_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
final_image_tensors = ((final_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)
final_pil_images = []
for i, image_tensor in enumerate(final_image_tensors):
img = T.ToPILImage()(image_tensor.cpu())
if not is_safe_prompt[i]:
img = img.filter(ImageFilter.GaussianBlur(radius=30))
final_pil_images.append(img)
# --- [Decode Intermediate Latents to PIL Images (if requested)] ---
if get_intermediate_steps:
intermediate_pil_images = []
# Ensure VAE is still ready (it should be from final decoding)
for step_latents in tqdm.tqdm(intermediate_latents, desc="Decoding intermediates"):
step_latents_scaled = (
step_latents.to(dtype=torch.float32, device="cuda") / self.vae.config.scaling_factor
)
step_image_tensors = (
self.vae.decode(step_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
)
step_image_tensors = ((step_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)
current_step_pil = []
for i, image_tensor in enumerate(step_image_tensors):
img = T.ToPILImage()(image_tensor.cpu())
# Apply moderation blur consistency
if not is_safe_prompt[i]:
img = img.filter(ImageFilter.GaussianBlur(radius=30))
current_step_pil.append(img)
intermediate_pil_images.append(current_step_pil) # Append list of images for this step
return final_pil_images, intermediate_pil_images # Return both final and intermediate images
else:
return final_pil_images # Return only final images
@torch.no_grad()
def eval_with_loss(self, images, raw_text):
latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
tokens, masks = self.tokenization(raw_text)
tokens = [token.to(latents.device) for token in tokens]
masks = [mask.to(latents.device) for mask in masks]
text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
text_embeddings = [text_embedding for text_embedding in text_embeddings]
pooled_text_embeddings = pooled_text_embeddings.float()
# 2. Get noisy input via the rectified flow
is_finetuning = self.config.height > 256
noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
timesteps = self.discritize_timestep(t, self.n_timesteps)
# 3. Forward pass through the dit
preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)
# 4. Rectified flow matching loss
loss = self.rectified_flow_loss(noise_latents, noise, t, preds, reduce="none", use_weighting=False).mean(
dim=[1, 2, 3]
)
intervals = np.linspace(0, 1, 9)
t_interval = [(intervals[i], intervals[i + 1]) for i in range(len(intervals) - 1)]
loss_bins = defaultdict(list)
for i, interval in enumerate(t_interval, 0):
idx = (interval[0] < t) & (t < interval[1])
loss_bins[i].append(loss[idx])
return loss_bins
|