|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from transformers import SeamlessM4TModel |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SeamlessLanguagePairsConfig(PretrainedConfig): |
|
"""Configuration class for SeamlessLanguagePairs model.""" |
|
|
|
model_type = "seamless_language_pairs" |
|
|
|
def __init__( |
|
self, |
|
seamless_model_name="facebook/hf-seamless-m4t-medium", |
|
num_language_pairs=21, |
|
hidden_size=1024, |
|
dropout_prob=0.1, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.seamless_model_name = seamless_model_name |
|
self.num_language_pairs = num_language_pairs |
|
self.hidden_size = hidden_size |
|
self.dropout_prob = dropout_prob |
|
|
|
|
|
class HFSeamlessLanguagePairs(PreTrainedModel): |
|
"""SeamlessM4T model with language pairs and translation features for HuggingFace Hub.""" |
|
|
|
config_class = SeamlessLanguagePairsConfig |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name) |
|
self.seamless_model_speech_encoder = self.seamless_model.speech_encoder |
|
self.seamless_model_text_encoder = self.seamless_model.text_encoder |
|
|
|
|
|
for param in self.seamless_model_speech_encoder.parameters(): |
|
param.requires_grad = False |
|
for param in self.seamless_model_text_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
self.audio_proj = nn.Linear( |
|
self.seamless_model_speech_encoder.config.hidden_size, |
|
config.hidden_size |
|
) |
|
self.text_proj = nn.Linear( |
|
self.seamless_model_text_encoder.config.hidden_size, |
|
config.hidden_size |
|
) |
|
|
|
|
|
self.translation_proj = nn.Linear(1, 32) |
|
|
|
|
|
self.language_pair_embedding = nn.Embedding(config.num_language_pairs, 64) |
|
|
|
|
|
self.fc = nn.Sequential( |
|
nn.Linear(2144, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_prob), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_prob), |
|
nn.Linear(512, 256), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_prob), |
|
nn.Linear(256, 1) |
|
) |
|
|
|
def forward( |
|
self, |
|
input_features, |
|
input_ids, |
|
text_attention_mask, |
|
is_translation, |
|
language_pair_id, |
|
audio_attention_mask=None, |
|
labels=None, |
|
**kwargs |
|
): |
|
|
|
audio_emb = self.seamless_model_speech_encoder( |
|
input_features=input_features, |
|
attention_mask=audio_attention_mask |
|
).last_hidden_state.mean(dim=1) |
|
audio_emb = self.audio_proj(audio_emb) |
|
|
|
|
|
text_emb = self.seamless_model_text_encoder( |
|
input_ids=input_ids, |
|
attention_mask=text_attention_mask |
|
).last_hidden_state.mean(dim=1) |
|
text_emb = self.text_proj(text_emb) |
|
|
|
|
|
translation_emb = self.translation_proj(is_translation.unsqueeze(-1)) |
|
|
|
|
|
language_pair_emb = self.language_pair_embedding(language_pair_id) |
|
|
|
|
|
combined = torch.cat([ |
|
audio_emb, |
|
text_emb, |
|
translation_emb, |
|
language_pair_emb |
|
], dim=1) |
|
|
|
logits = self.fc(combined).squeeze(-1) |
|
|
|
|
|
loss = F.mse_loss(logits, labels) if labels is not None else None |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None |
|
) |