import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import SequenceClassifierOutput from transformers import SeamlessM4TModel import logging logger = logging.getLogger(__name__) class SeamlessLanguagePairsConfig(PretrainedConfig): """Configuration class for SeamlessLanguagePairs model.""" model_type = "seamless_language_pairs" def __init__( self, seamless_model_name="facebook/hf-seamless-m4t-medium", num_language_pairs=21, hidden_size=1024, dropout_prob=0.1, **kwargs ): super().__init__(**kwargs) self.seamless_model_name = seamless_model_name self.num_language_pairs = num_language_pairs self.hidden_size = hidden_size self.dropout_prob = dropout_prob class HFSeamlessLanguagePairs(PreTrainedModel): """SeamlessM4T model with language pairs and translation features for HuggingFace Hub.""" config_class = SeamlessLanguagePairsConfig supports_gradient_checkpointing = True def __init__(self, config): super().__init__(config) self.config = config # Load the underlying SeamlessM4T model self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name) self.seamless_model_speech_encoder = self.seamless_model.speech_encoder self.seamless_model_text_encoder = self.seamless_model.text_encoder # Freeze pre-trained models for param in self.seamless_model_speech_encoder.parameters(): param.requires_grad = False for param in self.seamless_model_text_encoder.parameters(): param.requires_grad = False # Projection layers self.audio_proj = nn.Linear( self.seamless_model_speech_encoder.config.hidden_size, config.hidden_size ) self.text_proj = nn.Linear( self.seamless_model_text_encoder.config.hidden_size, config.hidden_size ) # Translation feature embedding (binary) self.translation_proj = nn.Linear(1, 32) # Language pair embedding (categorical) self.language_pair_embedding = nn.Embedding(config.num_language_pairs, 64) # Classification head (2048 + 32 + 64 = 2144) self.fc = nn.Sequential( nn.Linear(2144, 1024), nn.ReLU(), nn.Dropout(config.dropout_prob), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(config.dropout_prob), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(config.dropout_prob), nn.Linear(256, 1) ) def forward( self, input_features, input_ids, text_attention_mask, is_translation, language_pair_id, audio_attention_mask=None, labels=None, **kwargs ): # Encode audio audio_emb = self.seamless_model_speech_encoder( input_features=input_features, attention_mask=audio_attention_mask ).last_hidden_state.mean(dim=1) audio_emb = self.audio_proj(audio_emb) # Encode text text_emb = self.seamless_model_text_encoder( input_ids=input_ids, attention_mask=text_attention_mask ).last_hidden_state.mean(dim=1) text_emb = self.text_proj(text_emb) # Process translation feature translation_emb = self.translation_proj(is_translation.unsqueeze(-1)) # Process language pair feature language_pair_emb = self.language_pair_embedding(language_pair_id) # Combine features combined = torch.cat([ audio_emb, text_emb, translation_emb, language_pair_emb ], dim=1) # (batch_size, 2144) logits = self.fc(combined).squeeze(-1) # Compute loss if labels are provided loss = F.mse_loss(logits, labels) if labels is not None else None return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=None )