atharva27 commited on
Commit
28ef647
·
verified ·
1 Parent(s): 236bcb3

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +104 -0
handler.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from livekit import rtc
6
+ import asyncio
7
+ import os
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path: str = ""):
11
+ # Load the Orpheus TTS model and tokenizer from the given path (Hub repository).
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ path = "atharva27/orpheus"
14
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
15
+ self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
16
+ self.model.to(self.device)
17
+ self.model.eval()
18
+
19
+ def __call__(self, data: dict) -> list:
20
+ # Extract input text and optional voice and LiveKit parameters.
21
+ text_input = data.get("inputs") or data.get("text") or ""
22
+ if not isinstance(text_input, str) or text_input.strip() == "":
23
+ raise ValueError("No text input provided for TTS")
24
+ voice = data.get("voice", "tara") # default voice (e.g., "tara")
25
+
26
+ # Format prompt with voice name (Orpheus expects prompts like "voice: text").
27
+ prompt = f"{voice}: {text_input}"
28
+
29
+ # Encode prompt and generate output tokens with the TTS model.
30
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
31
+ generate_kwargs = {
32
+ "max_new_tokens": 1024, # allow sufficient tokens for audio output
33
+ "do_sample": True,
34
+ "temperature": 0.8,
35
+ "top_p": 0.95,
36
+ "repetition_penalty": 1.1, # >=1.1 for stable speech generation
37
+ "pad_token_id": self.tokenizer.eos_token_id,
38
+ }
39
+ output_ids = self.model.generate(input_ids, **generate_kwargs)
40
+ # The generated sequence includes the prompt; isolate newly generated tokens:
41
+ generated_tokens = output_ids[0, input_ids.size(1):]
42
+ output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
43
+
44
+ # Extract audio token IDs (assume tokens are in the output_text)
45
+ # This is a placeholder for token extraction, replace with actual logic.
46
+ audio_token_ids = [int(m) for m in output_text.split()]
47
+
48
+ # Example: convert the audio token IDs to waveform data
49
+ waveform = self.generate_waveform_from_tokens(audio_token_ids)
50
+
51
+ # Save or stream waveform
52
+ torchaudio.save("output_audio.wav", waveform, 24000) # Save as a 24 kHz audio file
53
+
54
+ # For real-time streaming, we will use LiveKit to stream the audio
55
+ lk_url = data.get("livekit_url")
56
+ lk_token = data.get("livekit_token")
57
+ room_name = data.get("livekit_room", "default-room")
58
+
59
+ # Streaming logic
60
+ asyncio.run(self.stream_audio(lk_url, lk_token, room_name, waveform))
61
+
62
+ return [{"status": "success"}]
63
+
64
+ def generate_waveform_from_tokens(self, audio_token_ids):
65
+ """
66
+ Convert audio tokens into a waveform (this part is for demonstration).
67
+ You should implement a proper method to decode tokens to actual audio.
68
+ """
69
+ # Here we're simulating the waveform by generating random data based on the tokens
70
+ # Replace this logic with actual audio generation
71
+ num_samples = len(audio_token_ids) * 100 # Estimate number of samples based on tokens
72
+ waveform = torch.randn(1, num_samples) # Simulate random audio waveform
73
+ return waveform
74
+
75
+ async def stream_audio(self, lk_url, lk_token, room_name, waveform):
76
+ room = rtc.Room()
77
+ try:
78
+ await room.connect(lk_url, lk_token, options=rtc.RoomOptions(auto_subscribe=True))
79
+ except Exception as e:
80
+ return f"Failed to connect to LiveKit: {e}"
81
+
82
+ # Create an audio track for streaming the TTS output
83
+ source = rtc.AudioSource(sample_rate=24000, num_channels=1)
84
+ track = rtc.LocalAudioTrack.create_audio_track("tts-audio", source)
85
+ await room.local_participant.publish_track(track, rtc.TrackPublishOptions(name="TTS Audio"))
86
+
87
+ # Stream the waveform data in chunks for real-time playback
88
+ frame_duration = 0.05 # 50 ms per frame
89
+ frame_samples = int(24000 * frame_duration) # 50 ms of audio at 24 kHz sample rate
90
+ total_samples = waveform.size(1)
91
+ for start in range(0, total_samples, frame_samples):
92
+ end = min(start + frame_samples, total_samples)
93
+ chunk = waveform[:, start:end].numpy().astype(np.int16) # Convert chunk to 16-bit PCM
94
+
95
+ # Create an AudioFrame and send to LiveKit
96
+ audio_frame = rtc.AudioFrame.create(24000, 1, len(chunk))
97
+ np.copyto(audio_frame.data, chunk)
98
+ await source.capture_frame(audio_frame)
99
+
100
+ # Sleep to maintain real-time pace (synchronize with frame duration)
101
+ await asyncio.sleep(frame_duration)
102
+
103
+ # Disconnect from the room after streaming is finished
104
+ await room.disconnect()