ke-lly-d commited on
Commit
458aa66
·
1 Parent(s): de51a6b

output prob distribution

Browse files
Files changed (1) hide show
  1. handler.py +16 -5
handler.py CHANGED
@@ -18,6 +18,7 @@ class EndpointHandler:
18
  if not input_text:
19
  return {"error": "No input provided."}
20
 
 
21
  inputs = self.tokenizer(
22
  input_text,
23
  return_tensors="pt",
@@ -27,14 +28,24 @@ class EndpointHandler:
27
  )
28
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
29
 
 
30
  with torch.no_grad():
31
  outputs = self.model(**inputs)
32
- probs = torch.softmax(outputs.logits, dim=-1)
33
- top_class_id = torch.argmax(probs, dim=-1).item()
34
- top_class_label = self.id2label.get(top_class_id) or self.id2label.get(str(top_class_id))
35
- top_class_prob = probs[0, top_class_id].item()
 
 
 
 
 
 
 
 
36
 
37
  return {
38
  "label": top_class_label,
39
- "confidence": round(top_class_prob, 4)
40
  }
 
 
18
  if not input_text:
19
  return {"error": "No input provided."}
20
 
21
+ # Tokenization
22
  inputs = self.tokenizer(
23
  input_text,
24
  return_tensors="pt",
 
28
  )
29
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
30
 
31
+ # Forward pass
32
  with torch.no_grad():
33
  outputs = self.model(**inputs)
34
+ probs = torch.softmax(outputs.logits, dim=-1)[0] # shape: (num_classes,)
35
+
36
+ # Get top class
37
+ top_class_id = torch.argmax(probs).item()
38
+ top_class_label = self.id2label.get(top_class_id) or self.id2label.get(str(top_class_id))
39
+ top_class_prob = probs[top_class_id].item()
40
+
41
+ # Convert full distribution to label->probability dict
42
+ prob_distribution = {
43
+ self.id2label.get(i) or self.id2label.get(str(i)): round(p.item(), 4)
44
+ for i, p in enumerate(probs)
45
+ }
46
 
47
  return {
48
  "label": top_class_label,
49
+ "probabilities": prob_distribution
50
  }
51
+