import torch from diffusers.models import AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer class EncoderMixin: """Mixin class for handling various encoders in the MotifDiT model. This mixin provides functionality for: 1. Loading and initializing encoders (VAE, T5, CLIP-L, CLIP-G) 2. Text tokenization and encoding 3. Managing encoder parameters and state """ TOKEN_MAX_LENGTH: int = 256 def prepare_embeddings( self, images: torch.Tensor, raw_text: list[str], vae: AutoencoderKL, t5: T5EncoderModel, clip_l: CLIPTextModel, clip_g: CLIPTextModel, t5_tokenizer: T5Tokenizer, clip_l_tokenizer: CLIPTokenizerFast, clip_g_tokenizer: CLIPTokenizerFast, is_training, ) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]: """Prepare image latents and text embeddings for model input. Args: images (torch.Tensor): Input images tensor with shape [B, C=3, H, W]. raw_text (List[str]): List of raw text strings with length B. """ with torch.no_grad(): latents: torch.Tensor = ( vae.encode(images).latent_dist.sample() - vae.config.shift_factor ) * vae.config.scaling_factor # Latents shape: [B, 16, H//8, W//8] # Tokenize the input text and move tokens and masks to the same device as latents tokenizers = [t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer] tokens, masks = self.tokenization(raw_text, tokenizers) tokens = [token.to(latents.device) for token in tokens] masks = [mask.to(latents.device) for mask in masks] # Encode the text and drop unnecessary embeddings text_embeddings, pooled_text_embeddings = self.text_encoding( tokens, masks, t5, clip_l, clip_g, t5_tokenizer.pad_token_id, clip_l_tokenizer.eos_token_id, clip_g_tokenizer.eos_token_id, is_training, ) text_embeddings = self.drop_text_emb(text_embeddings) # Convert text embeddings to float text_embeddings = [text_embedding.float() for text_embedding in text_embeddings] # Convert pooled text embeddings to float pooled_text_embeddings = pooled_text_embeddings.float() return latents, text_embeddings, pooled_text_embeddings def get_freezed_encoders_and_tokenizers( self, vae_type: str ) -> tuple[ AutoencoderKL, T5EncoderModel, CLIPTextModel, CLIPTextModel, T5Tokenizer, CLIPTokenizerFast, CLIPTokenizerFast ]: """Initialize the VAE and text encoders.""" if vae_type != "SD3": raise ValueError( f"VAE type must be `SD3` but self.config.vae_type is {vae_type}." f" note that the VAE type SDXL is deprecated." ) vae: AutoencoderKL = AutoencoderKL.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae" ) # Text encoders # 1. T5-XXL from Google t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16) t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl") # 2. CLIP-L from OpenAI clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16) clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14") # 3. CLIP-G from LAION clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16) clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") # Freeze all encoders for encoder_module in [vae, clip_l, clip_g, t5]: for param in encoder_module.parameters(): param.requires_grad = False return vae, t5, clip_l, clip_g, t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer def tokenization( self, raw_text: list[str], tokenizers: list[T5Tokenizer | CLIPTokenizerFast] ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: """Tokenize the input text using multiple tokenizers. Args: raw_text (str): Input text string. Returns: Tuple[List[torch.Tensor], List[torch.Tensor]]: Lists of tokenized text tensors and attention masks. """ tokens, masks = [], [] for tokenizer in tokenizers: tok = tokenizer( raw_text, padding="max_length", max_length=min(EncoderMixin.TOKEN_MAX_LENGTH, tokenizer.model_max_length), return_tensors="pt", truncation=True, ) tokens.append(tok.input_ids) masks.append(tok.attention_mask) return tokens, masks @torch.no_grad() def text_encoding( self, tokens: list[torch.Tensor], masks: list[torch.Tensor], t5: T5EncoderModel, clip_l: CLIPTextModel, clip_g: CLIPTextModel, t5_pad_token_id: int = 0, clip_l_tokenizer_eos_token_id: int = 49407, clip_g_tokenizer_eos_token_id: int = 49407, is_training: bool = False, ) -> tuple[list[torch.Tensor], torch.Tensor]: """Encode the tokenized text using multiple text encoders. Args: tokens (List[torch.Tensor]): List of tokenized text tensors. masks (List[torch.Tensor]): List of attention masks. Returns: Tuple[List[torch.Tensor], torch.Tensor]: Text embeddings and pooled text embeddings. """ t5_tokens, clip_l_tokens, clip_g_tokens = tokens t5_masks, _, _ = masks # T5 encoding t5_emb = t5(t5_tokens, attention_mask=t5_masks)[0] t5_emb = t5_emb * (t5_tokens != t5_pad_token_id).unsqueeze(-1) # CLIP encodings clip_l_emb = clip_l(input_ids=clip_l_tokens, output_hidden_states=True) clip_g_emb = clip_g(input_ids=clip_g_tokens, output_hidden_states=True) # Get pooled outputs clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768 clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280 if is_training: clip_l_emb_pooled = self.drop_text_emb(clip_l_emb_pooled) clip_g_emb_pooled = self.drop_text_emb(clip_g_emb_pooled) clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768 clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280 def masking_wo_first_eos(token, eos): """Create attention mask without first EOS token.""" idx = (token != eos).sum(dim=1) mask = token != eos arange = torch.arange(mask.size(0)).cuda() if idx != len(mask[0]): mask[arange, idx] = True return mask.unsqueeze(-1) # B x L x 1 # Apply masking clip_l_emb = clip_l_emb * masking_wo_first_eos(clip_l_tokens, clip_l_tokenizer_eos_token_id) clip_g_emb = clip_g_emb * masking_wo_first_eos(clip_g_tokens, clip_g_tokenizer_eos_token_id) encodings = [t5_emb, clip_l_emb, clip_g_emb] pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # B x 2048 return encodings, pooled_encodings @torch.no_grad() def drop_text_emb( self, text_embeddings: list[torch.Tensor] | torch.Tensor, drop_prob: float = 0.464 ) -> list[torch.Tensor] | torch.Tensor: """Randomly drop text embeddings with a specified probability. Args: text_embeddings (Union[List[torch.Tensor], torch.Tensor]): Text embeddings to be dropped. drop_prob (float, optional): Probability of dropping text embeddings. Defaults to 0.464. Returns: Union[List[torch.Tensor], torch.Tensor]: Text embeddings with dropped elements. """ if isinstance(text_embeddings, list): # For BxLxC features for text_embedding in text_embeddings: probs = torch.ones((text_embedding.shape[0])).cuda() * (1 - drop_prob) masks = torch.bernoulli(probs).cuda() while len(masks.shape) < len(text_embedding.shape): masks = masks.unsqueeze(-1) text_embedding = text_embedding * masks else: # For a pooled BxC feature probs = torch.ones((text_embeddings.shape[0])).cuda() * (1 - drop_prob) masks = torch.bernoulli(probs).cuda() while len(masks.shape) < len(text_embeddings.shape): masks = masks.unsqueeze(-1) text_embeddings = text_embeddings * masks return text_embeddings