File size: 1,128 Bytes
d5cce9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from typing import Any
import torch
from transformers import AutoModelForSeq2SeqLM, AutoProcessor
class EndpointHandler:
"""
Handler for allenai/olmOCR-7B-0725
Input:
{
"inputs": <PIL.Image | base64 str | URL>,
"parameters": {"max_new_tokens": <int, optional>}
}
Output: {"generated_text": <str>}
"""
def __init__(self, path: str = "") -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = path or "allenai/olmOCR-7B-0725"
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device)
def __call__(self, data: dict) -> Any:
image = data.get("inputs")
params = data.get("parameters", {})
max_tokens = params.get("max_new_tokens", 256)
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
text = self.processor.batch_decode(ids, skip_special_tokens=True)[0]
return {"generated_text": text} |