Create handler.py
Browse files- handler.py +104 -0
handler.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import numpy as np
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
5 |
+
from livekit import rtc
|
6 |
+
import asyncio
|
7 |
+
import os
|
8 |
+
|
9 |
+
class EndpointHandler:
|
10 |
+
def __init__(self, path: str = ""):
|
11 |
+
# Load the Orpheus TTS model and tokenizer from the given path (Hub repository).
|
12 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
path = "atharva27/orpheus"
|
14 |
+
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
15 |
+
self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
|
16 |
+
self.model.to(self.device)
|
17 |
+
self.model.eval()
|
18 |
+
|
19 |
+
def __call__(self, data: dict) -> list:
|
20 |
+
# Extract input text and optional voice and LiveKit parameters.
|
21 |
+
text_input = data.get("inputs") or data.get("text") or ""
|
22 |
+
if not isinstance(text_input, str) or text_input.strip() == "":
|
23 |
+
raise ValueError("No text input provided for TTS")
|
24 |
+
voice = data.get("voice", "tara") # default voice (e.g., "tara")
|
25 |
+
|
26 |
+
# Format prompt with voice name (Orpheus expects prompts like "voice: text").
|
27 |
+
prompt = f"{voice}: {text_input}"
|
28 |
+
|
29 |
+
# Encode prompt and generate output tokens with the TTS model.
|
30 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
|
31 |
+
generate_kwargs = {
|
32 |
+
"max_new_tokens": 1024, # allow sufficient tokens for audio output
|
33 |
+
"do_sample": True,
|
34 |
+
"temperature": 0.8,
|
35 |
+
"top_p": 0.95,
|
36 |
+
"repetition_penalty": 1.1, # >=1.1 for stable speech generation
|
37 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
38 |
+
}
|
39 |
+
output_ids = self.model.generate(input_ids, **generate_kwargs)
|
40 |
+
# The generated sequence includes the prompt; isolate newly generated tokens:
|
41 |
+
generated_tokens = output_ids[0, input_ids.size(1):]
|
42 |
+
output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
|
43 |
+
|
44 |
+
# Extract audio token IDs (assume tokens are in the output_text)
|
45 |
+
# This is a placeholder for token extraction, replace with actual logic.
|
46 |
+
audio_token_ids = [int(m) for m in output_text.split()]
|
47 |
+
|
48 |
+
# Example: convert the audio token IDs to waveform data
|
49 |
+
waveform = self.generate_waveform_from_tokens(audio_token_ids)
|
50 |
+
|
51 |
+
# Save or stream waveform
|
52 |
+
torchaudio.save("output_audio.wav", waveform, 24000) # Save as a 24 kHz audio file
|
53 |
+
|
54 |
+
# For real-time streaming, we will use LiveKit to stream the audio
|
55 |
+
lk_url = data.get("livekit_url")
|
56 |
+
lk_token = data.get("livekit_token")
|
57 |
+
room_name = data.get("livekit_room", "default-room")
|
58 |
+
|
59 |
+
# Streaming logic
|
60 |
+
asyncio.run(self.stream_audio(lk_url, lk_token, room_name, waveform))
|
61 |
+
|
62 |
+
return [{"status": "success"}]
|
63 |
+
|
64 |
+
def generate_waveform_from_tokens(self, audio_token_ids):
|
65 |
+
"""
|
66 |
+
Convert audio tokens into a waveform (this part is for demonstration).
|
67 |
+
You should implement a proper method to decode tokens to actual audio.
|
68 |
+
"""
|
69 |
+
# Here we're simulating the waveform by generating random data based on the tokens
|
70 |
+
# Replace this logic with actual audio generation
|
71 |
+
num_samples = len(audio_token_ids) * 100 # Estimate number of samples based on tokens
|
72 |
+
waveform = torch.randn(1, num_samples) # Simulate random audio waveform
|
73 |
+
return waveform
|
74 |
+
|
75 |
+
async def stream_audio(self, lk_url, lk_token, room_name, waveform):
|
76 |
+
room = rtc.Room()
|
77 |
+
try:
|
78 |
+
await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True))
|
79 |
+
except Exception as e:
|
80 |
+
return f"Failed to connect to LiveKit: {e}"
|
81 |
+
|
82 |
+
# Create an audio track for streaming the TTS output
|
83 |
+
source = rtc.AudioSource(sample_rate=24000, num_channels=1)
|
84 |
+
track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source)
|
85 |
+
await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio"))
|
86 |
+
|
87 |
+
# Stream the waveform data in chunks for real-time playback
|
88 |
+
frame_duration = 0.05 # 50 ms per frame
|
89 |
+
frame_samples = int(24000 * frame_duration) # 50 ms of audio at 24 kHz sample rate
|
90 |
+
total_samples = waveform.size(1)
|
91 |
+
for start in range(0, total_samples, frame_samples):
|
92 |
+
end = min(start + frame_samples, total_samples)
|
93 |
+
chunk = waveform[:, start:end].numpy().astype(np.int16) # Convert chunk to 16-bit PCM
|
94 |
+
|
95 |
+
# Create an AudioFrame and send to LiveKit
|
96 |
+
audio_frame = rtc.AudioFrame.create(24000, 1, len(chunk))
|
97 |
+
np.copyto(audio_frame.data, chunk)
|
98 |
+
await source.capture_frame(audio_frame)
|
99 |
+
|
100 |
+
# Sleep to maintain real-time pace (synchronize with frame duration)
|
101 |
+
await asyncio.sleep(frame_duration)
|
102 |
+
|
103 |
+
# Disconnect from the room after streaming is finished
|
104 |
+
await room.disconnect()
|