import torch import numpy as np from transformers import AutoProcessor from typing import Dict, List, Union import logging logger = logging.getLogger(__name__) class DataCollatorSimpleSeamless: def __init__( self, processor: str, sample_rate: int = 16000, max_audio_length_sec: float = 8.0, max_text_length: int = 256, normalization_type: str = "none" ): """Initialize the data collator. Args: processor: The processor to use. sample_rate: Audio sample rate. max_audio_length_sec: Maximum audio length in seconds. max_text_length: Maximum text length. normalization_type: Type of normalization to apply to labels. Options: "log1p", "none" """ logger.info(f"Loading processor: {processor}") self.processor = AutoProcessor.from_pretrained(processor) self.sample_rate = sample_rate self.max_audio_sample_length = int(max_audio_length_sec * sample_rate) self.max_text_length = max_text_length self.normalization_type = normalization_type def __call__(self, batch: List[Dict[str, Union[np.ndarray, str, float]]]) -> Dict[str, torch.Tensor]: """Process a batch of raw features into model inputs.""" # Extract raw data raw_audios = [item['raw_audio'] for item in batch] raw_texts = [item['raw_text'] for item in batch] raw_audios = [torch.tensor(audio) for audio in raw_audios] audio_inputs = self.processor( audios=raw_audios, sampling_rate=self.sample_rate, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_audio_sample_length, ) text_inputs = self.processor( text=raw_texts, return_tensors="pt", padding="longest", truncation=True, max_length=self.max_text_length, ) # Extract translation features is_translation = torch.tensor([item.get('is_translation', 0) for item in batch], dtype=torch.float32) # Extract language pair features language_pair_id = torch.tensor([item.get('language_pair_id', 0) for item in batch], dtype=torch.long) if 'labels' in batch[0]: labels = [item['labels'] for item in batch] labels = torch.tensor(labels, dtype=torch.float32) # Apply normalization based on type if self.normalization_type == "log1p": labels = torch.log1p(labels) elif self.normalization_type == "none": pass else: raise ValueError(f"Unknown normalization type: {self.normalization_type}") else: labels = None return { 'input_features': audio_inputs['input_features'], 'audio_attention_mask': audio_inputs.get('attention_mask', None) if audio_inputs.get('attention_mask') is not None else None, 'input_ids': text_inputs['input_ids'], 'text_attention_mask': text_inputs['attention_mask'], 'is_translation': is_translation, 'language_pair_id': language_pair_id, **({'labels': labels} if labels is not None else {}) }