CodyBontecou commited on
Commit
d19a591
·
1 Parent(s): fa994ec

__call__ return values

Browse files
Files changed (1) hide show
  1. handler.py +18 -4
handler.py CHANGED
@@ -13,9 +13,7 @@ class EndpointHandler:
13
  self.model = AutoModel.from_pretrained(
14
  model_dir,
15
  torch_dtype=torch.bfloat16,
16
- low_cpu_mem_usage=True,
17
  trust_remote_code=True,
18
- device_map="auto",
19
  ).eval()
20
 
21
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -23,9 +21,25 @@ class EndpointHandler:
23
  )
24
 
25
  def __call__(self, data: Dict[str, Any]) -> Any:
26
- logger.info(f"Received incoming request with {data=}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  if __name__ == "__main__":
30
  handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")
31
- print(handler)
 
13
  self.model = AutoModel.from_pretrained(
14
  model_dir,
15
  torch_dtype=torch.bfloat16,
 
16
  trust_remote_code=True,
 
17
  ).eval()
18
 
19
  self.tokenizer = AutoTokenizer.from_pretrained(
 
21
  )
22
 
23
  def __call__(self, data: Dict[str, Any]) -> Any:
24
+ logger.info(f"Received incoming request with {data}")
25
+
26
+ # Extract input text from the request data
27
+ input_text = data.get("inputs", "")
28
+ if not input_text:
29
+ logger.warning("No input text provided")
30
+ return [{"generated_text": ""}] # Return empty result but in valid format
31
+
32
+ # Tokenize the input
33
+ inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
34
+
35
+ # Generate embeddings
36
+ with torch.no_grad():
37
+ outputs = self.model(**inputs)
38
+
39
+ # Process outputs - this depends on your specific model and requirements
40
+ # For now, we'll just return the input as the output to fix the array format issue
41
+ return [{"input_text": input_text, "generated_text": outputs}]
42
 
43
 
44
  if __name__ == "__main__":
45
  handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")