File size: 3,633 Bytes
d848202 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import SeamlessM4TModel
import logging
logger = logging.getLogger(__name__)
class SeamlessBasicConfig(PretrainedConfig):
"""Configuration class for SeamlessBasic model."""
model_type = "seamless_basic"
def __init__(
self,
seamless_model_name="facebook/hf-seamless-m4t-medium",
hidden_size=1024,
dropout_prob=0.1,
**kwargs
):
super().__init__(**kwargs)
self.seamless_model_name = seamless_model_name
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
class HFSeamlessBasic(PreTrainedModel):
"""Basic SeamlessM4T model for HuggingFace Hub - processes audio and text only."""
config_class = SeamlessBasicConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.config = config
# Load the underlying SeamlessM4T model
self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name)
self.seamless_model_speech_encoder = self.seamless_model.speech_encoder
self.seamless_model_text_encoder = self.seamless_model.text_encoder
# Freeze pre-trained models
for param in self.seamless_model_speech_encoder.parameters():
param.requires_grad = False
for param in self.seamless_model_text_encoder.parameters():
param.requires_grad = False
# Projection layers
self.audio_proj = nn.Linear(
self.seamless_model_speech_encoder.config.hidden_size,
config.hidden_size
)
self.text_proj = nn.Linear(
self.seamless_model_text_encoder.config.hidden_size,
config.hidden_size
)
# Classification head (2048 = 1024 + 1024)
self.fc = nn.Sequential(
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(256, 1)
)
def forward(
self,
input_features,
input_ids,
text_attention_mask,
audio_attention_mask=None,
labels=None,
**kwargs # Accept additional features but ignore them
):
# Encode audio
audio_emb = self.seamless_model_speech_encoder(
input_features=input_features,
attention_mask=audio_attention_mask
).last_hidden_state.mean(dim=1)
audio_emb = self.audio_proj(audio_emb)
# Encode text
text_emb = self.seamless_model_text_encoder(
input_ids=input_ids,
attention_mask=text_attention_mask
).last_hidden_state.mean(dim=1)
text_emb = self.text_proj(text_emb)
# Combine features
combined = torch.cat([audio_emb, text_emb], dim=1) # (batch_size, 2048)
logits = self.fc(combined).squeeze(-1)
# Compute loss if labels are provided
loss = F.mse_loss(logits, labels) if labels is not None else None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None
) |