File size: 695 Bytes
			
			| 39b4aba dd2a3b2 39b4aba dd2a3b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | from typing import Optional
import torch
from ._ops import ops
def mha_fwd(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    out: torch.Tensor,
    alibi_slopes: torch.Tensor,
    p_dropout: float,
    softmax_scale: float,
    is_causal: bool,
    window_size_left: int,
    window_size_right: int,
    softcap: float,
    return_softmax: bool,
    gen: Optional[torch.Generator],
) -> torch.Tensor:
    ops.mha_fwd(
        q,
        k,
        v,
        out,
        alibi_slopes,
        p_dropout,
        softmax_scale,
        is_causal,
        window_size_left,
        window_size_right,
        softcap,
        return_softmax,
        gen,
    )
    return out | 
