output prob distribution
Browse files- handler.py +16 -5
handler.py
CHANGED
@@ -18,6 +18,7 @@ class EndpointHandler:
|
|
18 |
if not input_text:
|
19 |
return {"error": "No input provided."}
|
20 |
|
|
|
21 |
inputs = self.tokenizer(
|
22 |
input_text,
|
23 |
return_tensors="pt",
|
@@ -27,14 +28,24 @@ class EndpointHandler:
|
|
27 |
)
|
28 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
29 |
|
|
|
30 |
with torch.no_grad():
|
31 |
outputs = self.model(**inputs)
|
32 |
-
probs = torch.softmax(outputs.logits, dim=-1)
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
return {
|
38 |
"label": top_class_label,
|
39 |
-
"
|
40 |
}
|
|
|
|
18 |
if not input_text:
|
19 |
return {"error": "No input provided."}
|
20 |
|
21 |
+
# Tokenization
|
22 |
inputs = self.tokenizer(
|
23 |
input_text,
|
24 |
return_tensors="pt",
|
|
|
28 |
)
|
29 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
30 |
|
31 |
+
# Forward pass
|
32 |
with torch.no_grad():
|
33 |
outputs = self.model(**inputs)
|
34 |
+
probs = torch.softmax(outputs.logits, dim=-1)[0] # shape: (num_classes,)
|
35 |
+
|
36 |
+
# Get top class
|
37 |
+
top_class_id = torch.argmax(probs).item()
|
38 |
+
top_class_label = self.id2label.get(top_class_id) or self.id2label.get(str(top_class_id))
|
39 |
+
top_class_prob = probs[top_class_id].item()
|
40 |
+
|
41 |
+
# Convert full distribution to label->probability dict
|
42 |
+
prob_distribution = {
|
43 |
+
self.id2label.get(i) or self.id2label.get(str(i)): round(p.item(), 4)
|
44 |
+
for i, p in enumerate(probs)
|
45 |
+
}
|
46 |
|
47 |
return {
|
48 |
"label": top_class_label,
|
49 |
+
"probabilities": prob_distribution
|
50 |
}
|
51 |
+
|