Update handler.py
Browse files- handler.py +31 -11
handler.py
CHANGED
@@ -37,15 +37,23 @@ class EndpointHandler:
|
|
37 |
"""
|
38 |
Preprocess input data before inference
|
39 |
"""
|
40 |
-
inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
# Extract parameters from request
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
repetition_penalty = float(inputs.get("repetition_penalty", 1.1))
|
49 |
|
50 |
# Format prompt with voice
|
51 |
prompt = f"{voice}: {text}"
|
@@ -193,7 +201,19 @@ class EndpointHandler:
|
|
193 |
"""
|
194 |
Main entry point for the handler
|
195 |
"""
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
"""
|
38 |
Preprocess input data before inference
|
39 |
"""
|
40 |
+
# HF Inference API format: 'inputs' is the text, 'parameters' contains the config
|
41 |
+
# Handle both direct access and standardized HF format
|
42 |
+
if isinstance(data, dict) and "inputs" in data:
|
43 |
+
# Standard HF format
|
44 |
+
text = data["inputs"]
|
45 |
+
parameters = data.get("parameters", {})
|
46 |
+
else:
|
47 |
+
# Direct access (fallback)
|
48 |
+
text = data
|
49 |
+
parameters = {}
|
50 |
|
51 |
# Extract parameters from request
|
52 |
+
voice = parameters.get("voice", "tara")
|
53 |
+
temperature = float(parameters.get("temperature", 0.6))
|
54 |
+
top_p = float(parameters.get("top_p", 0.95))
|
55 |
+
max_new_tokens = int(parameters.get("max_new_tokens", 1200))
|
56 |
+
repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
|
|
|
57 |
|
58 |
# Format prompt with voice
|
59 |
prompt = f"{voice}: {text}"
|
|
|
201 |
"""
|
202 |
Main entry point for the handler
|
203 |
"""
|
204 |
+
try:
|
205 |
+
logger.info(f"Received request: {type(data)}")
|
206 |
+
|
207 |
+
# Check if we need to handle the health check route
|
208 |
+
if data == "ping" or data == {"inputs": "ping"}:
|
209 |
+
return {"status": "ok"}
|
210 |
+
|
211 |
+
preprocessed_inputs = self.preprocess(data)
|
212 |
+
model_outputs = self.inference(preprocessed_inputs)
|
213 |
+
response = self.postprocess(model_outputs)
|
214 |
+
return response
|
215 |
+
except Exception as e:
|
216 |
+
logger.error(f"Error processing request: {str(e)}")
|
217 |
+
import traceback
|
218 |
+
logger.error(traceback.format_exc())
|
219 |
+
return {"error": str(e)}
|