File size: 4,365 Bytes
87f6e47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import SeamlessM4TModel
import logging

logger = logging.getLogger(__name__)


class SeamlessLanguagePairsConfig(PretrainedConfig):
    """Configuration class for SeamlessLanguagePairs model."""
    
    model_type = "seamless_language_pairs"
    
    def __init__(
        self,
        seamless_model_name="facebook/hf-seamless-m4t-medium",
        num_language_pairs=21,
        hidden_size=1024,
        dropout_prob=0.1,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.seamless_model_name = seamless_model_name
        self.num_language_pairs = num_language_pairs
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob


class HFSeamlessLanguagePairs(PreTrainedModel):
    """SeamlessM4T model with language pairs and translation features for HuggingFace Hub."""
    
    config_class = SeamlessLanguagePairsConfig
    supports_gradient_checkpointing = True
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # Load the underlying SeamlessM4T model
        self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name)
        self.seamless_model_speech_encoder = self.seamless_model.speech_encoder
        self.seamless_model_text_encoder = self.seamless_model.text_encoder
        
        # Freeze pre-trained models
        for param in self.seamless_model_speech_encoder.parameters():
            param.requires_grad = False
        for param in self.seamless_model_text_encoder.parameters():
            param.requires_grad = False
        
        # Projection layers
        self.audio_proj = nn.Linear(
            self.seamless_model_speech_encoder.config.hidden_size, 
            config.hidden_size
        )
        self.text_proj = nn.Linear(
            self.seamless_model_text_encoder.config.hidden_size, 
            config.hidden_size
        )
        
        # Translation feature embedding (binary)
        self.translation_proj = nn.Linear(1, 32)
        
        # Language pair embedding (categorical)
        self.language_pair_embedding = nn.Embedding(config.num_language_pairs, 64)
        
        # Classification head (2048 + 32 + 64 = 2144)
        self.fc = nn.Sequential(
            nn.Linear(2144, 1024),
            nn.ReLU(),
            nn.Dropout(config.dropout_prob),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(config.dropout_prob),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(config.dropout_prob),
            nn.Linear(256, 1)
        )
    
    def forward(
        self,
        input_features,
        input_ids,
        text_attention_mask,
        is_translation,
        language_pair_id,
        audio_attention_mask=None,
        labels=None,
        **kwargs
    ):
        # Encode audio
        audio_emb = self.seamless_model_speech_encoder(
            input_features=input_features,
            attention_mask=audio_attention_mask
        ).last_hidden_state.mean(dim=1)
        audio_emb = self.audio_proj(audio_emb)
        
        # Encode text
        text_emb = self.seamless_model_text_encoder(
            input_ids=input_ids,
            attention_mask=text_attention_mask
        ).last_hidden_state.mean(dim=1)
        text_emb = self.text_proj(text_emb)
        
        # Process translation feature
        translation_emb = self.translation_proj(is_translation.unsqueeze(-1))
        
        # Process language pair feature
        language_pair_emb = self.language_pair_embedding(language_pair_id)
        
        # Combine features
        combined = torch.cat([
            audio_emb, 
            text_emb, 
            translation_emb, 
            language_pair_emb
        ], dim=1)  # (batch_size, 2144)
        
        logits = self.fc(combined).squeeze(-1)
        
        # Compute loss if labels are provided
        loss = F.mse_loss(logits, labels) if labels is not None else None
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None
        )