hypa_ai_orpheus / handler.py
hypaai's picture
Update handler.py
67d8511 verified
raw
history blame
5.62 kB
import torch
from transformers import AutoProcessor, AutoModelForTextToSpeech
import soundfile as sf
import io
import numpy as np
import os
class EndpointHandler():
def __init__(self, path=""):
"""
Initializes the handler. Loads the model and processor.
'path' is the directory where the model files are located.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Define the model path explicitly if needed, or rely on 'path'
model_path = path if path else os.getenv("HF_MODEL_DIR", ".") # HF Endpoints provide model dir via env var or path arg
# --- Load Model Components ---
# Adjust these lines based on the specific classes your Orpheus model needs
# It might be AutoModelForSpeechSeq2Seq, VitsModel, BarkModel, etc.
# Ensure you use the correct class names from the transformers library or
# the library your model relies on.
try:
print(f"Loading processor from: {model_path}")
self.processor = AutoProcessor.from_pretrained(model_path)
print(f"Loading model from: {model_path}")
self.model = AutoModelForTextToSpeech.from_pretrained(model_path)
self.model.to(self.device)
print("Model and processor loaded successfully.")
# --- Get Sampling Rate ---
# Try to get sampling rate from config, provide a default if not found
# Common locations: model.config.sampling_rate or processor.feature_extractor.sampling_rate
# Adjust this based on your specific model architecture!
self.sampling_rate = getattr(self.model.config, 'sampling_rate', None)
if self.sampling_rate is None and hasattr(self.processor, 'feature_extractor'):
self.sampling_rate = getattr(self.processor.feature_extractor, 'sampling_rate', 16000) # Default fallback
elif self.sampling_rate is None:
self.sampling_rate = 16000 # Default fallback if no config found
print(f"Using sampling rate: {self.sampling_rate}")
except Exception as e:
print(f"Error loading model or processor: {e}")
raise RuntimeError(f"Failed to load model/processor from {model_path}", e)
def __call__(self, data: dict) -> bytes:
"""
Runs inference on the input data.
'data' is the dictionary parsed from the incoming JSON request payload.
Should return raw audio bytes (e.g., WAV format).
"""
try:
# --- Get Inputs ---
# Extract text input - adjust key 'inputs' if necessary
inputs_text = data.pop("inputs", None)
if inputs_text is None:
raise ValueError("Missing 'inputs' key in request data")
# Optional: handle other parameters passed in the request
parameters = data.pop("parameters", {})
# --- Preprocess Text ---
# Use the processor to prepare model inputs
processed_inputs = self.processor(text=inputs_text, return_tensors="pt").to(self.device)
# --- Generate Speech ---
# Adjust generation parameters if needed (e.g., speaker embeddings)
# The output key might vary ('waveform', 'audio', 'speech', etc.)
with torch.no_grad():
# If your model needs specific args like speaker_embeddings, handle them here
# Example: speaker_embeddings = self.load_speaker_embedding(...)
# output = self.model.generate(**processed_inputs, speaker_embeddings=speaker_embeddings, **parameters)
output = self.model.generate(**processed_inputs, **parameters)
# --- Postprocess Audio ---
# Ensure output is on CPU and convert to numpy array
# The exact processing depends on the model output format
# Assuming output tensor contains the waveform
if isinstance(output, torch.Tensor):
speech_waveform = output.cpu().numpy().squeeze()
# Handle cases where output might be in a dictionary (common with pipelines)
elif isinstance(output, dict) and 'audio' in output:
speech_waveform = output['audio'].cpu().numpy().squeeze()
elif isinstance(output, dict) and 'waveform' in output:
speech_waveform = output['waveform'].cpu().numpy().squeeze()
else:
# Add handling for other potential output types if needed
raise TypeError(f"Unexpected model output type: {type(output)}")
# Normalize if necessary (some models output in range [-1, 1], others don't)
# If the waveform isn't in [-1, 1], soundfile might require normalization
# Example check: if np.max(np.abs(speech_waveform)) > 1.0:
# speech_waveform = speech_waveform / np.max(np.abs(speech_waveform))
# --- Convert to WAV Bytes ---
# Use an in-memory buffer to store the WAV file
buffer = io.BytesIO()
sf.write(buffer, speech_waveform, self.sampling_rate, format='WAV')
buffer.seek(0)
wav_bytes = buffer.read()
print(f"Generated {len(wav_bytes)} bytes of WAV audio.")
return wav_bytes
except Exception as e:
print(f"Error during inference: {e}")
# Re-raise the exception so the endpoint framework knows it failed
raise RuntimeError(f"Inference failed: {e}")