|
from typing import Optional |
|
|
|
import torch |
|
|
|
from ._ops import ops |
|
|
|
def mha_fwd( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
out: torch.Tensor, |
|
alibi_slopes: torch.Tensor, |
|
p_dropout: float, |
|
softmax_scale: float, |
|
is_causal: bool, |
|
window_size_left: int, |
|
window_size_right: int, |
|
softcap: float, |
|
return_softmax: bool, |
|
gen: Optional[torch.Generator], |
|
) -> torch.Tensor: |
|
ops.mha_fwd( |
|
q, |
|
k, |
|
v, |
|
out, |
|
alibi_slopes, |
|
p_dropout, |
|
softmax_scale, |
|
is_causal, |
|
window_size_left, |
|
window_size_right, |
|
softcap, |
|
return_softmax, |
|
gen, |
|
) |
|
return out |