File size: 4,365 Bytes
87f6e47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import SeamlessM4TModel
import logging
logger = logging.getLogger(__name__)
class SeamlessLanguagePairsConfig(PretrainedConfig):
"""Configuration class for SeamlessLanguagePairs model."""
model_type = "seamless_language_pairs"
def __init__(
self,
seamless_model_name="facebook/hf-seamless-m4t-medium",
num_language_pairs=21,
hidden_size=1024,
dropout_prob=0.1,
**kwargs
):
super().__init__(**kwargs)
self.seamless_model_name = seamless_model_name
self.num_language_pairs = num_language_pairs
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
class HFSeamlessLanguagePairs(PreTrainedModel):
"""SeamlessM4T model with language pairs and translation features for HuggingFace Hub."""
config_class = SeamlessLanguagePairsConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.config = config
# Load the underlying SeamlessM4T model
self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name)
self.seamless_model_speech_encoder = self.seamless_model.speech_encoder
self.seamless_model_text_encoder = self.seamless_model.text_encoder
# Freeze pre-trained models
for param in self.seamless_model_speech_encoder.parameters():
param.requires_grad = False
for param in self.seamless_model_text_encoder.parameters():
param.requires_grad = False
# Projection layers
self.audio_proj = nn.Linear(
self.seamless_model_speech_encoder.config.hidden_size,
config.hidden_size
)
self.text_proj = nn.Linear(
self.seamless_model_text_encoder.config.hidden_size,
config.hidden_size
)
# Translation feature embedding (binary)
self.translation_proj = nn.Linear(1, 32)
# Language pair embedding (categorical)
self.language_pair_embedding = nn.Embedding(config.num_language_pairs, 64)
# Classification head (2048 + 32 + 64 = 2144)
self.fc = nn.Sequential(
nn.Linear(2144, 1024),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(256, 1)
)
def forward(
self,
input_features,
input_ids,
text_attention_mask,
is_translation,
language_pair_id,
audio_attention_mask=None,
labels=None,
**kwargs
):
# Encode audio
audio_emb = self.seamless_model_speech_encoder(
input_features=input_features,
attention_mask=audio_attention_mask
).last_hidden_state.mean(dim=1)
audio_emb = self.audio_proj(audio_emb)
# Encode text
text_emb = self.seamless_model_text_encoder(
input_ids=input_ids,
attention_mask=text_attention_mask
).last_hidden_state.mean(dim=1)
text_emb = self.text_proj(text_emb)
# Process translation feature
translation_emb = self.translation_proj(is_translation.unsqueeze(-1))
# Process language pair feature
language_pair_emb = self.language_pair_embedding(language_pair_id)
# Combine features
combined = torch.cat([
audio_emb,
text_emb,
translation_emb,
language_pair_emb
], dim=1) # (batch_size, 2144)
logits = self.fc(combined).squeeze(-1)
# Compute loss if labels are provided
loss = F.mse_loss(logits, labels) if labels is not None else None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None
) |