jeff-RQ commited on
Commit
ae68214
·
1 Parent(s): 1ee0582

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +23 -0
handler.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
3
+
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # load model and processor from path
8
+ self.processor = Blip2Processor.from_pretrained(path)
9
+ self.model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ self.model.to(self.device)
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
15
+ # process input
16
+ image = data.pop("image", data)
17
+ text = data.pop("text", data)
18
+
19
+ inputs = self.processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
20
+ generated_ids = self.model.generate(**inputs)
21
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
22
+
23
+ return [{"answer": generated_text}]