Commit
·
d19a591
1
Parent(s):
fa994ec
__call__ return values
Browse files- handler.py +18 -4
handler.py
CHANGED
@@ -13,9 +13,7 @@ class EndpointHandler:
|
|
13 |
self.model = AutoModel.from_pretrained(
|
14 |
model_dir,
|
15 |
torch_dtype=torch.bfloat16,
|
16 |
-
low_cpu_mem_usage=True,
|
17 |
trust_remote_code=True,
|
18 |
-
device_map="auto",
|
19 |
).eval()
|
20 |
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
@@ -23,9 +21,25 @@ class EndpointHandler:
|
|
23 |
)
|
24 |
|
25 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
26 |
-
logger.info(f"Received incoming request with {data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
if __name__ == "__main__":
|
30 |
handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")
|
31 |
-
print(handler)
|
|
|
13 |
self.model = AutoModel.from_pretrained(
|
14 |
model_dir,
|
15 |
torch_dtype=torch.bfloat16,
|
|
|
16 |
trust_remote_code=True,
|
|
|
17 |
).eval()
|
18 |
|
19 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
21 |
)
|
22 |
|
23 |
def __call__(self, data: Dict[str, Any]) -> Any:
|
24 |
+
logger.info(f"Received incoming request with {data}")
|
25 |
+
|
26 |
+
# Extract input text from the request data
|
27 |
+
input_text = data.get("inputs", "")
|
28 |
+
if not input_text:
|
29 |
+
logger.warning("No input text provided")
|
30 |
+
return [{"generated_text": ""}] # Return empty result but in valid format
|
31 |
+
|
32 |
+
# Tokenize the input
|
33 |
+
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
|
34 |
+
|
35 |
+
# Generate embeddings
|
36 |
+
with torch.no_grad():
|
37 |
+
outputs = self.model(**inputs)
|
38 |
+
|
39 |
+
# Process outputs - this depends on your specific model and requirements
|
40 |
+
# For now, we'll just return the input as the output to fix the array format issue
|
41 |
+
return [{"input_text": input_text, "generated_text": outputs}]
|
42 |
|
43 |
|
44 |
if __name__ == "__main__":
|
45 |
handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")
|
|