from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torchvision.transforms as T import tqdm from diffusers.models import AutoencoderKL from diffusers.utils.torch_utils import randn_tensor from loguru import logger from PIL import Image, ImageFilter from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer from models.mixin.flow_mixin import FlowMixin from models.modeling_dit import MotifDiT TOKEN_MAX_LENGTH: int = 256 DROP_PROB: float = 0.1 LATENT_CHANNELS: int = 4 VAE_DOWNSCALE_FACTOR: int = 8 SD3_LATENT_CHANNEL: int = 16 def generate_intervals(steps, ratio, start=1.0): intervals = torch.linspace(start, 0, steps=steps) intervals = intervals.pow(ratio) return intervals class MotifImage(nn.Module, FlowMixin): """ MotifImage Text-to-Image model. This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders. It uses a VAE (Variational Autoencoder) for image encoding and decoding. Args: config (MMDiTConfig): Configuration object for the MMDiT model. Attributes: dit (MotifDiT): MotifDiT model instance. noise_scheduler (DDPMScheduler): Noise scheduler for the diffusion process. normalize_img (Callable): Function to normalize images from [-1, 1] range. unnormalize_img (Callable): Function to unnormalize images to [0, 1] range. cond_drop_prob (float): Probability of dropping text embeddings during training. snr_gamma (str): Strategy for weighting the loss based on Signal-to-Noise Ratio (SNR). loss_weight_strategy (str): Strategy for weighting the loss. vae (AutoencoderKL): Variational Autoencoder for image encoding and decoding. t5 (T5EncoderModel): T5 encoder model for text encoding. t5_tokenizer (T5Tokenizer): T5 tokenizer for text tokenization. clip_l (CLIPModel): CLIP (Contrastive Language-Image Pre-training) model (large) for text encoding. clip_l_tokenizer (CLIPTokenizerFast): CLIP tokenizer (large) for text tokenization. clip_g (CLIPModel): CLIP model (giant) for text encoding. clip_g_tokenizer (CLIPTokenizerFast): CLIP tokenizer (giant) for text tokenization. tokenizers (List[Union[T5Tokenizer, CLIPTokenizerFast]]): List of tokenizers. text_encoders (List[Union[T5EncoderModel, CLIPModel]]): List of text encoder models. """ def __init__(self, config): super().__init__() self.config = config self.dit = MotifDiT(config) self.cond_drop_prob = 0.1 self.use_weighting = False self._get_encoders() self._freeze_encoders() def forward(self, images: torch.Tensor, raw_text: str) -> torch.Tensor: """ Forward pass of the MotifDiT model. Args: images (torch.Tensor): Input images tensor, [0-1] ranged. raw_text (List[str]): Input text string. Returns: torch.Tensor: Rectified flow matching loss. """ # 1. Encode images and texts with torch.no_grad(): latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor tokens, masks = self.tokenization(raw_text) tokens = [token.to(latents.device) for token in tokens] masks = [mask.to(latents.device) for mask in masks] text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks) text_embeddings = self._drop_text_emb(text_embeddings) text_embeddings = [text_embedding.float() for text_embedding in text_embeddings] pooled_text_embeddings = pooled_text_embeddings.float() # 2. Get noisy input via the rectified flow is_finetuning = self.config.height > 256 noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning) timesteps = self.discritize_timestep(t, self.n_timesteps) # 3. Forward pass through the dit preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings) # 4. Rectified flow matching loss loss = self.rectified_flow_loss(latents, noise, t, preds, use_weighting=self.use_weighting) return [loss] def _get_encoders(self) -> None: """Initialize the VAE and text encoders.""" if self.config.vae_type == "SD3": self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae") elif self.config.vae_type == "SDXL": self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") else: raise ValueError(f"VAE type must be `SD3` or `SDXL` but self.config.vae_type is {self.config.vae_type}") # Text encoders # 1. T5-XXL from Google self.t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16) self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl") # 2. CLIP-L from OpenAI self.clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16) self.clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14") # 3. CLIP-G from LAION self.clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16) self.clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") self.tokenizers = [self.t5_tokenizer, self.clip_l_tokenizer, self.clip_g_tokenizer] self.text_encoders = [self.t5, self.clip_l, self.clip_g] def state_dict(self, destination=None, prefix="", keep_vars=False): state_dict = super(MotifImage, self).state_dict(destination, prefix, keep_vars) exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."] for key in list(state_dict.keys()): if any(key.startswith(exclude_key) for exclude_key in exclude_keys): state_dict.pop(key) return state_dict def load_state_dict(self, state_dict, strict=False): """ Load state dict and merge LoRA parameters if present. Args: state_dict (dict): State dictionary containing model parameters strict (bool): Whether to strictly enforce that the keys in state_dict match the keys in this module Returns: tuple: (missing_keys, unexpected_keys) lists of parameters that were missing or unexpected """ # Check if state_dict contains LoRA parameters has_lora = any("lora_" in key for key in state_dict.keys()) if has_lora: # If model doesn't have LoRA enabled but state_dict has LoRA params, enable it if not hasattr(self.dit, "peft_config"): logger.info("Enabling LoRA for parameter merging...") # Use default values if not already configured lora_rank = getattr(self.config, "lora_rank", 64) lora_alpha = getattr(self.config, "lora_alpha", 8) self.enable_lora(lora_rank, lora_alpha) if has_lora: try: # Load LoRA parameters # state_dict = { # k.replace("base_layer.", ""): v # for k, v in state_dict.items() # if "lora_" not in k and "lora" not in k # } missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) # Merge LoRA weights with base model logger.info("Merging LoRA parameters with base model...") for name, module in self.dit.named_modules(): if hasattr(module, "merge_and_unload"): module.merge_and_unload() logger.info("Successfully merged LoRA parameters") except Exception as e: logger.error(f"Error merging LoRA parameters: {str(e)}") raise else: missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) # Log summary of missing/unexpected parameters missing_top_levels = set() for key in missing_keys: top_level_name = key.split(".")[0] missing_top_levels.add(top_level_name) if missing_top_levels: logger.debug("Missing keys during loading at top level:") for name in missing_top_levels: logger.debug(name) if unexpected_keys: logger.debug("Unexpected keys found:") for key in unexpected_keys: logger.debug(key) return missing_keys, unexpected_keys def _freeze_encoders(self) -> None: """ freeze all encoders """ for encoder_module in [self.vae, self.clip_l, self.clip_g, self.t5]: for param in encoder_module.parameters(): param.requires_grad = False def tokenization( self, raw_texts: List[str], repeat_if_short: bool = False ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ Tokenizes a BATCH of input texts using multiple tokenizers efficiently. Optionally repeats each text to fill the max length if it's shorter, BEFORE passing the pre-processed batch to the tokenizer. Args: raw_texts (List[str]): A list of input text strings (the batch). repeat_if_short (bool): If True and a text is short, repeat that text to fill the context length. Defaults to True. Returns: Tuple[List[torch.Tensor], List[torch.Tensor]]: - A list containing one batch tensor of input IDs per tokenizer. Each tensor shape: [batch_size, max_length] - A list containing one batch tensor of attention masks per tokenizer. Each tensor shape: [batch_size, max_length] """ final_batch_tokens = [] final_batch_masks = [] # Process the batch with each tokenizer for tokenizer in self.tokenizers: effective_max_length = min(TOKEN_MAX_LENGTH, tokenizer.model_max_length) # 1. Pre-process the batch: Create a new list of potentially repeated strings. processed_texts_for_tokenizer = [] for text_item in raw_texts: # Start with the original text for this item processed_text = text_item if repeat_if_short: # Apply repetition logic individually based on text_item's length num_initial_tokens = len(text_item.split()) available_length = effective_max_length - 2 # Heuristic if num_initial_tokens > 0 and num_initial_tokens < available_length: num_additional_repeats = available_length // (num_initial_tokens + 1) if num_additional_repeats > 0: total_repeats = 1 + num_additional_repeats processed_text = " ".join([text_item] * total_repeats) # Add the processed text (original or repeated) to the list for this tokenizer processed_texts_for_tokenizer.append(processed_text) # 2. Tokenize the entire batch of processed texts at once. # Pass the list `processed_texts_for_tokenizer` directly to the tokenizer. # The tokenizer's __call__ method should handle the batch efficiently. batch_tok_output = tokenizer( # Call the tokenizer ONCE with the full list processed_texts_for_tokenizer, padding="max_length", max_length=effective_max_length, return_tensors="pt", truncation=True, ) # 3. Store the resulting batch tensors directly. # The tokenizer should return tensors with shape [batch_size, max_length]. final_batch_tokens.append(batch_tok_output.input_ids) final_batch_masks.append(batch_tok_output.attention_mask) return final_batch_tokens, final_batch_masks @torch.no_grad() def text_encoding( self, tokens: List[torch.Tensor], masks, noisy_pad=False, zero_masking=True ) -> Tuple[List[torch.Tensor], torch.Tensor]: """ Encode the tokenized text using multiple text encoders. Args: tokens (List[torch.Tensor]): List of tokenized text tensors. Returns: Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing a list of text embeddings and pooled text embeddings. """ t5_tokens, clip_l_tokens, clip_g_tokens = tokens t5_masks, clip_l_masks, clip_g_masks = masks t5_emb = self.t5(t5_tokens, attention_mask=t5_masks)[0] if zero_masking: t5_emb = t5_emb * (t5_tokens != self.t5_tokenizer.pad_token_id).unsqueeze(-1) if noisy_pad: t5_pad_noise = ( (t5_tokens == self.t5_tokenizer.pad_token_id).unsqueeze(-1) * torch.randn_like(t5_emb).cuda() * 0.008 ) t5_emb = t5_emb + t5_pad_noise clip_l_emb = self.clip_l(input_ids=clip_l_tokens, output_hidden_states=True) clip_g_emb = self.clip_g(input_ids=clip_g_tokens, output_hidden_states=True) clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768 clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280 clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768, clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280, def masking_wo_first_eos(token, eos): idx = (token != eos).sum(dim=1) mask = token != eos arange = torch.arange(mask.size(0)).cuda() mask[arange, idx] = True mask = mask.unsqueeze(-1) # B x L x 1 return mask if zero_masking: clip_l_emb = clip_l_emb * masking_wo_first_eos( clip_l_tokens, self.clip_l_tokenizer.eos_token_id ) # B x L x 768, clip_g_emb = clip_g_emb * masking_wo_first_eos( clip_g_tokens, self.clip_g_tokenizer.eos_token_id ) # B x L x 768, if noisy_pad: clip_l_pad_noise = ( ~masking_wo_first_eos(clip_l_tokens, self.clip_l_tokenizer.eos_token_id) * torch.randn_like(clip_l_emb).cuda() * 0.08 ) clip_g_pad_noise = ( ~masking_wo_first_eos(clip_g_tokens, self.clip_g_tokenizer.eos_token_id) * torch.randn_like(clip_g_emb).cuda() * 0.08 ) clip_l_emb = clip_l_emb + clip_l_pad_noise clip_g_emb = clip_g_emb + clip_g_pad_noise encodings = [t5_emb, clip_l_emb, clip_g_emb] pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # cat by channel, B x 2048 return encodings, pooled_encodings @torch.no_grad() def prompt_embedding(self, prompts: str, device, noisy_pad=False, zero_masking=True): tokens, masks = self.tokenization(prompts) tokens = [token.to(device) for token in tokens] masks = [mask.to(device) for mask in masks] text_embeddings, pooled_text_embeddings = self.text_encoding( tokens, masks, noisy_pad=noisy_pad, zero_masking=zero_masking ) text_embeddings = [text_embedding.bfloat16() for text_embedding in text_embeddings] pooled_text_embeddings = pooled_text_embeddings.bfloat16() return text_embeddings, pooled_text_embeddings @torch.no_grad() def sample( self, raw_text: List[str], steps: int = 50, guidance_scale: float = 7.5, resolution: List[int] = (256, 256), pre_latent=None, pre_timestep=None, step_scaling=1.0, noisy_pad=False, zero_masking=False, negative_prompt: Optional[List[str]] = None, device: str = "cuda", rescale_cfg=-1.0, clip_t=[0.0, 1.0], use_linear_quadratic_schedule=False, linear_quadratic_emulating_steps=250, prompt_rewriter=None, moderator=None, get_intermediate_steps: bool = False, # Defaulting to True based on user code ) -> Union[List[Image.Image], Tuple[List[Image.Image], List[List[Image.Image]]]]: # Updated return type hint """ Sample images using flow matching. Optionally returns intermediate step images calculated via observed average velocity method. Args: raw_text (List[str]): raw text prompts steps (int, optional): number of function estimations for flow matching ODE. Defaults to 50. guidance_scale (float, optional): classifier free guidance scale. Defaults to 7.5. resolution (List[int], optional): input and output resolution of raw images. Defaults to (256, 256). device (str, optional): Defaults to 'cuda'. pre_latent (Tensor, optional): the optional input to generate image with pre-defined latents. for instance, it would be utilized for denoising or image-editing. pre_timestep (float [0,1], optional): the pre-defined timestep. with `pre_latent`, image generation can be done by starting with intermediate timestep. step_scaling (float, default to 1.3): scaling factor for each ODE-solving. use_linear_quadratic_schedule (bool, default to false): boolean option to linear-quaratic t schdule. If false, then linear t schdule. linear_quadratic_emulating_steps (int, default to 250): N value in linear-quadratic t schedule from Meta moviegen paper Reference: (https://ai.meta.com/static-resource/movie-gen-research-paper) Figure 10 get_intermediate_steps (bool, optional): Whether to calculate and return intermediate step images. Calculation is based on initial_noise - avg(velocity). Defaults to True. Returns: Union[List[PIL.Image], Tuple[List[PIL.Image], List[List[PIL.Image]]]]: If get_intermediate_steps is False: Returns a list of final PIL images. If get_intermediate_steps is True: Returns a tuple containing: - List[PIL.Image]: Final output PIL images. - List[List[PIL.Image]]: List of intermediate PIL images. Each inner list contains the batch of images for one intermediate step. """ if prompt_rewriter: prompts = [prompt_rewriter.rewrite(prompt) for prompt in raw_text] else: prompts = raw_text # Simplified check for rewriter status if prompts == raw_text and prompt_rewriter is not None: logger.debug("Prompt rewriter did not change the prompts.") elif prompt_rewriter is None: logger.debug("Prompt rewriter not provided.") if moderator is None: is_safe_prompt = [True for _ in prompts] else: is_safe_prompt = [moderator and moderator.is_safe_content(prompt, threshold=0.7) for prompt in prompts] if not all(is_safe_prompt): logger.warning("Noxious prompt detected. Output image(s) will be blurred.") b = len(prompts) h, w = resolution # --- [Initial Latent Noise (e = x_1)] --- latent_channels = 16 if pre_latent is None: initial_noise = randn_tensor( # Store initial noise separately (b, latent_channels, h // VAE_DOWNSCALE_FACTOR, w // VAE_DOWNSCALE_FACTOR), device=device, dtype=torch.float32, # Use float32 for calculations ) else: initial_noise = pre_latent.to(device=device, dtype=torch.float32) if pre_timestep is not None and pre_timestep < 1.0: # Check if it's truly intermediate logger.warning( "Using pre_latent as initial_noise for average calculation, but pre_timestep suggests it's not pure noise. Results might be unexpected." ) latents = initial_noise.clone() # Working latents for the ODE solver # --- [Text Embeddings & CFG Setup] --- text_embeddings, pooled_text_embeddings = self.prompt_embedding( prompts, latents.device, noisy_pad=noisy_pad, zero_masking=zero_masking ) text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in text_embeddings] pooled_text_embeddings = pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16) do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: negative_text_embeddings = [ torch.zeros_like(text_embedding, device=text_embedding.device) for text_embedding in text_embeddings ] negative_pooled_text_embeddings = torch.zeros_like( pooled_text_embeddings, device=pooled_text_embeddings.device ) text_embeddings = [ torch.cat([text_embedding, negative_text_embedding], dim=0) for text_embedding, negative_text_embedding in zip(text_embeddings, negative_text_embeddings) ] pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0) # if negative_prompt is None: # negative_prompt = [""] * b # logger.debug("No negative prompt provided, using empty strings for CFG.") # negative_text_embeddings, negative_pooled_text_embeddings = self.prompt_embedding(negative_prompt, latents.device) # negative_text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in negative_text_embeddings] # negative_pooled_text_embeddings = negative_pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16) # text_embeddings = [torch.cat([pos_emb, neg_emb], dim=0) for pos_emb, neg_emb in zip(text_embeddings, negative_text_embeddings)] # pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0) # --- [Timestep Schedule (Sigmas)] --- # linear t schedule sigmas = torch.linspace(1, 0, steps + 1) if not pre_timestep else torch.linspace(pre_timestep, 0, steps + 1) if use_linear_quadratic_schedule: # liner-quadratic t schedule assert steps % 2 == 0 N = linear_quadratic_emulating_steps sigmas = torch.concat( [ torch.linspace(1, 0, N + 1)[: steps // 2], torch.linspace(0, 1, steps // 2 + 1) ** 2 * (steps // 2 * 1 / N - 1) - (steps // 2 * 1 / N - 1), ] ) # --- [Initialization for Intermediate Step Calculation] --- # intermediate_latents will store the latent states for intermediate steps intermediate_latents = [] if get_intermediate_steps else None predicted_velocities = [] # Store dx from each step sigma_history = [] # --- [Sampling Loop] --- for infer_step, t in tqdm.tqdm(enumerate(sigmas[:-1]), total=len(sigmas[:-1]), desc="Sampling"): # Prepare input for DiT model if do_classifier_free_guidance: input_latents = torch.cat([latents] * 2, dim=0) else: input_latents = latents # Prepare timestep input timestep = (t * 1000).round().long().to(latents.device) timestep = timestep.expand(input_latents.shape[0]).to(torch.bfloat16) # Ensure timestep is bfloat16 # Predict velocity dx = v(x_t, t) ≈ e - x_0 dx = self.dit(input_latents.to(torch.bfloat16), timestep, text_embeddings, pooled_text_embeddings) dt = sigmas[infer_step + 1] - sigmas[infer_step] # dt is negative sigma_history.append(dt) # Apply Classifier-Free Guidance if do_classifier_free_guidance: cond_dx, uncond_dx = dx.chunk(2) current_guidance_scale = guidance_scale if clip_t[0] <= t and t <= clip_t[1] else 1.0 dx = uncond_dx + current_guidance_scale * (cond_dx - uncond_dx) if rescale_cfg > 0.0: std_pos = torch.std(cond_dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5 std_cfg = torch.std(dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5 factor = std_pos / std_cfg factor = rescale_cfg * factor + (1.0 - rescale_cfg) dx = dx * factor # --- Store the predicted velocity for averaging --- predicted_velocities.append(dx.clone()) # --- Update Latents using standard Euler step --- latents = latents + dt * dx # --- Calculate and Store Intermediate Latent State (if requested) --- if get_intermediate_steps: dxs = torch.stack(predicted_velocities) sigma_sum = sum(sigma_history) normalized_sigma_history = [s / (sigma_sum) for s in sigma_history] dts = torch.tensor(normalized_sigma_history, device=dxs.device, dtype=dxs.dtype).view(-1, 1, 1, 1, 1) avg_dx = torch.sum(dxs * dts, dim=0) observed_state = initial_noise - avg_dx # Calculate the desired intermediate state intermediate_latents.append(observed_state.clone()) # Store its latent representation # --- [Decode Final Latents to PIL Images] --- self.vae = self.vae.to(device=latents.device, dtype=torch.float32) # Ensure VAE is ready final_latents_scaled = latents.to(torch.float32) / self.vae.config.scaling_factor final_image_tensors = self.vae.decode(final_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor final_image_tensors = ((final_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0) final_pil_images = [] for i, image_tensor in enumerate(final_image_tensors): img = T.ToPILImage()(image_tensor.cpu()) if not is_safe_prompt[i]: img = img.filter(ImageFilter.GaussianBlur(radius=30)) final_pil_images.append(img) # --- [Decode Intermediate Latents to PIL Images (if requested)] --- if get_intermediate_steps: intermediate_pil_images = [] # Ensure VAE is still ready (it should be from final decoding) for step_latents in tqdm.tqdm(intermediate_latents, desc="Decoding intermediates"): step_latents_scaled = ( step_latents.to(dtype=torch.float32, device="cuda") / self.vae.config.scaling_factor ) step_image_tensors = ( self.vae.decode(step_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor ) step_image_tensors = ((step_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0) current_step_pil = [] for i, image_tensor in enumerate(step_image_tensors): img = T.ToPILImage()(image_tensor.cpu()) # Apply moderation blur consistency if not is_safe_prompt[i]: img = img.filter(ImageFilter.GaussianBlur(radius=30)) current_step_pil.append(img) intermediate_pil_images.append(current_step_pil) # Append list of images for this step return final_pil_images, intermediate_pil_images # Return both final and intermediate images else: return final_pil_images # Return only final images @torch.no_grad() def eval_with_loss(self, images, raw_text): latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor tokens, masks = self.tokenization(raw_text) tokens = [token.to(latents.device) for token in tokens] masks = [mask.to(latents.device) for mask in masks] text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks) text_embeddings = [text_embedding for text_embedding in text_embeddings] pooled_text_embeddings = pooled_text_embeddings.float() # 2. Get noisy input via the rectified flow is_finetuning = self.config.height > 256 noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning) timesteps = self.discritize_timestep(t, self.n_timesteps) # 3. Forward pass through the dit preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings) # 4. Rectified flow matching loss loss = self.rectified_flow_loss(noise_latents, noise, t, preds, reduce="none", use_weighting=False).mean( dim=[1, 2, 3] ) intervals = np.linspace(0, 1, 9) t_interval = [(intervals[i], intervals[i + 1]) for i in range(len(intervals) - 1)] loss_bins = defaultdict(list) for i, interval in enumerate(t_interval, 0): idx = (interval[0] < t) & (t < interval[1]) loss_bins[i].append(loss[idx]) return loss_bins