seamless-basic / data_collator.py
giuseppe-tanzi's picture
Upload folder using huggingface_hub
8525e7c verified
raw
history blame
3.39 kB
import torch
import numpy as np
from transformers import AutoProcessor
from typing import Dict, List, Union
import logging
logger = logging.getLogger(__name__)
class DataCollatorSimpleSeamless:
def __init__(
self,
processor: str,
sample_rate: int = 16000,
max_audio_length_sec: float = 8.0,
max_text_length: int = 256,
normalization_type: str = "none"
):
"""Initialize the data collator.
Args:
processor: The processor to use.
sample_rate: Audio sample rate.
max_audio_length_sec: Maximum audio length in seconds.
max_text_length: Maximum text length.
normalization_type: Type of normalization to apply to labels. Options: "log1p", "none"
"""
logger.info(f"Loading processor: {processor}")
self.processor = AutoProcessor.from_pretrained(processor)
self.sample_rate = sample_rate
self.max_audio_sample_length = int(max_audio_length_sec * sample_rate)
self.max_text_length = max_text_length
self.normalization_type = normalization_type
def __call__(self, batch: List[Dict[str, Union[np.ndarray, str, float]]]) -> Dict[str, torch.Tensor]:
"""Process a batch of raw features into model inputs."""
# Extract raw data
raw_audios = [item['raw_audio'] for item in batch]
raw_texts = [item['raw_text'] for item in batch]
raw_audios = [torch.tensor(audio) for audio in raw_audios]
audio_inputs = self.processor(
audios=raw_audios,
sampling_rate=self.sample_rate,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_audio_sample_length,
)
text_inputs = self.processor(
text=raw_texts,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_text_length,
)
# Extract translation features
is_translation = torch.tensor([item.get('is_translation', 0) for item in batch], dtype=torch.float32)
# Extract language pair features
language_pair_id = torch.tensor([item.get('language_pair_id', 0) for item in batch], dtype=torch.long)
if 'labels' in batch[0]:
labels = [item['labels'] for item in batch]
labels = torch.tensor(labels, dtype=torch.float32)
# Apply normalization based on type
if self.normalization_type == "log1p":
labels = torch.log1p(labels)
elif self.normalization_type == "none":
pass
else:
raise ValueError(f"Unknown normalization type: {self.normalization_type}")
else:
labels = None
return {
'input_features': audio_inputs['input_features'],
'audio_attention_mask': audio_inputs.get('attention_mask', None) if audio_inputs.get('attention_mask') is not None else None,
'input_ids': text_inputs['input_ids'],
'text_attention_mask': text_inputs['attention_mask'],
'is_translation': is_translation,
'language_pair_id': language_pair_id,
**({'labels': labels} if labels is not None else {})
}