|
import torch |
|
import numpy as np |
|
from transformers import AutoProcessor |
|
from typing import Dict, List, Union |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class DataCollatorSimpleSeamless: |
|
def __init__( |
|
self, |
|
processor: str, |
|
sample_rate: int = 16000, |
|
max_audio_length_sec: float = 8.0, |
|
max_text_length: int = 256, |
|
normalization_type: str = "none" |
|
): |
|
"""Initialize the data collator. |
|
|
|
Args: |
|
processor: The processor to use. |
|
sample_rate: Audio sample rate. |
|
max_audio_length_sec: Maximum audio length in seconds. |
|
max_text_length: Maximum text length. |
|
normalization_type: Type of normalization to apply to labels. Options: "log1p", "none" |
|
""" |
|
logger.info(f"Loading processor: {processor}") |
|
self.processor = AutoProcessor.from_pretrained(processor) |
|
|
|
self.sample_rate = sample_rate |
|
self.max_audio_sample_length = int(max_audio_length_sec * sample_rate) |
|
self.max_text_length = max_text_length |
|
self.normalization_type = normalization_type |
|
|
|
def __call__(self, batch: List[Dict[str, Union[np.ndarray, str, float]]]) -> Dict[str, torch.Tensor]: |
|
"""Process a batch of raw features into model inputs.""" |
|
|
|
raw_audios = [item['raw_audio'] for item in batch] |
|
raw_texts = [item['raw_text'] for item in batch] |
|
|
|
raw_audios = [torch.tensor(audio) for audio in raw_audios] |
|
|
|
audio_inputs = self.processor( |
|
audios=raw_audios, |
|
sampling_rate=self.sample_rate, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_audio_sample_length, |
|
) |
|
|
|
text_inputs = self.processor( |
|
text=raw_texts, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_text_length, |
|
) |
|
|
|
|
|
is_translation = torch.tensor([item.get('is_translation', 0) for item in batch], dtype=torch.float32) |
|
|
|
|
|
language_pair_id = torch.tensor([item.get('language_pair_id', 0) for item in batch], dtype=torch.long) |
|
|
|
if 'labels' in batch[0]: |
|
labels = [item['labels'] for item in batch] |
|
labels = torch.tensor(labels, dtype=torch.float32) |
|
|
|
|
|
if self.normalization_type == "log1p": |
|
labels = torch.log1p(labels) |
|
elif self.normalization_type == "none": |
|
pass |
|
else: |
|
raise ValueError(f"Unknown normalization type: {self.normalization_type}") |
|
else: |
|
labels = None |
|
|
|
return { |
|
'input_features': audio_inputs['input_features'], |
|
'audio_attention_mask': audio_inputs.get('attention_mask', None) if audio_inputs.get('attention_mask') is not None else None, |
|
'input_ids': text_inputs['input_ids'], |
|
'text_attention_mask': text_inputs['attention_mask'], |
|
'is_translation': is_translation, |
|
'language_pair_id': language_pair_id, |
|
**({'labels': labels} if labels is not None else {}) |
|
} |