beomgyu-kim's picture
Add MotifVision model for text-to-image generation
6cd6a16
import torch
from diffusers.models import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer
class EncoderMixin:
"""Mixin class for handling various encoders in the MotifDiT model.
This mixin provides functionality for:
1. Loading and initializing encoders (VAE, T5, CLIP-L, CLIP-G)
2. Text tokenization and encoding
3. Managing encoder parameters and state
"""
TOKEN_MAX_LENGTH: int = 256
def prepare_embeddings(
self,
images: torch.Tensor,
raw_text: list[str],
vae: AutoencoderKL,
t5: T5EncoderModel,
clip_l: CLIPTextModel,
clip_g: CLIPTextModel,
t5_tokenizer: T5Tokenizer,
clip_l_tokenizer: CLIPTokenizerFast,
clip_g_tokenizer: CLIPTokenizerFast,
is_training,
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
"""Prepare image latents and text embeddings for model input.
Args:
images (torch.Tensor): Input images tensor with shape [B, C=3, H, W].
raw_text (List[str]): List of raw text strings with length B.
"""
with torch.no_grad():
latents: torch.Tensor = (
vae.encode(images).latent_dist.sample() - vae.config.shift_factor
) * vae.config.scaling_factor # Latents shape: [B, 16, H//8, W//8]
# Tokenize the input text and move tokens and masks to the same device as latents
tokenizers = [t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer]
tokens, masks = self.tokenization(raw_text, tokenizers)
tokens = [token.to(latents.device) for token in tokens]
masks = [mask.to(latents.device) for mask in masks]
# Encode the text and drop unnecessary embeddings
text_embeddings, pooled_text_embeddings = self.text_encoding(
tokens,
masks,
t5,
clip_l,
clip_g,
t5_tokenizer.pad_token_id,
clip_l_tokenizer.eos_token_id,
clip_g_tokenizer.eos_token_id,
is_training,
)
text_embeddings = self.drop_text_emb(text_embeddings)
# Convert text embeddings to float
text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
# Convert pooled text embeddings to float
pooled_text_embeddings = pooled_text_embeddings.float()
return latents, text_embeddings, pooled_text_embeddings
def get_freezed_encoders_and_tokenizers(
self, vae_type: str
) -> tuple[
AutoencoderKL, T5EncoderModel, CLIPTextModel, CLIPTextModel, T5Tokenizer, CLIPTokenizerFast, CLIPTokenizerFast
]:
"""Initialize the VAE and text encoders."""
if vae_type != "SD3":
raise ValueError(
f"VAE type must be `SD3` but self.config.vae_type is {vae_type}."
f" note that the VAE type SDXL is deprecated."
)
vae: AutoencoderKL = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae"
)
# Text encoders
# 1. T5-XXL from Google
t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
# 2. CLIP-L from OpenAI
clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
# 3. CLIP-G from LAION
clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# Freeze all encoders
for encoder_module in [vae, clip_l, clip_g, t5]:
for param in encoder_module.parameters():
param.requires_grad = False
return vae, t5, clip_l, clip_g, t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer
def tokenization(
self, raw_text: list[str], tokenizers: list[T5Tokenizer | CLIPTokenizerFast]
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Tokenize the input text using multiple tokenizers.
Args:
raw_text (str): Input text string.
Returns:
Tuple[List[torch.Tensor], List[torch.Tensor]]: Lists of tokenized text tensors and attention masks.
"""
tokens, masks = [], []
for tokenizer in tokenizers:
tok = tokenizer(
raw_text,
padding="max_length",
max_length=min(EncoderMixin.TOKEN_MAX_LENGTH, tokenizer.model_max_length),
return_tensors="pt",
truncation=True,
)
tokens.append(tok.input_ids)
masks.append(tok.attention_mask)
return tokens, masks
@torch.no_grad()
def text_encoding(
self,
tokens: list[torch.Tensor],
masks: list[torch.Tensor],
t5: T5EncoderModel,
clip_l: CLIPTextModel,
clip_g: CLIPTextModel,
t5_pad_token_id: int = 0,
clip_l_tokenizer_eos_token_id: int = 49407,
clip_g_tokenizer_eos_token_id: int = 49407,
is_training: bool = False,
) -> tuple[list[torch.Tensor], torch.Tensor]:
"""Encode the tokenized text using multiple text encoders.
Args:
tokens (List[torch.Tensor]): List of tokenized text tensors.
masks (List[torch.Tensor]): List of attention masks.
Returns:
Tuple[List[torch.Tensor], torch.Tensor]: Text embeddings and pooled text embeddings.
"""
t5_tokens, clip_l_tokens, clip_g_tokens = tokens
t5_masks, _, _ = masks
# T5 encoding
t5_emb = t5(t5_tokens, attention_mask=t5_masks)[0]
t5_emb = t5_emb * (t5_tokens != t5_pad_token_id).unsqueeze(-1)
# CLIP encodings
clip_l_emb = clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
clip_g_emb = clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
# Get pooled outputs
clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768
clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280
if is_training:
clip_l_emb_pooled = self.drop_text_emb(clip_l_emb_pooled)
clip_g_emb_pooled = self.drop_text_emb(clip_g_emb_pooled)
clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768
clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280
def masking_wo_first_eos(token, eos):
"""Create attention mask without first EOS token."""
idx = (token != eos).sum(dim=1)
mask = token != eos
arange = torch.arange(mask.size(0)).cuda()
if idx != len(mask[0]):
mask[arange, idx] = True
return mask.unsqueeze(-1) # B x L x 1
# Apply masking
clip_l_emb = clip_l_emb * masking_wo_first_eos(clip_l_tokens, clip_l_tokenizer_eos_token_id)
clip_g_emb = clip_g_emb * masking_wo_first_eos(clip_g_tokens, clip_g_tokenizer_eos_token_id)
encodings = [t5_emb, clip_l_emb, clip_g_emb]
pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # B x 2048
return encodings, pooled_encodings
@torch.no_grad()
def drop_text_emb(
self, text_embeddings: list[torch.Tensor] | torch.Tensor, drop_prob: float = 0.464
) -> list[torch.Tensor] | torch.Tensor:
"""Randomly drop text embeddings with a specified probability.
Args:
text_embeddings (Union[List[torch.Tensor], torch.Tensor]): Text embeddings to be dropped.
drop_prob (float, optional): Probability of dropping text embeddings. Defaults to 0.464.
Returns:
Union[List[torch.Tensor], torch.Tensor]: Text embeddings with dropped elements.
"""
if isinstance(text_embeddings, list):
# For BxLxC features
for text_embedding in text_embeddings:
probs = torch.ones((text_embedding.shape[0])).cuda() * (1 - drop_prob)
masks = torch.bernoulli(probs).cuda()
while len(masks.shape) < len(text_embedding.shape):
masks = masks.unsqueeze(-1)
text_embedding = text_embedding * masks
else:
# For a pooled BxC feature
probs = torch.ones((text_embeddings.shape[0])).cuda() * (1 - drop_prob)
masks = torch.bernoulli(probs).cuda()
while len(masks.shape) < len(text_embeddings.shape):
masks = masks.unsqueeze(-1)
text_embeddings = text_embeddings * masks
return text_embeddings