hypa_ai_orpheus / handler.py
hypaai's picture
Updated handler.py with claude
59a5c2a verified
raw
history blame
7.26 kB
import os
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC
class EndpointHandler:
def __init__(self, path=""):
# Load the Orpheus model and tokenizer
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit"
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16
)
# Move model to GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Load SNAC model for audio decoding
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
self.snac_model.to(self.device)
# Special tokens
self.start_token = torch.tensor([[128259]], dtype=torch.int64) # Start of human
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) # End of text, End of human
self.padding_token = 128263
self.start_audio_token = 128257 # Start of Audio token
self.end_audio_token = 128258 # End of Audio token
print(f"Model loaded on {self.device}")
def preprocess(self, data):
"""
Preprocess input data before inference
"""
inputs = data.pop("inputs", data)
# Extract parameters from request
text = inputs.get("text", "")
voice = inputs.get("voice", "tara")
temperature = float(inputs.get("temperature", 0.6))
top_p = float(inputs.get("top_p", 0.95))
max_new_tokens = int(inputs.get("max_new_tokens", 1200))
repetition_penalty = float(inputs.get("repetition_penalty", 1.1))
# Format prompt with voice
prompt = f"{voice}: {text}"
# Tokenize
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
# Add special tokens
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1)
# No need for padding as we're processing a single sequence
input_ids = modified_input_ids.to(self.device)
attention_mask = torch.ones_like(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens,
"repetition_penalty": repetition_penalty
}
def inference(self, inputs):
"""
Run model inference on the preprocessed inputs
"""
# Extract parameters
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
temperature = inputs["temperature"]
top_p = inputs["top_p"]
max_new_tokens = inputs["max_new_tokens"]
repetition_penalty = inputs["repetition_penalty"]
# Generate output tokens
with torch.no_grad():
generated_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
num_return_sequences=1,
eos_token_id=self.end_audio_token,
)
return generated_ids
def postprocess(self, generated_ids):
"""
Process generated tokens into audio
"""
# Find Start of Audio token
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True)
if len(token_indices[1]) > 0:
last_occurrence_idx = token_indices[1][-1].item()
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
else:
cropped_tensor = generated_ids
# Remove End of Audio tokens
processed_rows = []
for row in cropped_tensor:
masked_row = row[row != self.end_audio_token]
processed_rows.append(masked_row)
# Prepare audio codes
code_lists = []
for row in processed_rows:
row_length = row.size(0)
# Ensure length is multiple of 7 for SNAC
new_length = (row_length // 7) * 7
trimmed_row = row[:new_length]
trimmed_row = [t.item() - 128266 for t in trimmed_row] # Adjust token values
code_lists.append(trimmed_row)
# Generate audio from codes
audio_samples = []
for code_list in code_lists:
audio = self.redistribute_codes(code_list)
audio_samples.append(audio)
# Return first (and only) audio sample
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy()
# Convert to base64 for transmission
import base64
import io
import wave
# Convert float32 array to int16 for WAV format
audio_int16 = (audio_sample * 32767).astype(np.int16)
# Create WAV in memory
with io.BytesIO() as wav_io:
with wave.open(wav_io, 'wb') as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(24000) # 24kHz
wav_file.writeframes(audio_int16.tobytes())
wav_data = wav_io.getvalue()
# Encode as base64
audio_b64 = base64.b64encode(wav_data).decode('utf-8')
return {
"audio_b64": audio_b64,
"sample_rate": 24000
}
def redistribute_codes(self, code_list):
"""
Reorganize tokens for SNAC decoding
"""
layer_1 = [] # Coarsest layer
layer_2 = [] # Intermediate layer
layer_3 = [] # Finest layer
num_groups = len(code_list) // 7
for i in range(num_groups):
idx = 7 * i
layer_1.append(code_list[idx])
layer_2.append(code_list[idx + 1] - 4096)
layer_3.append(code_list[idx + 2] - (2 * 4096))
layer_3.append(code_list[idx + 3] - (3 * 4096))
layer_2.append(code_list[idx + 4] - (4 * 4096))
layer_3.append(code_list[idx + 5] - (5 * 4096))
layer_3.append(code_list[idx + 6] - (6 * 4096))
codes = [
torch.tensor(layer_1).unsqueeze(0).to(self.device),
torch.tensor(layer_2).unsqueeze(0).to(self.device),
torch.tensor(layer_3).unsqueeze(0).to(self.device)
]
# Decode audio
audio_hat = self.snac_model.decode(codes)
return audio_hat
def __call__(self, data):
"""
Main entry point for the handler
"""
preprocessed_inputs = self.preprocess(data)
model_outputs = self.inference(preprocessed_inputs)
response = self.postprocess(model_outputs)
return response