File size: 1,427 Bytes
42d6781 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from typing import Dict, Any
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
class EndpointHandler:
def __init__(self, path: str = "."):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.to(self.device)
self.model.eval()
self.id2label = self.model.config.id2label
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
input_text = data.get("inputs", "")
if not input_text:
return {"error": "No input provided."}
inputs = self.tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
top_class_id = torch.argmax(probs, dim=-1).item()
top_class_label = self.id2label.get(top_class_id) or self.id2label.get(str(top_class_id))
top_class_prob = probs[0, top_class_id].item()
return {
"label": top_class_label,
"confidence": round(top_class_prob, 4)
}
|