seamless-langpairs / modeling_seamless_langpairs.py
giuseppe-tanzi's picture
Upload folder using huggingface_hub
87f6e47 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import SeamlessM4TModel
import logging
logger = logging.getLogger(__name__)
class SeamlessLanguagePairsConfig(PretrainedConfig):
"""Configuration class for SeamlessLanguagePairs model."""
model_type = "seamless_language_pairs"
def __init__(
self,
seamless_model_name="facebook/hf-seamless-m4t-medium",
num_language_pairs=21,
hidden_size=1024,
dropout_prob=0.1,
**kwargs
):
super().__init__(**kwargs)
self.seamless_model_name = seamless_model_name
self.num_language_pairs = num_language_pairs
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
class HFSeamlessLanguagePairs(PreTrainedModel):
"""SeamlessM4T model with language pairs and translation features for HuggingFace Hub."""
config_class = SeamlessLanguagePairsConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.config = config
# Load the underlying SeamlessM4T model
self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name)
self.seamless_model_speech_encoder = self.seamless_model.speech_encoder
self.seamless_model_text_encoder = self.seamless_model.text_encoder
# Freeze pre-trained models
for param in self.seamless_model_speech_encoder.parameters():
param.requires_grad = False
for param in self.seamless_model_text_encoder.parameters():
param.requires_grad = False
# Projection layers
self.audio_proj = nn.Linear(
self.seamless_model_speech_encoder.config.hidden_size,
config.hidden_size
)
self.text_proj = nn.Linear(
self.seamless_model_text_encoder.config.hidden_size,
config.hidden_size
)
# Translation feature embedding (binary)
self.translation_proj = nn.Linear(1, 32)
# Language pair embedding (categorical)
self.language_pair_embedding = nn.Embedding(config.num_language_pairs, 64)
# Classification head (2048 + 32 + 64 = 2144)
self.fc = nn.Sequential(
nn.Linear(2144, 1024),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(256, 1)
)
def forward(
self,
input_features,
input_ids,
text_attention_mask,
is_translation,
language_pair_id,
audio_attention_mask=None,
labels=None,
**kwargs
):
# Encode audio
audio_emb = self.seamless_model_speech_encoder(
input_features=input_features,
attention_mask=audio_attention_mask
).last_hidden_state.mean(dim=1)
audio_emb = self.audio_proj(audio_emb)
# Encode text
text_emb = self.seamless_model_text_encoder(
input_ids=input_ids,
attention_mask=text_attention_mask
).last_hidden_state.mean(dim=1)
text_emb = self.text_proj(text_emb)
# Process translation feature
translation_emb = self.translation_proj(is_translation.unsqueeze(-1))
# Process language pair feature
language_pair_emb = self.language_pair_embedding(language_pair_id)
# Combine features
combined = torch.cat([
audio_emb,
text_emb,
translation_emb,
language_pair_emb
], dim=1) # (batch_size, 2144)
logits = self.fc(combined).squeeze(-1)
# Compute loss if labels are provided
loss = F.mse_loss(logits, labels) if labels is not None else None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None
)