Motif-Image-6B-Preview / models /modeling_motifimage.py
beomgyu-kim's picture
refactor/motifimage (#2)
327b52c verified
from collections import defaultdict
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
import tqdm
from diffusers.models import AutoencoderKL
from diffusers.utils.torch_utils import randn_tensor
from loguru import logger
from PIL import Image, ImageFilter
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer
from models.mixin.flow_mixin import FlowMixin
from models.modeling_dit import MotifDiT
TOKEN_MAX_LENGTH: int = 256
DROP_PROB: float = 0.1
LATENT_CHANNELS: int = 4
VAE_DOWNSCALE_FACTOR: int = 8
SD3_LATENT_CHANNEL: int = 16
def generate_intervals(steps, ratio, start=1.0):
intervals = torch.linspace(start, 0, steps=steps)
intervals = intervals.pow(ratio)
return intervals
class MotifImage(nn.Module, FlowMixin):
"""
MotifImage Text-to-Image model.
This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders.
It uses a VAE (Variational Autoencoder) for image encoding and decoding.
Args:
config (MMDiTConfig): Configuration object for the MMDiT model.
Attributes:
dit (MotifDiT): MotifDiT model instance.
noise_scheduler (DDPMScheduler): Noise scheduler for the diffusion process.
normalize_img (Callable): Function to normalize images from [-1, 1] range.
unnormalize_img (Callable): Function to unnormalize images to [0, 1] range.
cond_drop_prob (float): Probability of dropping text embeddings during training.
snr_gamma (str): Strategy for weighting the loss based on Signal-to-Noise Ratio (SNR).
loss_weight_strategy (str): Strategy for weighting the loss.
vae (AutoencoderKL): Variational Autoencoder for image encoding and decoding.
t5 (T5EncoderModel): T5 encoder model for text encoding.
t5_tokenizer (T5Tokenizer): T5 tokenizer for text tokenization.
clip_l (CLIPModel): CLIP (Contrastive Language-Image Pre-training) model (large) for text encoding.
clip_l_tokenizer (CLIPTokenizerFast): CLIP tokenizer (large) for text tokenization.
clip_g (CLIPModel): CLIP model (giant) for text encoding.
clip_g_tokenizer (CLIPTokenizerFast): CLIP tokenizer (giant) for text tokenization.
tokenizers (List[Union[T5Tokenizer, CLIPTokenizerFast]]): List of tokenizers.
text_encoders (List[Union[T5EncoderModel, CLIPModel]]): List of text encoder models.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.dit = MotifDiT(config)
self.cond_drop_prob = 0.1
self.use_weighting = False
self._get_encoders()
self._freeze_encoders()
def forward(self, images: torch.Tensor, raw_text: str) -> torch.Tensor:
"""
Forward pass of the MotifDiT model.
Args:
images (torch.Tensor): Input images tensor, [0-1] ranged.
raw_text (List[str]): Input text string.
Returns:
torch.Tensor: Rectified flow matching loss.
"""
# 1. Encode images and texts
with torch.no_grad():
latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
tokens, masks = self.tokenization(raw_text)
tokens = [token.to(latents.device) for token in tokens]
masks = [mask.to(latents.device) for mask in masks]
text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
text_embeddings = self._drop_text_emb(text_embeddings)
text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
pooled_text_embeddings = pooled_text_embeddings.float()
# 2. Get noisy input via the rectified flow
is_finetuning = self.config.height > 256
noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
timesteps = self.discritize_timestep(t, self.n_timesteps)
# 3. Forward pass through the dit
preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)
# 4. Rectified flow matching loss
loss = self.rectified_flow_loss(latents, noise, t, preds, use_weighting=self.use_weighting)
return [loss]
def _get_encoders(self) -> None:
"""Initialize the VAE and text encoders."""
if self.config.vae_type == "SD3":
self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
elif self.config.vae_type == "SDXL":
self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
else:
raise ValueError(f"VAE type must be `SD3` or `SDXL` but self.config.vae_type is {self.config.vae_type}")
# Text encoders
# 1. T5-XXL from Google
self.t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
# 2. CLIP-L from OpenAI
self.clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
self.clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
# 3. CLIP-G from LAION
self.clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
self.clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
self.tokenizers = [self.t5_tokenizer, self.clip_l_tokenizer, self.clip_g_tokenizer]
self.text_encoders = [self.t5, self.clip_l, self.clip_g]
def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super(MotifImage, self).state_dict(destination, prefix, keep_vars)
exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."]
for key in list(state_dict.keys()):
if any(key.startswith(exclude_key) for exclude_key in exclude_keys):
state_dict.pop(key)
return state_dict
def load_state_dict(self, state_dict, strict=False):
"""
Load state dict and merge LoRA parameters if present.
Args:
state_dict (dict): State dictionary containing model parameters
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys in this module
Returns:
tuple: (missing_keys, unexpected_keys) lists of parameters that were missing or unexpected
"""
# Check if state_dict contains LoRA parameters
has_lora = any("lora_" in key for key in state_dict.keys())
if has_lora:
# If model doesn't have LoRA enabled but state_dict has LoRA params, enable it
if not hasattr(self.dit, "peft_config"):
logger.info("Enabling LoRA for parameter merging...")
# Use default values if not already configured
lora_rank = getattr(self.config, "lora_rank", 64)
lora_alpha = getattr(self.config, "lora_alpha", 8)
self.enable_lora(lora_rank, lora_alpha)
if has_lora:
try:
# Load LoRA parameters
# state_dict = {
# k.replace("base_layer.", ""): v
# for k, v in state_dict.items()
# if "lora_" not in k and "lora" not in k
# }
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
# Merge LoRA weights with base model
logger.info("Merging LoRA parameters with base model...")
for name, module in self.dit.named_modules():
if hasattr(module, "merge_and_unload"):
module.merge_and_unload()
logger.info("Successfully merged LoRA parameters")
except Exception as e:
logger.error(f"Error merging LoRA parameters: {str(e)}")
raise
else:
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
# Log summary of missing/unexpected parameters
missing_top_levels = set()
for key in missing_keys:
top_level_name = key.split(".")[0]
missing_top_levels.add(top_level_name)
if missing_top_levels:
logger.debug("Missing keys during loading at top level:")
for name in missing_top_levels:
logger.debug(name)
if unexpected_keys:
logger.debug("Unexpected keys found:")
for key in unexpected_keys:
logger.debug(key)
return missing_keys, unexpected_keys
def _freeze_encoders(self) -> None:
"""
freeze all encoders
"""
for encoder_module in [self.vae, self.clip_l, self.clip_g, self.t5]:
for param in encoder_module.parameters():
param.requires_grad = False
def tokenization(
self, raw_texts: List[str], repeat_if_short: bool = False
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
Tokenizes a BATCH of input texts using multiple tokenizers efficiently.
Optionally repeats each text to fill the max length if it's shorter,
BEFORE passing the pre-processed batch to the tokenizer.
Args:
raw_texts (List[str]): A list of input text strings (the batch).
repeat_if_short (bool): If True and a text is short, repeat that text
to fill the context length. Defaults to True.
Returns:
Tuple[List[torch.Tensor], List[torch.Tensor]]:
- A list containing one batch tensor of input IDs per tokenizer.
Each tensor shape: [batch_size, max_length]
- A list containing one batch tensor of attention masks per tokenizer.
Each tensor shape: [batch_size, max_length]
"""
final_batch_tokens = []
final_batch_masks = []
# Process the batch with each tokenizer
for tokenizer in self.tokenizers:
effective_max_length = min(TOKEN_MAX_LENGTH, tokenizer.model_max_length)
# 1. Pre-process the batch: Create a new list of potentially repeated strings.
processed_texts_for_tokenizer = []
for text_item in raw_texts:
# Start with the original text for this item
processed_text = text_item
if repeat_if_short:
# Apply repetition logic individually based on text_item's length
num_initial_tokens = len(text_item.split())
available_length = effective_max_length - 2 # Heuristic
if num_initial_tokens > 0 and num_initial_tokens < available_length:
num_additional_repeats = available_length // (num_initial_tokens + 1)
if num_additional_repeats > 0:
total_repeats = 1 + num_additional_repeats
processed_text = " ".join([text_item] * total_repeats)
# Add the processed text (original or repeated) to the list for this tokenizer
processed_texts_for_tokenizer.append(processed_text)
# 2. Tokenize the entire batch of processed texts at once.
# Pass the list `processed_texts_for_tokenizer` directly to the tokenizer.
# The tokenizer's __call__ method should handle the batch efficiently.
batch_tok_output = tokenizer( # Call the tokenizer ONCE with the full list
processed_texts_for_tokenizer,
padding="max_length",
max_length=effective_max_length,
return_tensors="pt",
truncation=True,
)
# 3. Store the resulting batch tensors directly.
# The tokenizer should return tensors with shape [batch_size, max_length].
final_batch_tokens.append(batch_tok_output.input_ids)
final_batch_masks.append(batch_tok_output.attention_mask)
return final_batch_tokens, final_batch_masks
@torch.no_grad()
def text_encoding(
self, tokens: List[torch.Tensor], masks, noisy_pad=False, zero_masking=True
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""
Encode the tokenized text using multiple text encoders.
Args:
tokens (List[torch.Tensor]): List of tokenized text tensors.
Returns:
Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing a list of text embeddings and pooled text embeddings.
"""
t5_tokens, clip_l_tokens, clip_g_tokens = tokens
t5_masks, clip_l_masks, clip_g_masks = masks
t5_emb = self.t5(t5_tokens, attention_mask=t5_masks)[0]
if zero_masking:
t5_emb = t5_emb * (t5_tokens != self.t5_tokenizer.pad_token_id).unsqueeze(-1)
if noisy_pad:
t5_pad_noise = (
(t5_tokens == self.t5_tokenizer.pad_token_id).unsqueeze(-1) * torch.randn_like(t5_emb).cuda() * 0.008
)
t5_emb = t5_emb + t5_pad_noise
clip_l_emb = self.clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
clip_g_emb = self.clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768
clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280
clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768,
clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280,
def masking_wo_first_eos(token, eos):
idx = (token != eos).sum(dim=1)
mask = token != eos
arange = torch.arange(mask.size(0)).cuda()
mask[arange, idx] = True
mask = mask.unsqueeze(-1) # B x L x 1
return mask
if zero_masking:
clip_l_emb = clip_l_emb * masking_wo_first_eos(
clip_l_tokens, self.clip_l_tokenizer.eos_token_id
) # B x L x 768,
clip_g_emb = clip_g_emb * masking_wo_first_eos(
clip_g_tokens, self.clip_g_tokenizer.eos_token_id
) # B x L x 768,
if noisy_pad:
clip_l_pad_noise = (
~masking_wo_first_eos(clip_l_tokens, self.clip_l_tokenizer.eos_token_id)
* torch.randn_like(clip_l_emb).cuda()
* 0.08
)
clip_g_pad_noise = (
~masking_wo_first_eos(clip_g_tokens, self.clip_g_tokenizer.eos_token_id)
* torch.randn_like(clip_g_emb).cuda()
* 0.08
)
clip_l_emb = clip_l_emb + clip_l_pad_noise
clip_g_emb = clip_g_emb + clip_g_pad_noise
encodings = [t5_emb, clip_l_emb, clip_g_emb]
pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # cat by channel, B x 2048
return encodings, pooled_encodings
@torch.no_grad()
def prompt_embedding(self, prompts: str, device, noisy_pad=False, zero_masking=True):
tokens, masks = self.tokenization(prompts)
tokens = [token.to(device) for token in tokens]
masks = [mask.to(device) for mask in masks]
text_embeddings, pooled_text_embeddings = self.text_encoding(
tokens, masks, noisy_pad=noisy_pad, zero_masking=zero_masking
)
text_embeddings = [text_embedding.bfloat16() for text_embedding in text_embeddings]
pooled_text_embeddings = pooled_text_embeddings.bfloat16()
return text_embeddings, pooled_text_embeddings
@torch.no_grad()
def sample(
self,
raw_text: List[str],
steps: int = 50,
guidance_scale: float = 7.5,
resolution: List[int] = (256, 256),
pre_latent=None,
pre_timestep=None,
step_scaling=1.0,
noisy_pad=False,
zero_masking=False,
negative_prompt: Optional[List[str]] = None,
device: str = "cuda",
rescale_cfg=-1.0,
clip_t=[0.0, 1.0],
use_linear_quadratic_schedule=False,
linear_quadratic_emulating_steps=250,
prompt_rewriter=None,
moderator=None,
get_intermediate_steps: bool = False, # Defaulting to True based on user code
) -> Union[List[Image.Image], Tuple[List[Image.Image], List[List[Image.Image]]]]: # Updated return type hint
"""
Sample images using flow matching. Optionally returns intermediate step images
calculated via observed average velocity method.
Args:
raw_text (List[str]): raw text prompts
steps (int, optional): number of function estimations for flow matching ODE. Defaults to 50.
guidance_scale (float, optional): classifier free guidance scale. Defaults to 7.5.
resolution (List[int], optional): input and output resolution of raw images. Defaults to (256, 256).
device (str, optional): Defaults to 'cuda'.
pre_latent (Tensor, optional): the optional input to generate image with pre-defined latents.
for instance, it would be utilized for denoising or image-editing.
pre_timestep (float [0,1], optional): the pre-defined timestep. with `pre_latent`, image generation
can be done by starting with intermediate timestep.
step_scaling (float, default to 1.3): scaling factor for each ODE-solving.
use_linear_quadratic_schedule (bool, default to false): boolean option to linear-quaratic t schdule. If false, then linear t schdule.
linear_quadratic_emulating_steps (int, default to 250): N value in linear-quadratic t schedule from Meta moviegen paper
Reference: (https://ai.meta.com/static-resource/movie-gen-research-paper) Figure 10
get_intermediate_steps (bool, optional): Whether to calculate and return intermediate step images.
Calculation is based on initial_noise - avg(velocity). Defaults to True.
Returns:
Union[List[PIL.Image], Tuple[List[PIL.Image], List[List[PIL.Image]]]]:
If get_intermediate_steps is False: Returns a list of final PIL images.
If get_intermediate_steps is True: Returns a tuple containing:
- List[PIL.Image]: Final output PIL images.
- List[List[PIL.Image]]: List of intermediate PIL images. Each inner list contains
the batch of images for one intermediate step.
"""
if prompt_rewriter:
prompts = [prompt_rewriter.rewrite(prompt) for prompt in raw_text]
else:
prompts = raw_text
# Simplified check for rewriter status
if prompts == raw_text and prompt_rewriter is not None:
logger.debug("Prompt rewriter did not change the prompts.")
elif prompt_rewriter is None:
logger.debug("Prompt rewriter not provided.")
if moderator is None:
is_safe_prompt = [True for _ in prompts]
else:
is_safe_prompt = [moderator and moderator.is_safe_content(prompt, threshold=0.7) for prompt in prompts]
if not all(is_safe_prompt):
logger.warning("Noxious prompt detected. Output image(s) will be blurred.")
b = len(prompts)
h, w = resolution
# --- [Initial Latent Noise (e = x_1)] ---
latent_channels = 16
if pre_latent is None:
initial_noise = randn_tensor( # Store initial noise separately
(b, latent_channels, h // VAE_DOWNSCALE_FACTOR, w // VAE_DOWNSCALE_FACTOR),
device=device,
dtype=torch.float32, # Use float32 for calculations
)
else:
initial_noise = pre_latent.to(device=device, dtype=torch.float32)
if pre_timestep is not None and pre_timestep < 1.0: # Check if it's truly intermediate
logger.warning(
"Using pre_latent as initial_noise for average calculation, but pre_timestep suggests it's not pure noise. Results might be unexpected."
)
latents = initial_noise.clone() # Working latents for the ODE solver
# --- [Text Embeddings & CFG Setup] ---
text_embeddings, pooled_text_embeddings = self.prompt_embedding(
prompts, latents.device, noisy_pad=noisy_pad, zero_masking=zero_masking
)
text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in text_embeddings]
pooled_text_embeddings = pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
negative_text_embeddings = [
torch.zeros_like(text_embedding, device=text_embedding.device) for text_embedding in text_embeddings
]
negative_pooled_text_embeddings = torch.zeros_like(
pooled_text_embeddings, device=pooled_text_embeddings.device
)
text_embeddings = [
torch.cat([text_embedding, negative_text_embedding], dim=0)
for text_embedding, negative_text_embedding in zip(text_embeddings, negative_text_embeddings)
]
pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)
# if negative_prompt is None:
# negative_prompt = [""] * b
# logger.debug("No negative prompt provided, using empty strings for CFG.")
# negative_text_embeddings, negative_pooled_text_embeddings = self.prompt_embedding(negative_prompt, latents.device)
# negative_text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in negative_text_embeddings]
# negative_pooled_text_embeddings = negative_pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)
# text_embeddings = [torch.cat([pos_emb, neg_emb], dim=0) for pos_emb, neg_emb in zip(text_embeddings, negative_text_embeddings)]
# pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)
# --- [Timestep Schedule (Sigmas)] ---
# linear t schedule
sigmas = torch.linspace(1, 0, steps + 1) if not pre_timestep else torch.linspace(pre_timestep, 0, steps + 1)
if use_linear_quadratic_schedule:
# liner-quadratic t schedule
assert steps % 2 == 0
N = linear_quadratic_emulating_steps
sigmas = torch.concat(
[
torch.linspace(1, 0, N + 1)[: steps // 2],
torch.linspace(0, 1, steps // 2 + 1) ** 2 * (steps // 2 * 1 / N - 1) - (steps // 2 * 1 / N - 1),
]
)
# --- [Initialization for Intermediate Step Calculation] ---
# intermediate_latents will store the latent states for intermediate steps
intermediate_latents = [] if get_intermediate_steps else None
predicted_velocities = [] # Store dx from each step
sigma_history = []
# --- [Sampling Loop] ---
for infer_step, t in tqdm.tqdm(enumerate(sigmas[:-1]), total=len(sigmas[:-1]), desc="Sampling"):
# Prepare input for DiT model
if do_classifier_free_guidance:
input_latents = torch.cat([latents] * 2, dim=0)
else:
input_latents = latents
# Prepare timestep input
timestep = (t * 1000).round().long().to(latents.device)
timestep = timestep.expand(input_latents.shape[0]).to(torch.bfloat16) # Ensure timestep is bfloat16
# Predict velocity dx = v(x_t, t) ≈ e - x_0
dx = self.dit(input_latents.to(torch.bfloat16), timestep, text_embeddings, pooled_text_embeddings)
dt = sigmas[infer_step + 1] - sigmas[infer_step] # dt is negative
sigma_history.append(dt)
# Apply Classifier-Free Guidance
if do_classifier_free_guidance:
cond_dx, uncond_dx = dx.chunk(2)
current_guidance_scale = guidance_scale if clip_t[0] <= t and t <= clip_t[1] else 1.0
dx = uncond_dx + current_guidance_scale * (cond_dx - uncond_dx)
if rescale_cfg > 0.0:
std_pos = torch.std(cond_dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
std_cfg = torch.std(dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
factor = std_pos / std_cfg
factor = rescale_cfg * factor + (1.0 - rescale_cfg)
dx = dx * factor
# --- Store the predicted velocity for averaging ---
predicted_velocities.append(dx.clone())
# --- Update Latents using standard Euler step ---
latents = latents + dt * dx
# --- Calculate and Store Intermediate Latent State (if requested) ---
if get_intermediate_steps:
dxs = torch.stack(predicted_velocities)
sigma_sum = sum(sigma_history)
normalized_sigma_history = [s / (sigma_sum) for s in sigma_history]
dts = torch.tensor(normalized_sigma_history, device=dxs.device, dtype=dxs.dtype).view(-1, 1, 1, 1, 1)
avg_dx = torch.sum(dxs * dts, dim=0)
observed_state = initial_noise - avg_dx # Calculate the desired intermediate state
intermediate_latents.append(observed_state.clone()) # Store its latent representation
# --- [Decode Final Latents to PIL Images] ---
self.vae = self.vae.to(device=latents.device, dtype=torch.float32) # Ensure VAE is ready
final_latents_scaled = latents.to(torch.float32) / self.vae.config.scaling_factor
final_image_tensors = self.vae.decode(final_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
final_image_tensors = ((final_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)
final_pil_images = []
for i, image_tensor in enumerate(final_image_tensors):
img = T.ToPILImage()(image_tensor.cpu())
if not is_safe_prompt[i]:
img = img.filter(ImageFilter.GaussianBlur(radius=30))
final_pil_images.append(img)
# --- [Decode Intermediate Latents to PIL Images (if requested)] ---
if get_intermediate_steps:
intermediate_pil_images = []
# Ensure VAE is still ready (it should be from final decoding)
for step_latents in tqdm.tqdm(intermediate_latents, desc="Decoding intermediates"):
step_latents_scaled = (
step_latents.to(dtype=torch.float32, device="cuda") / self.vae.config.scaling_factor
)
step_image_tensors = (
self.vae.decode(step_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
)
step_image_tensors = ((step_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)
current_step_pil = []
for i, image_tensor in enumerate(step_image_tensors):
img = T.ToPILImage()(image_tensor.cpu())
# Apply moderation blur consistency
if not is_safe_prompt[i]:
img = img.filter(ImageFilter.GaussianBlur(radius=30))
current_step_pil.append(img)
intermediate_pil_images.append(current_step_pil) # Append list of images for this step
return final_pil_images, intermediate_pil_images # Return both final and intermediate images
else:
return final_pil_images # Return only final images
@torch.no_grad()
def eval_with_loss(self, images, raw_text):
latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
tokens, masks = self.tokenization(raw_text)
tokens = [token.to(latents.device) for token in tokens]
masks = [mask.to(latents.device) for mask in masks]
text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
text_embeddings = [text_embedding for text_embedding in text_embeddings]
pooled_text_embeddings = pooled_text_embeddings.float()
# 2. Get noisy input via the rectified flow
is_finetuning = self.config.height > 256
noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
timesteps = self.discritize_timestep(t, self.n_timesteps)
# 3. Forward pass through the dit
preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)
# 4. Rectified flow matching loss
loss = self.rectified_flow_loss(noise_latents, noise, t, preds, reduce="none", use_weighting=False).mean(
dim=[1, 2, 3]
)
intervals = np.linspace(0, 1, 9)
t_interval = [(intervals[i], intervals[i + 1]) for i in range(len(intervals) - 1)]
loss_bins = defaultdict(list)
for i, interval in enumerate(t_interval, 0):
idx = (interval[0] < t) & (t < interval[1])
loss_bins[i].append(loss[idx])
return loss_bins