|
import os |
|
import torch |
|
import numpy as np |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from snac import SNAC |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model_name = "hypaai/Hypa_Orpheus-3b-0.1-ft-unsloth-merged_16bit" |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_name, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
|
|
self.snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz") |
|
self.snac_model.to(self.device) |
|
|
|
|
|
self.start_token = torch.tensor([[128259]], dtype=torch.int64) |
|
self.end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64) |
|
self.padding_token = 128263 |
|
self.start_audio_token = 128257 |
|
self.end_audio_token = 128258 |
|
|
|
print(f"Model loaded on {self.device}") |
|
|
|
def preprocess(self, data): |
|
""" |
|
Preprocess input data before inference |
|
""" |
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
text = inputs.get("text", "") |
|
voice = inputs.get("voice", "tara") |
|
temperature = float(inputs.get("temperature", 0.6)) |
|
top_p = float(inputs.get("top_p", 0.95)) |
|
max_new_tokens = int(inputs.get("max_new_tokens", 1200)) |
|
repetition_penalty = float(inputs.get("repetition_penalty", 1.1)) |
|
|
|
|
|
prompt = f"{voice}: {text}" |
|
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids |
|
|
|
|
|
modified_input_ids = torch.cat([self.start_token, input_ids, self.end_tokens], dim=1) |
|
|
|
|
|
input_ids = modified_input_ids.to(self.device) |
|
attention_mask = torch.ones_like(input_ids) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty |
|
} |
|
|
|
def inference(self, inputs): |
|
""" |
|
Run model inference on the preprocessed inputs |
|
""" |
|
|
|
input_ids = inputs["input_ids"] |
|
attention_mask = inputs["attention_mask"] |
|
temperature = inputs["temperature"] |
|
top_p = inputs["top_p"] |
|
max_new_tokens = inputs["max_new_tokens"] |
|
repetition_penalty = inputs["repetition_penalty"] |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = self.model.generate( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_return_sequences=1, |
|
eos_token_id=self.end_audio_token, |
|
) |
|
|
|
return generated_ids |
|
|
|
def postprocess(self, generated_ids): |
|
""" |
|
Process generated tokens into audio |
|
""" |
|
|
|
token_indices = (generated_ids == self.start_audio_token).nonzero(as_tuple=True) |
|
|
|
if len(token_indices[1]) > 0: |
|
last_occurrence_idx = token_indices[1][-1].item() |
|
cropped_tensor = generated_ids[:, last_occurrence_idx+1:] |
|
else: |
|
cropped_tensor = generated_ids |
|
|
|
|
|
processed_rows = [] |
|
for row in cropped_tensor: |
|
masked_row = row[row != self.end_audio_token] |
|
processed_rows.append(masked_row) |
|
|
|
|
|
code_lists = [] |
|
for row in processed_rows: |
|
row_length = row.size(0) |
|
|
|
new_length = (row_length // 7) * 7 |
|
trimmed_row = row[:new_length] |
|
trimmed_row = [t.item() - 128266 for t in trimmed_row] |
|
code_lists.append(trimmed_row) |
|
|
|
|
|
audio_samples = [] |
|
for code_list in code_lists: |
|
audio = self.redistribute_codes(code_list) |
|
audio_samples.append(audio) |
|
|
|
|
|
audio_sample = audio_samples[0].detach().squeeze().cpu().numpy() |
|
|
|
|
|
import base64 |
|
import io |
|
import wave |
|
|
|
|
|
audio_int16 = (audio_sample * 32767).astype(np.int16) |
|
|
|
|
|
with io.BytesIO() as wav_io: |
|
with wave.open(wav_io, 'wb') as wav_file: |
|
wav_file.setnchannels(1) |
|
wav_file.setsampwidth(2) |
|
wav_file.setframerate(24000) |
|
wav_file.writeframes(audio_int16.tobytes()) |
|
wav_data = wav_io.getvalue() |
|
|
|
|
|
audio_b64 = base64.b64encode(wav_data).decode('utf-8') |
|
|
|
return { |
|
"audio_b64": audio_b64, |
|
"sample_rate": 24000 |
|
} |
|
|
|
def redistribute_codes(self, code_list): |
|
""" |
|
Reorganize tokens for SNAC decoding |
|
""" |
|
layer_1 = [] |
|
layer_2 = [] |
|
layer_3 = [] |
|
|
|
num_groups = len(code_list) // 7 |
|
for i in range(num_groups): |
|
idx = 7 * i |
|
layer_1.append(code_list[idx]) |
|
layer_2.append(code_list[idx + 1] - 4096) |
|
layer_3.append(code_list[idx + 2] - (2 * 4096)) |
|
layer_3.append(code_list[idx + 3] - (3 * 4096)) |
|
layer_2.append(code_list[idx + 4] - (4 * 4096)) |
|
layer_3.append(code_list[idx + 5] - (5 * 4096)) |
|
layer_3.append(code_list[idx + 6] - (6 * 4096)) |
|
|
|
codes = [ |
|
torch.tensor(layer_1).unsqueeze(0).to(self.device), |
|
torch.tensor(layer_2).unsqueeze(0).to(self.device), |
|
torch.tensor(layer_3).unsqueeze(0).to(self.device) |
|
] |
|
|
|
|
|
audio_hat = self.snac_model.decode(codes) |
|
return audio_hat |
|
|
|
def __call__(self, data): |
|
""" |
|
Main entry point for the handler |
|
""" |
|
preprocessed_inputs = self.preprocess(data) |
|
model_outputs = self.inference(preprocessed_inputs) |
|
response = self.postprocess(model_outputs) |
|
return response |