Tom Aarsen
commited on
Commit
·
30e2384
1
Parent(s):
12e66dc
Cast attention_mask to bool in SDPA
Browse filesI'm pretty sure this is correct. It was an int tensor before, which SDPA doesn't like
model.py
CHANGED
|
@@ -190,7 +190,7 @@ class EncoderBlock(nn.Module):
|
|
| 190 |
query=xq.transpose(1, 2),
|
| 191 |
key=xk.transpose(1, 2),
|
| 192 |
value=xv.transpose(1, 2),
|
| 193 |
-
attn_mask=attention_mask,
|
| 194 |
dropout_p=0,
|
| 195 |
).transpose(1, 2)
|
| 196 |
|
|
|
|
| 190 |
query=xq.transpose(1, 2),
|
| 191 |
key=xk.transpose(1, 2),
|
| 192 |
value=xv.transpose(1, 2),
|
| 193 |
+
attn_mask=attention_mask.bool(),
|
| 194 |
dropout_p=0,
|
| 195 |
).transpose(1, 2)
|
| 196 |
|