walterheart commited on
Commit
31fdbaa
·
verified ·
1 Parent(s): efe8448

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +171 -12
handler.py CHANGED
@@ -1,16 +1,175 @@
1
- from transformers import pipeline
 
2
  import base64
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- class TTSHandler:
6
- def __init__(self):
7
- self.pipe = pipeline("text-to-speech", model="suno/bark")
8
-
9
- def preprocess(self, text):
10
- return {"inputs": text.strip()}
11
-
12
- def inference(self, inputs):
13
- return self.pipe(**inputs)
 
 
 
 
 
 
14
 
15
- def postprocess(self, audio):
16
- return {"audio": base64.b64encode(audio).decode("utf-8")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
  import base64
4
+ import torch
5
+ import numpy as np
6
+ from transformers import BarkModel, BarkProcessor
7
+ from typing import Dict, List, Any
8
 
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ """
12
+ Initialize the handler for Bark text-to-speech model.
13
+ Args:
14
+ path (str, optional): Path to the model directory. Defaults to "".
15
+ """
16
+ self.path = path
17
+ self.model = None
18
+ self.processor = None
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ self.initialized = False
21
 
22
+ def setup(self, **kwargs):
23
+ """
24
+ Load the model and processor.
25
+ Args:
26
+ **kwargs: Additional arguments.
27
+ """
28
+ # Load model from the local directory
29
+ self.model = BarkModel.from_pretrained(self.path)
30
+ self.model.to(self.device)
31
+
32
+ # Load processor
33
+ self.processor = BarkProcessor.from_pretrained(self.path)
34
+
35
+ self.initialized = True
36
+ print(f"Bark model loaded on {self.device}")
37
 
38
+ def preprocess(self, request: Dict) -> Dict:
39
+ """
40
+ Process the input request before inference.
41
+ Args:
42
+ request (Dict): The request data containing text to convert to speech.
43
+ Returns:
44
+ Dict: Processed inputs for the model.
45
+ """
46
+ if not self.initialized:
47
+ self.setup()
48
+
49
+ inputs = {}
50
+
51
+ # Get text from the request
52
+ if "inputs" in request:
53
+ if isinstance(request["inputs"], str):
54
+ # Single text input
55
+ inputs["text"] = request["inputs"]
56
+ elif isinstance(request["inputs"], list):
57
+ # List of text inputs
58
+ inputs["text"] = request["inputs"][0] # Take the first text
59
+
60
+ # Get optional parameters
61
+ params = request.get("parameters", {})
62
+
63
+ # Speaker ID/voice preset
64
+ if "speaker_id" in params:
65
+ inputs["speaker_id"] = params["speaker_id"]
66
+ elif "voice_preset" in params:
67
+ inputs["voice_preset"] = params["voice_preset"]
68
+
69
+ # Other generation parameters
70
+ if "temperature" in params:
71
+ inputs["temperature"] = params.get("temperature", 0.7)
72
+
73
+ return inputs
74
+
75
+ def inference(self, inputs: Dict) -> Dict:
76
+ """
77
+ Run model inference on the processed inputs.
78
+ Args:
79
+ inputs (Dict): Processed inputs for the model.
80
+ Returns:
81
+ Dict: Model outputs.
82
+ """
83
+ text = inputs.get("text", "")
84
+ if not text:
85
+ return {"error": "No text provided for speech generation"}
86
+
87
+ # Extract optional parameters
88
+ speaker_id = inputs.get("speaker_id", None)
89
+ voice_preset = inputs.get("voice_preset", None)
90
+ temperature = inputs.get("temperature", 0.7)
91
+
92
+ # Prepare inputs for the model
93
+ input_ids = self.processor(text).to(self.device)
94
+
95
+ # Generate speech
96
+ with torch.no_grad():
97
+ if speaker_id:
98
+ # Use speaker_id if provided
99
+ speech_output = self.model.generate(
100
+ input_ids=input_ids,
101
+ speaker_id=speaker_id,
102
+ temperature=temperature
103
+ )
104
+ elif voice_preset:
105
+ # Use voice_preset if provided
106
+ speech_output = self.model.generate(
107
+ input_ids=input_ids,
108
+ voice_preset=voice_preset,
109
+ temperature=temperature
110
+ )
111
+ else:
112
+ # Use default settings
113
+ speech_output = self.model.generate(
114
+ input_ids=input_ids,
115
+ temperature=temperature
116
+ )
117
+
118
+ # Convert to numpy array
119
+ audio_array = speech_output.cpu().numpy().squeeze()
120
+
121
+ return {"audio_array": audio_array, "sample_rate": self.model.generation_config.sample_rate}
122
+
123
+ def postprocess(self, inference_output: Dict) -> Dict:
124
+ """
125
+ Process the model outputs after inference.
126
+ Args:
127
+ inference_output (Dict): Model outputs.
128
+ Returns:
129
+ Dict: Processed outputs ready for the response.
130
+ """
131
+ if "error" in inference_output:
132
+ return {"error": inference_output["error"]}
133
+
134
+ audio_array = inference_output.get("audio_array")
135
+ sample_rate = inference_output.get("sample_rate", 24000)
136
+
137
+ # Convert audio array to WAV format
138
+ try:
139
+ import scipy.io.wavfile as wav
140
+ audio_buffer = io.BytesIO()
141
+ wav.write(audio_buffer, sample_rate, audio_array)
142
+ audio_buffer.seek(0)
143
+ audio_data = audio_buffer.read()
144
+
145
+ # Encode audio data to base64
146
+ audio_base64 = base64.b64encode(audio_data).decode("utf-8")
147
+
148
+ return {
149
+ "audio": audio_base64,
150
+ "sample_rate": sample_rate,
151
+ "format": "wav"
152
+ }
153
+ except Exception as e:
154
+ return {"error": f"Error converting audio: {str(e)}"}
155
+
156
+ def __call__(self, data: Dict) -> Dict:
157
+ """
158
+ Main entry point for the handler.
159
+ Args:
160
+ data (Dict): Request data.
161
+ Returns:
162
+ Dict: Response data.
163
+ """
164
+ # Ensure the model is initialized
165
+ if not self.initialized:
166
+ self.setup()
167
+
168
+ # Process the request
169
+ try:
170
+ inputs = self.preprocess(data)
171
+ outputs = self.inference(inputs)
172
+ response = self.postprocess(outputs)
173
+ return response
174
+ except Exception as e:
175
+ return {"error": f"Error processing request: {str(e)}"}