jschwab21 commited on
Commit
351aa3d
·
verified ·
1 Parent(s): 4d90ad4

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -0
handler.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ # Load the model in FP16 to reduce memory usage while retaining performance.
8
+ self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16)
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+
11
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
+ """
13
+ data args:
14
+ inputs (str): The text input or prompts for the model
15
+ Return:
16
+ A list containing the generated responses.
17
+ """
18
+ # Extract the input text from the request
19
+ inputs = data.get("inputs", "")
20
+ if not inputs:
21
+ return [{"error": "No input provided"}]
22
+
23
+ # Tokenize the input and run the model to generate output
24
+ tokens = self.tokenizer(inputs, return_tensors="pt").to(torch.float16)
25
+ output_tokens = self.model.generate(**tokens)
26
+
27
+ # Decode the generated tokens back to text
28
+ output_text = self.tokenizer.decode(output_tokens[0], skip_special_tokens=True)
29
+
30
+ # Return the generated response as a list (required format)
31
+ return [{"generated_text": output_text}]