| from typing import Optional | |
| import torch | |
| from ._ops import ops | |
| def mha_fwd( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| out: torch.Tensor, | |
| alibi_slopes: torch.Tensor, | |
| p_dropout: float, | |
| softmax_scale: float, | |
| is_causal: bool, | |
| window_size_left: int, | |
| window_size_right: int, | |
| softcap: float, | |
| return_softmax: bool, | |
| gen: Optional[torch.Generator], | |
| ) -> torch.Tensor: | |
| return ops.mha_fwd( | |
| q, | |
| k, | |
| v, | |
| out, | |
| alibi_slopes, | |
| p_dropout, | |
| softmax_scale, | |
| is_causal, | |
| window_size_left, | |
| window_size_right, | |
| softcap, | |
| return_softmax, | |
| gen, | |
| ) | |
| return out | |