File size: 3,633 Bytes
d848202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import SeamlessM4TModel
import logging

logger = logging.getLogger(__name__)


class SeamlessBasicConfig(PretrainedConfig):
    """Configuration class for SeamlessBasic model."""
    
    model_type = "seamless_basic"
    
    def __init__(
        self,
        seamless_model_name="facebook/hf-seamless-m4t-medium",
        hidden_size=1024,
        dropout_prob=0.1,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.seamless_model_name = seamless_model_name
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob


class HFSeamlessBasic(PreTrainedModel):
    """Basic SeamlessM4T model for HuggingFace Hub - processes audio and text only."""
    
    config_class = SeamlessBasicConfig
    supports_gradient_checkpointing = True
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # Load the underlying SeamlessM4T model
        self.seamless_model = SeamlessM4TModel.from_pretrained(config.seamless_model_name)
        self.seamless_model_speech_encoder = self.seamless_model.speech_encoder
        self.seamless_model_text_encoder = self.seamless_model.text_encoder
        
        # Freeze pre-trained models
        for param in self.seamless_model_speech_encoder.parameters():
            param.requires_grad = False
        for param in self.seamless_model_text_encoder.parameters():
            param.requires_grad = False
        
        # Projection layers
        self.audio_proj = nn.Linear(
            self.seamless_model_speech_encoder.config.hidden_size, 
            config.hidden_size
        )
        self.text_proj = nn.Linear(
            self.seamless_model_text_encoder.config.hidden_size, 
            config.hidden_size
        )
        
        # Classification head (2048 = 1024 + 1024)
        self.fc = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(config.dropout_prob),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(config.dropout_prob),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(config.dropout_prob),
            nn.Linear(256, 1)
        )
    
    def forward(
        self,
        input_features,
        input_ids,
        text_attention_mask,
        audio_attention_mask=None,
        labels=None,
        **kwargs  # Accept additional features but ignore them
    ):
        # Encode audio
        audio_emb = self.seamless_model_speech_encoder(
            input_features=input_features,
            attention_mask=audio_attention_mask
        ).last_hidden_state.mean(dim=1)
        audio_emb = self.audio_proj(audio_emb)
        
        # Encode text
        text_emb = self.seamless_model_text_encoder(
            input_ids=input_ids,
            attention_mask=text_attention_mask
        ).last_hidden_state.mean(dim=1)
        text_emb = self.text_proj(text_emb)
        
        # Combine features
        combined = torch.cat([audio_emb, text_emb], dim=1)  # (batch_size, 2048)
        
        logits = self.fc(combined).squeeze(-1)
        
        # Compute loss if labels are provided
        loss = F.mse_loss(logits, labels) if labels is not None else None
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None
        )