|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from transformers import SeamlessM4TModel |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SeamlessBasicConfig(PretrainedConfig): |
|
"""Configuration class for SeamlessBasic model.""" |
|
|
|
model_type = "seamless_basic" |
|
|
|
def __init__( |
|
self, |
|
seamless_model_name="facebook/hf-seamless-m4t-medium", |
|
hidden_size=1024, |
|
dropout_prob=0.1, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.seamless_model_name = seamless_model_name |
|
self.hidden_size = hidden_size |
|
self.dropout_prob = dropout_prob |
|
|
|
|
|
class HFSeamlessBasic(PreTrainedModel): |
|
"""Basic SeamlessM4T model for HuggingFace Hub - processes audio and text only.""" |
|
|
|
config_class = SeamlessBasicConfig |
|
supports_gradient_checkpointing = True |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name) |
|
self.seamless_model_speech_encoder = self.seamless_model.speech_encoder |
|
self.seamless_model_text_encoder = self.seamless_model.text_encoder |
|
|
|
|
|
for param in self.seamless_model_speech_encoder.parameters(): |
|
param.requires_grad = False |
|
for param in self.seamless_model_text_encoder.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
self.audio_proj = nn.Linear( |
|
self.seamless_model_speech_encoder.config.hidden_size, |
|
config.hidden_size |
|
) |
|
self.text_proj = nn.Linear( |
|
self.seamless_model_text_encoder.config.hidden_size, |
|
config.hidden_size |
|
) |
|
|
|
|
|
self.fc = nn.Sequential( |
|
nn.Linear(2048, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_prob), |
|
nn.Linear(1024, 512), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_prob), |
|
nn.Linear(512, 256), |
|
nn.ReLU(), |
|
nn.Dropout(config.dropout_prob), |
|
nn.Linear(256, 1) |
|
) |
|
|
|
def forward( |
|
self, |
|
input_features, |
|
input_ids, |
|
text_attention_mask, |
|
audio_attention_mask=None, |
|
labels=None, |
|
**kwargs |
|
): |
|
|
|
audio_emb = self.seamless_model_speech_encoder( |
|
input_features=input_features, |
|
attention_mask=audio_attention_mask |
|
).last_hidden_state.mean(dim=1) |
|
audio_emb = self.audio_proj(audio_emb) |
|
|
|
|
|
text_emb = self.seamless_model_text_encoder( |
|
input_ids=input_ids, |
|
attention_mask=text_attention_mask |
|
).last_hidden_state.mean(dim=1) |
|
text_emb = self.text_proj(text_emb) |
|
|
|
|
|
combined = torch.cat([audio_emb, text_emb], dim=1) |
|
|
|
logits = self.fc(combined).squeeze(-1) |
|
|
|
|
|
loss = F.mse_loss(logits, labels) if labels is not None else None |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None |
|
) |