Commit
·
71aae6d
1
Parent(s):
8771224
fix: handle window_size passed as list
Browse files
mha.py
CHANGED
|
@@ -514,6 +514,10 @@ class MHA(nn.Module):
|
|
| 514 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
| 515 |
else:
|
| 516 |
alibi_slopes = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
if window_size != (-1, -1):
|
| 518 |
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 519 |
|
|
|
|
| 514 |
alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device)
|
| 515 |
else:
|
| 516 |
alibi_slopes = None
|
| 517 |
+
|
| 518 |
+
if isinstance(window_size, list):
|
| 519 |
+
window_size = tuple(window_size)
|
| 520 |
+
|
| 521 |
if window_size != (-1, -1):
|
| 522 |
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
| 523 |
|