|
import torch |
|
from diffusers.models import AutoencoderKL |
|
|
|
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer |
|
|
|
|
|
class EncoderMixin: |
|
"""Mixin class for handling various encoders in the MotifDiT model. |
|
|
|
This mixin provides functionality for: |
|
1. Loading and initializing encoders (VAE, T5, CLIP-L, CLIP-G) |
|
2. Text tokenization and encoding |
|
3. Managing encoder parameters and state |
|
""" |
|
|
|
TOKEN_MAX_LENGTH: int = 256 |
|
|
|
def prepare_embeddings( |
|
self, |
|
images: torch.Tensor, |
|
raw_text: list[str], |
|
vae: AutoencoderKL, |
|
t5: T5EncoderModel, |
|
clip_l: CLIPTextModel, |
|
clip_g: CLIPTextModel, |
|
t5_tokenizer: T5Tokenizer, |
|
clip_l_tokenizer: CLIPTokenizerFast, |
|
clip_g_tokenizer: CLIPTokenizerFast, |
|
is_training, |
|
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: |
|
"""Prepare image latents and text embeddings for model input. |
|
|
|
Args: |
|
images (torch.Tensor): Input images tensor with shape [B, C=3, H, W]. |
|
raw_text (List[str]): List of raw text strings with length B. |
|
""" |
|
with torch.no_grad(): |
|
latents: torch.Tensor = ( |
|
vae.encode(images).latent_dist.sample() - vae.config.shift_factor |
|
) * vae.config.scaling_factor |
|
|
|
|
|
tokenizers = [t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer] |
|
tokens, masks = self.tokenization(raw_text, tokenizers) |
|
tokens = [token.to(latents.device) for token in tokens] |
|
masks = [mask.to(latents.device) for mask in masks] |
|
|
|
|
|
text_embeddings, pooled_text_embeddings = self.text_encoding( |
|
tokens, |
|
masks, |
|
t5, |
|
clip_l, |
|
clip_g, |
|
t5_tokenizer.pad_token_id, |
|
clip_l_tokenizer.eos_token_id, |
|
clip_g_tokenizer.eos_token_id, |
|
is_training, |
|
) |
|
text_embeddings = self.drop_text_emb(text_embeddings) |
|
|
|
|
|
text_embeddings = [text_embedding.float() for text_embedding in text_embeddings] |
|
|
|
|
|
pooled_text_embeddings = pooled_text_embeddings.float() |
|
|
|
return latents, text_embeddings, pooled_text_embeddings |
|
|
|
def get_freezed_encoders_and_tokenizers( |
|
self, vae_type: str |
|
) -> tuple[ |
|
AutoencoderKL, T5EncoderModel, CLIPTextModel, CLIPTextModel, T5Tokenizer, CLIPTokenizerFast, CLIPTokenizerFast |
|
]: |
|
"""Initialize the VAE and text encoders.""" |
|
if vae_type != "SD3": |
|
raise ValueError( |
|
f"VAE type must be `SD3` but self.config.vae_type is {vae_type}." |
|
f" note that the VAE type SDXL is deprecated." |
|
) |
|
|
|
vae: AutoencoderKL = AutoencoderKL.from_pretrained( |
|
"stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae" |
|
) |
|
|
|
|
|
|
|
t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16) |
|
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl") |
|
|
|
|
|
clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16) |
|
clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14") |
|
|
|
|
|
clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16) |
|
clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
|
|
|
|
|
|
|
for encoder_module in [vae, clip_l, clip_g, t5]: |
|
for param in encoder_module.parameters(): |
|
param.requires_grad = False |
|
|
|
return vae, t5, clip_l, clip_g, t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer |
|
|
|
def tokenization( |
|
self, raw_text: list[str], tokenizers: list[T5Tokenizer | CLIPTokenizerFast] |
|
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
|
"""Tokenize the input text using multiple tokenizers. |
|
|
|
Args: |
|
raw_text (str): Input text string. |
|
|
|
Returns: |
|
Tuple[List[torch.Tensor], List[torch.Tensor]]: Lists of tokenized text tensors and attention masks. |
|
""" |
|
tokens, masks = [], [] |
|
for tokenizer in tokenizers: |
|
tok = tokenizer( |
|
raw_text, |
|
padding="max_length", |
|
max_length=min(EncoderMixin.TOKEN_MAX_LENGTH, tokenizer.model_max_length), |
|
return_tensors="pt", |
|
truncation=True, |
|
) |
|
tokens.append(tok.input_ids) |
|
masks.append(tok.attention_mask) |
|
return tokens, masks |
|
|
|
@torch.no_grad() |
|
def text_encoding( |
|
self, |
|
tokens: list[torch.Tensor], |
|
masks: list[torch.Tensor], |
|
t5: T5EncoderModel, |
|
clip_l: CLIPTextModel, |
|
clip_g: CLIPTextModel, |
|
t5_pad_token_id: int = 0, |
|
clip_l_tokenizer_eos_token_id: int = 49407, |
|
clip_g_tokenizer_eos_token_id: int = 49407, |
|
is_training: bool = False, |
|
) -> tuple[list[torch.Tensor], torch.Tensor]: |
|
"""Encode the tokenized text using multiple text encoders. |
|
|
|
Args: |
|
tokens (List[torch.Tensor]): List of tokenized text tensors. |
|
masks (List[torch.Tensor]): List of attention masks. |
|
|
|
Returns: |
|
Tuple[List[torch.Tensor], torch.Tensor]: Text embeddings and pooled text embeddings. |
|
""" |
|
t5_tokens, clip_l_tokens, clip_g_tokens = tokens |
|
t5_masks, _, _ = masks |
|
|
|
|
|
t5_emb = t5(t5_tokens, attention_mask=t5_masks)[0] |
|
t5_emb = t5_emb * (t5_tokens != t5_pad_token_id).unsqueeze(-1) |
|
|
|
|
|
clip_l_emb = clip_l(input_ids=clip_l_tokens, output_hidden_states=True) |
|
clip_g_emb = clip_g(input_ids=clip_g_tokens, output_hidden_states=True) |
|
|
|
|
|
clip_l_emb_pooled = clip_l_emb.pooler_output |
|
clip_g_emb_pooled = clip_g_emb.pooler_output |
|
|
|
if is_training: |
|
clip_l_emb_pooled = self.drop_text_emb(clip_l_emb_pooled) |
|
clip_g_emb_pooled = self.drop_text_emb(clip_g_emb_pooled) |
|
|
|
clip_l_emb = clip_l_emb.last_hidden_state |
|
clip_g_emb = clip_g_emb.last_hidden_state |
|
|
|
def masking_wo_first_eos(token, eos): |
|
"""Create attention mask without first EOS token.""" |
|
idx = (token != eos).sum(dim=1) |
|
mask = token != eos |
|
arange = torch.arange(mask.size(0)).cuda() |
|
if idx != len(mask[0]): |
|
mask[arange, idx] = True |
|
return mask.unsqueeze(-1) |
|
|
|
|
|
clip_l_emb = clip_l_emb * masking_wo_first_eos(clip_l_tokens, clip_l_tokenizer_eos_token_id) |
|
clip_g_emb = clip_g_emb * masking_wo_first_eos(clip_g_tokens, clip_g_tokenizer_eos_token_id) |
|
|
|
encodings = [t5_emb, clip_l_emb, clip_g_emb] |
|
pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) |
|
|
|
return encodings, pooled_encodings |
|
|
|
@torch.no_grad() |
|
def drop_text_emb( |
|
self, text_embeddings: list[torch.Tensor] | torch.Tensor, drop_prob: float = 0.464 |
|
) -> list[torch.Tensor] | torch.Tensor: |
|
"""Randomly drop text embeddings with a specified probability. |
|
|
|
Args: |
|
text_embeddings (Union[List[torch.Tensor], torch.Tensor]): Text embeddings to be dropped. |
|
drop_prob (float, optional): Probability of dropping text embeddings. Defaults to 0.464. |
|
|
|
Returns: |
|
Union[List[torch.Tensor], torch.Tensor]: Text embeddings with dropped elements. |
|
""" |
|
if isinstance(text_embeddings, list): |
|
|
|
for text_embedding in text_embeddings: |
|
probs = torch.ones((text_embedding.shape[0])).cuda() * (1 - drop_prob) |
|
masks = torch.bernoulli(probs).cuda() |
|
while len(masks.shape) < len(text_embedding.shape): |
|
masks = masks.unsqueeze(-1) |
|
text_embedding = text_embedding * masks |
|
else: |
|
|
|
probs = torch.ones((text_embeddings.shape[0])).cuda() * (1 - drop_prob) |
|
masks = torch.bernoulli(probs).cuda() |
|
while len(masks.shape) < len(text_embeddings.shape): |
|
masks = masks.unsqueeze(-1) |
|
text_embeddings = text_embeddings * masks |
|
|
|
return text_embeddings |
|
|