File size: 1,128 Bytes
d5cce9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Any
import torch
from transformers import AutoModelForSeq2SeqLM, AutoProcessor

class EndpointHandler:
    """
    Handler for allenai/olmOCR-7B-0725

    Input:
      {
        "inputs": <PIL.Image | base64 str | URL>,
        "parameters": {"max_new_tokens": <int, optional>}
      }

    Output: {"generated_text": <str>}
    """

    def __init__(self, path: str = "") -> None:
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        model_path = path or "allenai/olmOCR-7B-0725"
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(self.device)

    def __call__(self, data: dict) -> Any:
        image = data.get("inputs")
        params = data.get("parameters", {})
        max_tokens = params.get("max_new_tokens", 256)

        inputs = self.processor(images=image, return_tensors="pt").to(self.device)
        ids = self.model.generate(**inputs, max_new_tokens=max_tokens)
        text = self.processor.batch_decode(ids, skip_special_tokens=True)[0]

        return {"generated_text": text}