import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import SequenceClassifierOutput from transformers import SeamlessM4TModel import logging logger = logging.getLogger(__name__) class SeamlessBasicConfig(PretrainedConfig): """Configuration class for SeamlessBasic model.""" model_type = "seamless_basic" def __init__( self, seamless_model_name="facebook/hf-seamless-m4t-medium", hidden_size=1024, dropout_prob=0.1, **kwargs ): super().__init__(**kwargs) self.seamless_model_name = seamless_model_name self.hidden_size = hidden_size self.dropout_prob = dropout_prob class HFSeamlessBasic(PreTrainedModel): """Basic SeamlessM4T model for HuggingFace Hub - processes audio and text only.""" config_class = SeamlessBasicConfig supports_gradient_checkpointing = True def __init__(self, config): super().__init__(config) self.config = config # Load the underlying SeamlessM4T model self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name) self.seamless_model_speech_encoder = self.seamless_model.speech_encoder self.seamless_model_text_encoder = self.seamless_model.text_encoder # Freeze pre-trained models for param in self.seamless_model_speech_encoder.parameters(): param.requires_grad = False for param in self.seamless_model_text_encoder.parameters(): param.requires_grad = False # Projection layers self.audio_proj = nn.Linear( self.seamless_model_speech_encoder.config.hidden_size, config.hidden_size ) self.text_proj = nn.Linear( self.seamless_model_text_encoder.config.hidden_size, config.hidden_size ) # Classification head (2048 = 1024 + 1024) self.fc = nn.Sequential( nn.Linear(2048, 1024), nn.ReLU(), nn.Dropout(config.dropout_prob), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(config.dropout_prob), nn.Linear(512, 256), nn.ReLU(), nn.Dropout(config.dropout_prob), nn.Linear(256, 1) ) def forward( self, input_features, input_ids, text_attention_mask, audio_attention_mask=None, labels=None, **kwargs # Accept additional features but ignore them ): # Encode audio audio_emb = self.seamless_model_speech_encoder( input_features=input_features, attention_mask=audio_attention_mask ).last_hidden_state.mean(dim=1) audio_emb = self.audio_proj(audio_emb) # Encode text text_emb = self.seamless_model_text_encoder( input_ids=input_ids, attention_mask=text_attention_mask ).last_hidden_state.mean(dim=1) text_emb = self.text_proj(text_emb) # Combine features combined = torch.cat([audio_emb, text_emb], dim=1) # (batch_size, 2048) logits = self.fc(combined).squeeze(-1) # Compute loss if labels are provided loss = F.mse_loss(logits, labels) if labels is not None else None return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=None )