hypaai commited on
Commit
21ba720
·
verified ·
1 Parent(s): b754e90

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -11
handler.py CHANGED
@@ -37,15 +37,23 @@ class EndpointHandler:
37
  """
38
  Preprocess input data before inference
39
  """
40
- inputs = data.pop("inputs", data)
 
 
 
 
 
 
 
 
 
41
 
42
  # Extract parameters from request
43
- text = inputs.get("text", "")
44
- voice = inputs.get("voice", "tara")
45
- temperature = float(inputs.get("temperature", 0.6))
46
- top_p = float(inputs.get("top_p", 0.95))
47
- max_new_tokens = int(inputs.get("max_new_tokens", 1200))
48
- repetition_penalty = float(inputs.get("repetition_penalty", 1.1))
49
 
50
  # Format prompt with voice
51
  prompt = f"{voice}: {text}"
@@ -193,7 +201,19 @@ class EndpointHandler:
193
  """
194
  Main entry point for the handler
195
  """
196
- preprocessed_inputs = self.preprocess(data)
197
- model_outputs = self.inference(preprocessed_inputs)
198
- response = self.postprocess(model_outputs)
199
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  """
38
  Preprocess input data before inference
39
  """
40
+ # HF Inference API format: 'inputs' is the text, 'parameters' contains the config
41
+ # Handle both direct access and standardized HF format
42
+ if isinstance(data, dict) and "inputs" in data:
43
+ # Standard HF format
44
+ text = data["inputs"]
45
+ parameters = data.get("parameters", {})
46
+ else:
47
+ # Direct access (fallback)
48
+ text = data
49
+ parameters = {}
50
 
51
  # Extract parameters from request
52
+ voice = parameters.get("voice", "tara")
53
+ temperature = float(parameters.get("temperature", 0.6))
54
+ top_p = float(parameters.get("top_p", 0.95))
55
+ max_new_tokens = int(parameters.get("max_new_tokens", 1200))
56
+ repetition_penalty = float(parameters.get("repetition_penalty", 1.1))
 
57
 
58
  # Format prompt with voice
59
  prompt = f"{voice}: {text}"
 
201
  """
202
  Main entry point for the handler
203
  """
204
+ try:
205
+ logger.info(f"Received request: {type(data)}")
206
+
207
+ # Check if we need to handle the health check route
208
+ if data == "ping" or data == {"inputs": "ping"}:
209
+ return {"status": "ok"}
210
+
211
+ preprocessed_inputs = self.preprocess(data)
212
+ model_outputs = self.inference(preprocessed_inputs)
213
+ response = self.postprocess(model_outputs)
214
+ return response
215
+ except Exception as e:
216
+ logger.error(f"Error processing request: {str(e)}")
217
+ import traceback
218
+ logger.error(traceback.format_exc())
219
+ return {"error": str(e)}