|
|
import torch |
|
|
from transformers import AutoProcessor, AutoModelForTextToSpeech |
|
|
import soundfile as sf |
|
|
import io |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initializes the handler. Loads the model and processor. |
|
|
'path' is the directory where the model files are located. |
|
|
""" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
model_path = path if path else os.getenv("HF_MODEL_DIR", ".") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
print(f"Loading processor from: {model_path}") |
|
|
self.processor = AutoProcessor.from_pretrained(model_path) |
|
|
print(f"Loading model from: {model_path}") |
|
|
self.model = AutoModelForTextToSpeech.from_pretrained(model_path) |
|
|
self.model.to(self.device) |
|
|
print("Model and processor loaded successfully.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.sampling_rate = getattr(self.model.config, 'sampling_rate', None) |
|
|
if self.sampling_rate is None and hasattr(self.processor, 'feature_extractor'): |
|
|
self.sampling_rate = getattr(self.processor.feature_extractor, 'sampling_rate', 16000) |
|
|
elif self.sampling_rate is None: |
|
|
self.sampling_rate = 16000 |
|
|
print(f"Using sampling rate: {self.sampling_rate}") |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model or processor: {e}") |
|
|
raise RuntimeError(f"Failed to load model/processor from {model_path}", e) |
|
|
|
|
|
|
|
|
def __call__(self, data: dict) -> bytes: |
|
|
""" |
|
|
Runs inference on the input data. |
|
|
'data' is the dictionary parsed from the incoming JSON request payload. |
|
|
Should return raw audio bytes (e.g., WAV format). |
|
|
""" |
|
|
try: |
|
|
|
|
|
|
|
|
inputs_text = data.pop("inputs", None) |
|
|
if inputs_text is None: |
|
|
raise ValueError("Missing 'inputs' key in request data") |
|
|
|
|
|
|
|
|
parameters = data.pop("parameters", {}) |
|
|
|
|
|
|
|
|
|
|
|
processed_inputs = self.processor(text=inputs_text, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
output = self.model.generate(**processed_inputs, **parameters) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(output, torch.Tensor): |
|
|
speech_waveform = output.cpu().numpy().squeeze() |
|
|
|
|
|
elif isinstance(output, dict) and 'audio' in output: |
|
|
speech_waveform = output['audio'].cpu().numpy().squeeze() |
|
|
elif isinstance(output, dict) and 'waveform' in output: |
|
|
speech_waveform = output['waveform'].cpu().numpy().squeeze() |
|
|
else: |
|
|
|
|
|
raise TypeError(f"Unexpected model output type: {type(output)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
sf.write(buffer, speech_waveform, self.sampling_rate, format='WAV') |
|
|
buffer.seek(0) |
|
|
wav_bytes = buffer.read() |
|
|
|
|
|
print(f"Generated {len(wav_bytes)} bytes of WAV audio.") |
|
|
return wav_bytes |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during inference: {e}") |
|
|
|
|
|
raise RuntimeError(f"Inference failed: {e}") |