import torch from transformers import AutoProcessor, AutoModelForTextToSpeech import soundfile as sf import io import numpy as np import os class EndpointHandler(): def __init__(self, path=""): """ Initializes the handler. Loads the model and processor. 'path' is the directory where the model files are located. """ self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Define the model path explicitly if needed, or rely on 'path' model_path = path if path else os.getenv("HF_MODEL_DIR", ".") # HF Endpoints provide model dir via env var or path arg # --- Load Model Components --- # Adjust these lines based on the specific classes your Orpheus model needs # It might be AutoModelForSpeechSeq2Seq, VitsModel, BarkModel, etc. # Ensure you use the correct class names from the transformers library or # the library your model relies on. try: print(f"Loading processor from: {model_path}") self.processor = AutoProcessor.from_pretrained(model_path) print(f"Loading model from: {model_path}") self.model = AutoModelForTextToSpeech.from_pretrained(model_path) self.model.to(self.device) print("Model and processor loaded successfully.") # --- Get Sampling Rate --- # Try to get sampling rate from config, provide a default if not found # Common locations: model.config.sampling_rate or processor.feature_extractor.sampling_rate # Adjust this based on your specific model architecture! self.sampling_rate = getattr(self.model.config, 'sampling_rate', None) if self.sampling_rate is None and hasattr(self.processor, 'feature_extractor'): self.sampling_rate = getattr(self.processor.feature_extractor, 'sampling_rate', 16000) # Default fallback elif self.sampling_rate is None: self.sampling_rate = 16000 # Default fallback if no config found print(f"Using sampling rate: {self.sampling_rate}") except Exception as e: print(f"Error loading model or processor: {e}") raise RuntimeError(f"Failed to load model/processor from {model_path}", e) def __call__(self, data: dict) -> bytes: """ Runs inference on the input data. 'data' is the dictionary parsed from the incoming JSON request payload. Should return raw audio bytes (e.g., WAV format). """ try: # --- Get Inputs --- # Extract text input - adjust key 'inputs' if necessary inputs_text = data.pop("inputs", None) if inputs_text is None: raise ValueError("Missing 'inputs' key in request data") # Optional: handle other parameters passed in the request parameters = data.pop("parameters", {}) # --- Preprocess Text --- # Use the processor to prepare model inputs processed_inputs = self.processor(text=inputs_text, return_tensors="pt").to(self.device) # --- Generate Speech --- # Adjust generation parameters if needed (e.g., speaker embeddings) # The output key might vary ('waveform', 'audio', 'speech', etc.) with torch.no_grad(): # If your model needs specific args like speaker_embeddings, handle them here # Example: speaker_embeddings = self.load_speaker_embedding(...) # output = self.model.generate(**processed_inputs, speaker_embeddings=speaker_embeddings, **parameters) output = self.model.generate(**processed_inputs, **parameters) # --- Postprocess Audio --- # Ensure output is on CPU and convert to numpy array # The exact processing depends on the model output format # Assuming output tensor contains the waveform if isinstance(output, torch.Tensor): speech_waveform = output.cpu().numpy().squeeze() # Handle cases where output might be in a dictionary (common with pipelines) elif isinstance(output, dict) and 'audio' in output: speech_waveform = output['audio'].cpu().numpy().squeeze() elif isinstance(output, dict) and 'waveform' in output: speech_waveform = output['waveform'].cpu().numpy().squeeze() else: # Add handling for other potential output types if needed raise TypeError(f"Unexpected model output type: {type(output)}") # Normalize if necessary (some models output in range [-1, 1], others don't) # If the waveform isn't in [-1, 1], soundfile might require normalization # Example check: if np.max(np.abs(speech_waveform)) > 1.0: # speech_waveform = speech_waveform / np.max(np.abs(speech_waveform)) # --- Convert to WAV Bytes --- # Use an in-memory buffer to store the WAV file buffer = io.BytesIO() sf.write(buffer, speech_waveform, self.sampling_rate, format='WAV') buffer.seek(0) wav_bytes = buffer.read() print(f"Generated {len(wav_bytes)} bytes of WAV audio.") return wav_bytes except Exception as e: print(f"Error during inference: {e}") # Re-raise the exception so the endpoint framework knows it failed raise RuntimeError(f"Inference failed: {e}")