Update README.md
Browse files
README.md
CHANGED
|
@@ -2684,7 +2684,7 @@ def pooling(outputs: torch.Tensor, inputs: Dict, strategy: str = 'cls') -> np.n
|
|
| 2684 |
outputs = outputs[:, 0]
|
| 2685 |
elif strategy == 'mean':
|
| 2686 |
outputs = torch.sum(
|
| 2687 |
-
outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"])
|
| 2688 |
else:
|
| 2689 |
raise NotImplementedError
|
| 2690 |
return outputs.detach().cpu().numpy()
|
|
|
|
| 2684 |
outputs = outputs[:, 0]
|
| 2685 |
elif strategy == 'mean':
|
| 2686 |
outputs = torch.sum(
|
| 2687 |
+
outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"], dim=1, keepdim=True)
|
| 2688 |
else:
|
| 2689 |
raise NotImplementedError
|
| 2690 |
return outputs.detach().cpu().numpy()
|