File size: 1,220 Bytes
5c35d44 c9feb62 5777436 ed658dd 5c35d44 5777436 c9feb62 5777436 c9feb62 79e137c 5777436 c9feb62 79e137c 5777436 5c35d44 c9feb62 5c35d44 c9feb62 5777436 c9feb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
class EndpointHandler:
"""Custom handler for NuExtract-2-8B (InternLM2 based)."""
def __init__(self, path: str = "") -> None:
# ↓↓↓ allow the repo’s custom configuration & modelling code
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True # ← mandatory
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True, # ← mandatory
torch_dtype=torch.float16, # fits on a 16 GB GPU
device_map="auto" # put tensors on the GPU
).eval()
def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
prompt = data.get("inputs", "")
if not prompt:
return {"error": "No input provided."}
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
output_ids = self.model.generate(**inputs, max_new_tokens=128)
answer = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"generated_text": answer} |