nu-extract-2-fork / handler.py
HowWebWorks
update tokenizer
c9feb62
raw
history blame
1.21 kB
from typing import Dict
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
"""
Minimal custom handler for InternLM2 / NuExtract-2-8B
"""
def __init__(self, path: str = "./model"):
# allow execution of custom model code
self.tokenizer = AutoTokenizer.from_pretrained(
path, trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True, # ← key line
torch_dtype=torch.float16, # load in fp16 to fit on one A10/T4
device_map="auto" # send to GPU if available
).eval() # put in inference mode
def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
prompt = data.get("inputs", "")
if not prompt:
return {"error": "No input provided."}
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(**inputs, max_new_tokens=128)
answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": answer}