import os import torch import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer from snac import SNAC class EndpointHandler: def __init__(self, path=""): # Load the Orpheus model and tokenizer self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit" self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16 ) # Move model to GPU if available self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) # Load SNAC model for audio decoding self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") self.snac_model.to(self.device) # Special tokens self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human self.padding_token = 128263 self.start_audio_token = 128257 # Start of Audio token self.end_audio_token = 128258 # End of Audio token print(f"Model loaded on {self.device}") def preprocess(self, data): """ Preprocess input data before inference """ inputs = data.pop("inputs", data) # Extract parameters from request text = inputs.get("text", "") voice = inputs.get("voice", "tara") temperature = float(inputs.get("temperature", 0.6)) top_p = float(inputs.get("top_p", 0.95)) max_new_tokens = int(inputs.get("max_new_tokens", 1200)) repetition_penalty = float(inputs.get("repetition_penalty", 1.1)) # Format prompt with voice prompt = f"{voice}: {text}" # Tokenize input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids # Add special tokens modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1) # No need for padding as we're processing a single sequence input_ids = modified_input_ids.to(self.device) attention_mask = torch.ones_like(input_ids) return { "input_ids": input_ids, "attention_mask": attention_mask, "temperature": temperature, "top_p": top_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty } def inference(self, inputs): """ Run model inference on the preprocessed inputs """ # Extract parameters input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] temperature = inputs["temperature"] top_p = inputs["top_p"] max_new_tokens = inputs["max_new_tokens"] repetition_penalty = inputs["repetition_penalty"] # Generate output tokens with torch.no_grad(): generated_ids = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=1, eos_token_id=self.end_audio_token, ) return generated_ids def postprocess(self, generated_ids): """ Process generated tokens into audio """ # Find Start of Audio token token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True) if len(token_indices[1]) > 0: last_occurrence_idx = token_indices[1][-1].item() cropped_tensor = generated_ids[:, last_occurrence_idx+1:] else: cropped_tensor = generated_ids # Remove End of Audio tokens processed_rows = [] for row in cropped_tensor: masked_row = row[row != self.end_audio_token] processed_rows.append(masked_row) # Prepare audio codes code_lists = [] for row in processed_rows: row_length = row.size(0) # Ensure length is multiple of 7 for SNAC new_length = (row_length // 7) * 7 trimmed_row = row[:new_length] trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values code_lists.append(trimmed_row) # Generate audio from codes audio_samples = [] for code_list in code_lists: audio = self.redistribute_codes(code_list) audio_samples.append(audio) # Return first (and only) audio sample audio_sample = audio_samples[0].detach().squeeze().cpu().numpy() # Convert to base64 for transmission import base64 import io import wave # Convert float32 array to int16 for WAV format audio_int16 = (audio_sample * 32767).astype(np.int16) # Create WAV in memory with io.BytesIO() as wav_io: with wave.open(wav_io, 'wb') as wav_file: wav_file.setnchannels(1) # Mono wav_file.setsampwidth(2) # 16-bit wav_file.setframerate(24000) # 24kHz wav_file.writeframes(audio_int16.tobytes()) wav_data = wav_io.getvalue() # Encode as base64 audio_b64 = base64.b64encode(wav_data).decode('utf-8') return { "audio_b64": audio_b64, "sample_rate": 24000 } def redistribute_codes(self, code_list): """ Reorganize tokens for SNAC decoding """ layer_1 = [] # Coarsest layer layer_2 = [] # Intermediate layer layer_3 = [] # Finest layer num_groups = len(code_list) // 7 for i in range(num_groups): idx = 7 * i layer_1.append(code_list[idx]) layer_2.append(code_list[idx + 1] - 4096) layer_3.append(code_list[idx + 2] - (2 * 4096)) layer_3.append(code_list[idx + 3] - (3 * 4096)) layer_2.append(code_list[idx + 4] - (4 * 4096)) layer_3.append(code_list[idx + 5] - (5 * 4096)) layer_3.append(code_list[idx + 6] - (6 * 4096)) codes = [ torch.tensor(layer_1).unsqueeze(0).to(self.device), torch.tensor(layer_2).unsqueeze(0).to(self.device), torch.tensor(layer_3).unsqueeze(0).to(self.device) ] # Decode audio audio_hat = self.snac_model.decode(codes) return audio_hat def __call__(self, data): """ Main entry point for the handler """ preprocessed_inputs = self.preprocess(data) model_outputs = self.inference(preprocessed_inputs) response = self.postprocess(model_outputs) return response