|
from collections import defaultdict |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
import tqdm |
|
from diffusers.models import AutoencoderKL |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from loguru import logger |
|
from PIL import Image, ImageFilter |
|
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer |
|
|
|
from models.mixin.flow_mixin import FlowMixin |
|
from models.modeling_dit import MotifDiT |
|
|
|
TOKEN_MAX_LENGTH: int = 256 |
|
DROP_PROB: float = 0.1 |
|
LATENT_CHANNELS: int = 4 |
|
VAE_DOWNSCALE_FACTOR: int = 8 |
|
SD3_LATENT_CHANNEL: int = 16 |
|
|
|
|
|
def generate_intervals(steps, ratio, start=1.0): |
|
intervals = torch.linspace(start, 0, steps=steps) |
|
intervals = intervals.pow(ratio) |
|
return intervals |
|
|
|
|
|
class MotifImage(nn.Module, FlowMixin): |
|
""" |
|
MotifImage Text-to-Image model. |
|
|
|
This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders. |
|
It uses a VAE (Variational Autoencoder) for image encoding and decoding. |
|
|
|
Args: |
|
config (MMDiTConfig): Configuration object for the MMDiT model. |
|
|
|
Attributes: |
|
dit (MotifDiT): MotifDiT model instance. |
|
noise_scheduler (DDPMScheduler): Noise scheduler for the diffusion process. |
|
normalize_img (Callable): Function to normalize images from [-1, 1] range. |
|
unnormalize_img (Callable): Function to unnormalize images to [0, 1] range. |
|
cond_drop_prob (float): Probability of dropping text embeddings during training. |
|
snr_gamma (str): Strategy for weighting the loss based on Signal-to-Noise Ratio (SNR). |
|
loss_weight_strategy (str): Strategy for weighting the loss. |
|
vae (AutoencoderKL): Variational Autoencoder for image encoding and decoding. |
|
t5 (T5EncoderModel): T5 encoder model for text encoding. |
|
t5_tokenizer (T5Tokenizer): T5 tokenizer for text tokenization. |
|
clip_l (CLIPModel): CLIP (Contrastive Language-Image Pre-training) model (large) for text encoding. |
|
clip_l_tokenizer (CLIPTokenizerFast): CLIP tokenizer (large) for text tokenization. |
|
clip_g (CLIPModel): CLIP model (giant) for text encoding. |
|
clip_g_tokenizer (CLIPTokenizerFast): CLIP tokenizer (giant) for text tokenization. |
|
tokenizers (List[Union[T5Tokenizer, CLIPTokenizerFast]]): List of tokenizers. |
|
text_encoders (List[Union[T5EncoderModel, CLIPModel]]): List of text encoder models. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.dit = MotifDiT(config) |
|
self.cond_drop_prob = 0.1 |
|
self.use_weighting = False |
|
self._get_encoders() |
|
self._freeze_encoders() |
|
|
|
def forward(self, images: torch.Tensor, raw_text: str) -> torch.Tensor: |
|
""" |
|
Forward pass of the MotifDiT model. |
|
|
|
Args: |
|
images (torch.Tensor): Input images tensor, [0-1] ranged. |
|
raw_text (List[str]): Input text string. |
|
|
|
Returns: |
|
torch.Tensor: Rectified flow matching loss. |
|
""" |
|
|
|
with torch.no_grad(): |
|
latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor |
|
tokens, masks = self.tokenization(raw_text) |
|
tokens = [token.to(latents.device) for token in tokens] |
|
masks = [mask.to(latents.device) for mask in masks] |
|
text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks) |
|
text_embeddings = self._drop_text_emb(text_embeddings) |
|
text_embeddings = [text_embedding.float() for text_embedding in text_embeddings] |
|
pooled_text_embeddings = pooled_text_embeddings.float() |
|
|
|
|
|
is_finetuning = self.config.height > 256 |
|
noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning) |
|
|
|
timesteps = self.discritize_timestep(t, self.n_timesteps) |
|
|
|
|
|
preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings) |
|
|
|
|
|
loss = self.rectified_flow_loss(latents, noise, t, preds, use_weighting=self.use_weighting) |
|
|
|
return [loss] |
|
|
|
def _get_encoders(self) -> None: |
|
"""Initialize the VAE and text encoders.""" |
|
if self.config.vae_type == "SD3": |
|
self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae") |
|
elif self.config.vae_type == "SDXL": |
|
self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae") |
|
else: |
|
raise ValueError(f"VAE type must be `SD3` or `SDXL` but self.config.vae_type is {self.config.vae_type}") |
|
|
|
|
|
|
|
self.t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16) |
|
self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl") |
|
|
|
|
|
self.clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16) |
|
self.clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
|
|
self.clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16) |
|
self.clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
|
|
|
self.tokenizers = [self.t5_tokenizer, self.clip_l_tokenizer, self.clip_g_tokenizer] |
|
self.text_encoders = [self.t5, self.clip_l, self.clip_g] |
|
|
|
def state_dict(self, destination=None, prefix="", keep_vars=False): |
|
state_dict = super(MotifImage, self).state_dict(destination, prefix, keep_vars) |
|
exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."] |
|
for key in list(state_dict.keys()): |
|
if any(key.startswith(exclude_key) for exclude_key in exclude_keys): |
|
state_dict.pop(key) |
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict, strict=False): |
|
""" |
|
Load state dict and merge LoRA parameters if present. |
|
|
|
Args: |
|
state_dict (dict): State dictionary containing model parameters |
|
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys in this module |
|
|
|
Returns: |
|
tuple: (missing_keys, unexpected_keys) lists of parameters that were missing or unexpected |
|
""" |
|
|
|
has_lora = any("lora_" in key for key in state_dict.keys()) |
|
|
|
if has_lora: |
|
|
|
if not hasattr(self.dit, "peft_config"): |
|
logger.info("Enabling LoRA for parameter merging...") |
|
|
|
lora_rank = getattr(self.config, "lora_rank", 64) |
|
lora_alpha = getattr(self.config, "lora_alpha", 8) |
|
self.enable_lora(lora_rank, lora_alpha) |
|
|
|
if has_lora: |
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) |
|
|
|
logger.info("Merging LoRA parameters with base model...") |
|
for name, module in self.dit.named_modules(): |
|
if hasattr(module, "merge_and_unload"): |
|
module.merge_and_unload() |
|
|
|
logger.info("Successfully merged LoRA parameters") |
|
|
|
except Exception as e: |
|
logger.error(f"Error merging LoRA parameters: {str(e)}") |
|
raise |
|
else: |
|
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) |
|
|
|
|
|
missing_top_levels = set() |
|
for key in missing_keys: |
|
top_level_name = key.split(".")[0] |
|
missing_top_levels.add(top_level_name) |
|
if missing_top_levels: |
|
logger.debug("Missing keys during loading at top level:") |
|
for name in missing_top_levels: |
|
logger.debug(name) |
|
|
|
if unexpected_keys: |
|
logger.debug("Unexpected keys found:") |
|
for key in unexpected_keys: |
|
logger.debug(key) |
|
|
|
return missing_keys, unexpected_keys |
|
|
|
def _freeze_encoders(self) -> None: |
|
""" |
|
freeze all encoders |
|
""" |
|
for encoder_module in [self.vae, self.clip_l, self.clip_g, self.t5]: |
|
for param in encoder_module.parameters(): |
|
param.requires_grad = False |
|
|
|
def tokenization( |
|
self, raw_texts: List[str], repeat_if_short: bool = False |
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
""" |
|
Tokenizes a BATCH of input texts using multiple tokenizers efficiently. |
|
Optionally repeats each text to fill the max length if it's shorter, |
|
BEFORE passing the pre-processed batch to the tokenizer. |
|
|
|
Args: |
|
raw_texts (List[str]): A list of input text strings (the batch). |
|
repeat_if_short (bool): If True and a text is short, repeat that text |
|
to fill the context length. Defaults to True. |
|
|
|
Returns: |
|
Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
- A list containing one batch tensor of input IDs per tokenizer. |
|
Each tensor shape: [batch_size, max_length] |
|
- A list containing one batch tensor of attention masks per tokenizer. |
|
Each tensor shape: [batch_size, max_length] |
|
""" |
|
final_batch_tokens = [] |
|
final_batch_masks = [] |
|
|
|
|
|
for tokenizer in self.tokenizers: |
|
effective_max_length = min(TOKEN_MAX_LENGTH, tokenizer.model_max_length) |
|
|
|
|
|
processed_texts_for_tokenizer = [] |
|
for text_item in raw_texts: |
|
|
|
processed_text = text_item |
|
|
|
if repeat_if_short: |
|
|
|
num_initial_tokens = len(text_item.split()) |
|
available_length = effective_max_length - 2 |
|
|
|
if num_initial_tokens > 0 and num_initial_tokens < available_length: |
|
num_additional_repeats = available_length // (num_initial_tokens + 1) |
|
if num_additional_repeats > 0: |
|
total_repeats = 1 + num_additional_repeats |
|
processed_text = " ".join([text_item] * total_repeats) |
|
|
|
|
|
processed_texts_for_tokenizer.append(processed_text) |
|
|
|
|
|
|
|
|
|
batch_tok_output = tokenizer( |
|
processed_texts_for_tokenizer, |
|
padding="max_length", |
|
max_length=effective_max_length, |
|
return_tensors="pt", |
|
truncation=True, |
|
) |
|
|
|
|
|
|
|
final_batch_tokens.append(batch_tok_output.input_ids) |
|
final_batch_masks.append(batch_tok_output.attention_mask) |
|
|
|
return final_batch_tokens, final_batch_masks |
|
|
|
@torch.no_grad() |
|
def text_encoding( |
|
self, tokens: List[torch.Tensor], masks, noisy_pad=False, zero_masking=True |
|
) -> Tuple[List[torch.Tensor], torch.Tensor]: |
|
""" |
|
Encode the tokenized text using multiple text encoders. |
|
|
|
Args: |
|
tokens (List[torch.Tensor]): List of tokenized text tensors. |
|
|
|
Returns: |
|
Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing a list of text embeddings and pooled text embeddings. |
|
""" |
|
t5_tokens, clip_l_tokens, clip_g_tokens = tokens |
|
t5_masks, clip_l_masks, clip_g_masks = masks |
|
t5_emb = self.t5(t5_tokens, attention_mask=t5_masks)[0] |
|
if zero_masking: |
|
t5_emb = t5_emb * (t5_tokens != self.t5_tokenizer.pad_token_id).unsqueeze(-1) |
|
if noisy_pad: |
|
t5_pad_noise = ( |
|
(t5_tokens == self.t5_tokenizer.pad_token_id).unsqueeze(-1) * torch.randn_like(t5_emb).cuda() * 0.008 |
|
) |
|
t5_emb = t5_emb + t5_pad_noise |
|
|
|
clip_l_emb = self.clip_l(input_ids=clip_l_tokens, output_hidden_states=True) |
|
clip_g_emb = self.clip_g(input_ids=clip_g_tokens, output_hidden_states=True) |
|
clip_l_emb_pooled = clip_l_emb.pooler_output |
|
clip_g_emb_pooled = clip_g_emb.pooler_output |
|
|
|
clip_l_emb = clip_l_emb.last_hidden_state |
|
clip_g_emb = clip_g_emb.last_hidden_state |
|
|
|
def masking_wo_first_eos(token, eos): |
|
idx = (token != eos).sum(dim=1) |
|
mask = token != eos |
|
arange = torch.arange(mask.size(0)).cuda() |
|
mask[arange, idx] = True |
|
mask = mask.unsqueeze(-1) |
|
return mask |
|
|
|
if zero_masking: |
|
clip_l_emb = clip_l_emb * masking_wo_first_eos( |
|
clip_l_tokens, self.clip_l_tokenizer.eos_token_id |
|
) |
|
clip_g_emb = clip_g_emb * masking_wo_first_eos( |
|
clip_g_tokens, self.clip_g_tokenizer.eos_token_id |
|
) |
|
|
|
if noisy_pad: |
|
clip_l_pad_noise = ( |
|
~masking_wo_first_eos(clip_l_tokens, self.clip_l_tokenizer.eos_token_id) |
|
* torch.randn_like(clip_l_emb).cuda() |
|
* 0.08 |
|
) |
|
clip_g_pad_noise = ( |
|
~masking_wo_first_eos(clip_g_tokens, self.clip_g_tokenizer.eos_token_id) |
|
* torch.randn_like(clip_g_emb).cuda() |
|
* 0.08 |
|
) |
|
clip_l_emb = clip_l_emb + clip_l_pad_noise |
|
clip_g_emb = clip_g_emb + clip_g_pad_noise |
|
|
|
encodings = [t5_emb, clip_l_emb, clip_g_emb] |
|
pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) |
|
|
|
return encodings, pooled_encodings |
|
|
|
@torch.no_grad() |
|
def prompt_embedding(self, prompts: str, device, noisy_pad=False, zero_masking=True): |
|
tokens, masks = self.tokenization(prompts) |
|
tokens = [token.to(device) for token in tokens] |
|
masks = [mask.to(device) for mask in masks] |
|
text_embeddings, pooled_text_embeddings = self.text_encoding( |
|
tokens, masks, noisy_pad=noisy_pad, zero_masking=zero_masking |
|
) |
|
text_embeddings = [text_embedding.bfloat16() for text_embedding in text_embeddings] |
|
pooled_text_embeddings = pooled_text_embeddings.bfloat16() |
|
return text_embeddings, pooled_text_embeddings |
|
|
|
@torch.no_grad() |
|
def sample( |
|
self, |
|
raw_text: List[str], |
|
steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
resolution: List[int] = (256, 256), |
|
pre_latent=None, |
|
pre_timestep=None, |
|
step_scaling=1.0, |
|
noisy_pad=False, |
|
zero_masking=False, |
|
negative_prompt: Optional[List[str]] = None, |
|
device: str = "cuda", |
|
rescale_cfg=-1.0, |
|
clip_t=[0.0, 1.0], |
|
use_linear_quadratic_schedule=False, |
|
linear_quadratic_emulating_steps=250, |
|
prompt_rewriter=None, |
|
moderator=None, |
|
get_intermediate_steps: bool = False, |
|
) -> Union[List[Image.Image], Tuple[List[Image.Image], List[List[Image.Image]]]]: |
|
""" |
|
Sample images using flow matching. Optionally returns intermediate step images |
|
calculated via observed average velocity method. |
|
|
|
Args: |
|
raw_text (List[str]): raw text prompts |
|
steps (int, optional): number of function estimations for flow matching ODE. Defaults to 50. |
|
guidance_scale (float, optional): classifier free guidance scale. Defaults to 7.5. |
|
resolution (List[int], optional): input and output resolution of raw images. Defaults to (256, 256). |
|
device (str, optional): Defaults to 'cuda'. |
|
pre_latent (Tensor, optional): the optional input to generate image with pre-defined latents. |
|
for instance, it would be utilized for denoising or image-editing. |
|
pre_timestep (float [0,1], optional): the pre-defined timestep. with `pre_latent`, image generation |
|
can be done by starting with intermediate timestep. |
|
step_scaling (float, default to 1.3): scaling factor for each ODE-solving. |
|
use_linear_quadratic_schedule (bool, default to false): boolean option to linear-quaratic t schdule. If false, then linear t schdule. |
|
linear_quadratic_emulating_steps (int, default to 250): N value in linear-quadratic t schedule from Meta moviegen paper |
|
Reference: (https://ai.meta.com/static-resource/movie-gen-research-paper) Figure 10 |
|
get_intermediate_steps (bool, optional): Whether to calculate and return intermediate step images. |
|
Calculation is based on initial_noise - avg(velocity). Defaults to True. |
|
|
|
Returns: |
|
Union[List[PIL.Image], Tuple[List[PIL.Image], List[List[PIL.Image]]]]: |
|
If get_intermediate_steps is False: Returns a list of final PIL images. |
|
If get_intermediate_steps is True: Returns a tuple containing: |
|
- List[PIL.Image]: Final output PIL images. |
|
- List[List[PIL.Image]]: List of intermediate PIL images. Each inner list contains |
|
the batch of images for one intermediate step. |
|
""" |
|
if prompt_rewriter: |
|
prompts = [prompt_rewriter.rewrite(prompt) for prompt in raw_text] |
|
else: |
|
prompts = raw_text |
|
|
|
|
|
if prompts == raw_text and prompt_rewriter is not None: |
|
logger.debug("Prompt rewriter did not change the prompts.") |
|
elif prompt_rewriter is None: |
|
logger.debug("Prompt rewriter not provided.") |
|
|
|
if moderator is None: |
|
is_safe_prompt = [True for _ in prompts] |
|
else: |
|
is_safe_prompt = [moderator and moderator.is_safe_content(prompt, threshold=0.7) for prompt in prompts] |
|
if not all(is_safe_prompt): |
|
logger.warning("Noxious prompt detected. Output image(s) will be blurred.") |
|
|
|
b = len(prompts) |
|
h, w = resolution |
|
|
|
|
|
latent_channels = 16 |
|
if pre_latent is None: |
|
initial_noise = randn_tensor( |
|
(b, latent_channels, h // VAE_DOWNSCALE_FACTOR, w // VAE_DOWNSCALE_FACTOR), |
|
device=device, |
|
dtype=torch.float32, |
|
) |
|
else: |
|
initial_noise = pre_latent.to(device=device, dtype=torch.float32) |
|
if pre_timestep is not None and pre_timestep < 1.0: |
|
logger.warning( |
|
"Using pre_latent as initial_noise for average calculation, but pre_timestep suggests it's not pure noise. Results might be unexpected." |
|
) |
|
|
|
latents = initial_noise.clone() |
|
|
|
|
|
text_embeddings, pooled_text_embeddings = self.prompt_embedding( |
|
prompts, latents.device, noisy_pad=noisy_pad, zero_masking=zero_masking |
|
) |
|
text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in text_embeddings] |
|
pooled_text_embeddings = pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16) |
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
if do_classifier_free_guidance: |
|
negative_text_embeddings = [ |
|
torch.zeros_like(text_embedding, device=text_embedding.device) for text_embedding in text_embeddings |
|
] |
|
negative_pooled_text_embeddings = torch.zeros_like( |
|
pooled_text_embeddings, device=pooled_text_embeddings.device |
|
) |
|
text_embeddings = [ |
|
torch.cat([text_embedding, negative_text_embedding], dim=0) |
|
for text_embedding, negative_text_embedding in zip(text_embeddings, negative_text_embeddings) |
|
] |
|
pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sigmas = torch.linspace(1, 0, steps + 1) if not pre_timestep else torch.linspace(pre_timestep, 0, steps + 1) |
|
|
|
if use_linear_quadratic_schedule: |
|
|
|
assert steps % 2 == 0 |
|
N = linear_quadratic_emulating_steps |
|
sigmas = torch.concat( |
|
[ |
|
torch.linspace(1, 0, N + 1)[: steps // 2], |
|
torch.linspace(0, 1, steps // 2 + 1) ** 2 * (steps // 2 * 1 / N - 1) - (steps // 2 * 1 / N - 1), |
|
] |
|
) |
|
|
|
|
|
|
|
intermediate_latents = [] if get_intermediate_steps else None |
|
predicted_velocities = [] |
|
sigma_history = [] |
|
|
|
for infer_step, t in tqdm.tqdm(enumerate(sigmas[:-1]), total=len(sigmas[:-1]), desc="Sampling"): |
|
|
|
if do_classifier_free_guidance: |
|
input_latents = torch.cat([latents] * 2, dim=0) |
|
else: |
|
input_latents = latents |
|
|
|
|
|
timestep = (t * 1000).round().long().to(latents.device) |
|
timestep = timestep.expand(input_latents.shape[0]).to(torch.bfloat16) |
|
|
|
|
|
dx = self.dit(input_latents.to(torch.bfloat16), timestep, text_embeddings, pooled_text_embeddings) |
|
dt = sigmas[infer_step + 1] - sigmas[infer_step] |
|
sigma_history.append(dt) |
|
|
|
|
|
if do_classifier_free_guidance: |
|
cond_dx, uncond_dx = dx.chunk(2) |
|
current_guidance_scale = guidance_scale if clip_t[0] <= t and t <= clip_t[1] else 1.0 |
|
dx = uncond_dx + current_guidance_scale * (cond_dx - uncond_dx) |
|
|
|
if rescale_cfg > 0.0: |
|
std_pos = torch.std(cond_dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5 |
|
std_cfg = torch.std(dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5 |
|
factor = std_pos / std_cfg |
|
factor = rescale_cfg * factor + (1.0 - rescale_cfg) |
|
dx = dx * factor |
|
|
|
|
|
predicted_velocities.append(dx.clone()) |
|
|
|
|
|
latents = latents + dt * dx |
|
|
|
|
|
if get_intermediate_steps: |
|
dxs = torch.stack(predicted_velocities) |
|
|
|
sigma_sum = sum(sigma_history) |
|
normalized_sigma_history = [s / (sigma_sum) for s in sigma_history] |
|
dts = torch.tensor(normalized_sigma_history, device=dxs.device, dtype=dxs.dtype).view(-1, 1, 1, 1, 1) |
|
|
|
avg_dx = torch.sum(dxs * dts, dim=0) |
|
observed_state = initial_noise - avg_dx |
|
intermediate_latents.append(observed_state.clone()) |
|
|
|
|
|
self.vae = self.vae.to(device=latents.device, dtype=torch.float32) |
|
final_latents_scaled = latents.to(torch.float32) / self.vae.config.scaling_factor |
|
final_image_tensors = self.vae.decode(final_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor |
|
final_image_tensors = ((final_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0) |
|
|
|
final_pil_images = [] |
|
for i, image_tensor in enumerate(final_image_tensors): |
|
img = T.ToPILImage()(image_tensor.cpu()) |
|
if not is_safe_prompt[i]: |
|
img = img.filter(ImageFilter.GaussianBlur(radius=30)) |
|
final_pil_images.append(img) |
|
|
|
|
|
if get_intermediate_steps: |
|
intermediate_pil_images = [] |
|
|
|
for step_latents in tqdm.tqdm(intermediate_latents, desc="Decoding intermediates"): |
|
step_latents_scaled = ( |
|
step_latents.to(dtype=torch.float32, device="cuda") / self.vae.config.scaling_factor |
|
) |
|
step_image_tensors = ( |
|
self.vae.decode(step_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor |
|
) |
|
step_image_tensors = ((step_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0) |
|
|
|
current_step_pil = [] |
|
for i, image_tensor in enumerate(step_image_tensors): |
|
img = T.ToPILImage()(image_tensor.cpu()) |
|
|
|
if not is_safe_prompt[i]: |
|
img = img.filter(ImageFilter.GaussianBlur(radius=30)) |
|
current_step_pil.append(img) |
|
intermediate_pil_images.append(current_step_pil) |
|
|
|
return final_pil_images, intermediate_pil_images |
|
else: |
|
return final_pil_images |
|
|
|
@torch.no_grad() |
|
def eval_with_loss(self, images, raw_text): |
|
latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor |
|
|
|
tokens, masks = self.tokenization(raw_text) |
|
tokens = [token.to(latents.device) for token in tokens] |
|
masks = [mask.to(latents.device) for mask in masks] |
|
text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks) |
|
text_embeddings = [text_embedding for text_embedding in text_embeddings] |
|
pooled_text_embeddings = pooled_text_embeddings.float() |
|
|
|
|
|
is_finetuning = self.config.height > 256 |
|
noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning) |
|
timesteps = self.discritize_timestep(t, self.n_timesteps) |
|
|
|
|
|
preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings) |
|
|
|
|
|
loss = self.rectified_flow_loss(noise_latents, noise, t, preds, reduce="none", use_weighting=False).mean( |
|
dim=[1, 2, 3] |
|
) |
|
|
|
intervals = np.linspace(0, 1, 9) |
|
t_interval = [(intervals[i], intervals[i + 1]) for i in range(len(intervals) - 1)] |
|
|
|
loss_bins = defaultdict(list) |
|
for i, interval in enumerate(t_interval, 0): |
|
idx = (interval[0] < t) & (t < interval[1]) |
|
loss_bins[i].append(loss[idx]) |
|
|
|
return loss_bins |
|
|