samadpls commited on
Commit
ccc2704
·
verified ·
1 Parent(s): d3ecb83

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +27 -0
handler.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ import torch
3
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
+ try:
9
+ self.model = T5ForConditionalGeneration.from_pretrained(path).to(self.device)
10
+ self.tokenizer = T5Tokenizer.from_pretrained(path)
11
+ except Exception as e:
12
+ print(f"Error loading model or tokenizer from path {path}: {e}")
13
+ # Handle error (e.g., exit or set model/tokenizer to None)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ inputs = data.get("inputs", "")
17
+ if not inputs:
18
+ return [{"error": "No inputs provided"}]
19
+
20
+ tokenized_input = self.tokenizer(inputs, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
21
+ tokenized_input = tokenized_input.to(self.device) # Move input tensors to the same device as model
22
+
23
+ summary_ids = self.model.generate(**tokenized_input, max_length=400, do_sample=True, top_p=0.8)
24
+
25
+ summary_text = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
26
+
27
+ return [{"summary": summary_text}]