| from typing import Optional, List | |
| import torch | |
| from ._ops import ops | |
| def mha_fwd( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| out: Optional[torch.Tensor] = None, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| p_dropout: float = 0.0, | |
| softmax_scale: float = 1.0, | |
| is_causal: bool = False, | |
| window_size_left: int = -1, | |
| window_size_right: int = -1, | |
| softcap: float = 0.0, | |
| return_softmax: bool = False, | |
| gen: Optional[torch.Generator] = None, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Forward pass for multi-head attention. | |
| Args: | |
| q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] | |
| k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] | |
| v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] | |
| out: Optional output tensor, same shape as q | |
| alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] | |
| p_dropout: Dropout probability | |
| softmax_scale: Scale factor for softmax | |
| is_causal: Whether to use causal attention | |
| window_size_left: Window size for left context (-1 for unlimited) | |
| window_size_right: Window size for right context (-1 for unlimited) | |
| softcap: Soft cap for attention weights | |
| return_softmax: Whether to return softmax weights | |
| gen: Optional random number generator | |
| Returns: | |
| List of tensors: [output, softmax_lse, (softmax if return_softmax)] | |
| """ | |
| return ops.mha_fwd( | |
| q, | |
| k, | |
| v, | |
| out, | |
| alibi_slopes, | |
| p_dropout, | |
| softmax_scale, | |
| is_causal, | |
| window_size_left, | |
| window_size_right, | |
| softcap, | |
| return_softmax, | |
| gen, | |
| ) | |
| def mha_varlen_fwd( | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| out: Optional[torch.Tensor] = None, | |
| seqused_k: Optional[torch.Tensor] = None, | |
| leftpad_k: Optional[torch.Tensor] = None, | |
| block_table: Optional[torch.Tensor] = None, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| max_seqlen_q: int = 0, | |
| max_seqlen_k: int = 0, | |
| p_dropout: float = 0.0, | |
| softmax_scale: float = 1.0, | |
| zero_tensors: bool = False, | |
| is_causal: bool = False, | |
| window_size_left: int = -1, | |
| window_size_right: int = -1, | |
| softcap: float = 0.0, | |
| return_softmax: bool = False, | |
| gen: Optional[torch.Generator] = None, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Forward pass for multi-head attention with variable sequence lengths. | |
| Args: | |
| q: Query tensor of shape [total_q, num_heads, head_size] | |
| k: Key tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] | |
| v: Value tensor of shape [total_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] | |
| cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] | |
| cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] | |
| out: Optional output tensor of shape [total_q, num_heads, head_size] | |
| seqused_k: Optional tensor specifying how many keys to use per batch element [batch_size] | |
| leftpad_k: Optional left padding for keys of shape [batch_size] | |
| block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] | |
| alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] | |
| max_seqlen_q: Maximum sequence length for queries | |
| max_seqlen_k: Maximum sequence length for keys | |
| p_dropout: Dropout probability | |
| softmax_scale: Scale factor for softmax | |
| zero_tensors: Whether to zero tensors before computation | |
| is_causal: Whether to use causal attention | |
| window_size_left: Window size for left context (-1 for unlimited) | |
| window_size_right: Window size for right context (-1 for unlimited) | |
| softcap: Soft cap for attention weights | |
| return_softmax: Whether to return softmax weights | |
| gen: Optional random number generator | |
| Returns: | |
| List of tensors: [output, softmax_lse, (softmax if return_softmax)] | |
| """ | |
| return ops.mha_varlen_fwd( | |
| q, | |
| k, | |
| v, | |
| out, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| seqused_k, | |
| leftpad_k, | |
| block_table, | |
| alibi_slopes, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| p_dropout, | |
| softmax_scale, | |
| zero_tensors, | |
| is_causal, | |
| window_size_left, | |
| window_size_right, | |
| softcap, | |
| return_softmax, | |
| gen, | |
| ) | |
| def mha_bwd( | |
| dout: torch.Tensor, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| out: torch.Tensor, | |
| softmax_lse: torch.Tensor, | |
| dq: Optional[torch.Tensor] = None, | |
| dk: Optional[torch.Tensor] = None, | |
| dv: Optional[torch.Tensor] = None, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| p_dropout: float = 0.0, | |
| softmax_scale: float = 1.0, | |
| is_causal: bool = False, | |
| window_size_left: int = -1, | |
| window_size_right: int = -1, | |
| softcap: float = 0.0, | |
| deterministic: bool = False, | |
| gen: Optional[torch.Generator] = None, | |
| rng_state: Optional[torch.Tensor] = None, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Backward pass for multi-head attention. | |
| Args: | |
| dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] | |
| q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] | |
| k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] | |
| v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] | |
| out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] | |
| softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] | |
| dq: Optional gradient tensor for queries, same shape as q | |
| dk: Optional gradient tensor for keys, same shape as k | |
| dv: Optional gradient tensor for values, same shape as v | |
| alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] | |
| p_dropout: Dropout probability | |
| softmax_scale: Scale factor for softmax | |
| is_causal: Whether to use causal attention | |
| window_size_left: Window size for left context (-1 for unlimited) | |
| window_size_right: Window size for right context (-1 for unlimited) | |
| softcap: Soft cap for attention weights | |
| deterministic: Whether to use deterministic algorithms | |
| gen: Optional random number generator | |
| rng_state: Optional RNG state from forward pass | |
| Returns: | |
| List of tensors: [dq, dk, dv] | |
| """ | |
| return ops.mha_bwd( | |
| dout, | |
| q, | |
| k, | |
| v, | |
| out, | |
| softmax_lse, | |
| dq, | |
| dk, | |
| dv, | |
| alibi_slopes, | |
| p_dropout, | |
| softmax_scale, | |
| is_causal, | |
| window_size_left, | |
| window_size_right, | |
| softcap, | |
| deterministic, | |
| gen, | |
| rng_state, | |
| ) | |
| def mha_varlen_bwd( | |
| dout: torch.Tensor, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| out: torch.Tensor, | |
| softmax_lse: torch.Tensor, | |
| cu_seqlens_q: torch.Tensor, | |
| cu_seqlens_k: torch.Tensor, | |
| dq: Optional[torch.Tensor] = None, | |
| dk: Optional[torch.Tensor] = None, | |
| dv: Optional[torch.Tensor] = None, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| max_seqlen_q: int = 0, | |
| max_seqlen_k: int = 0, | |
| p_dropout: float = 0.0, | |
| softmax_scale: float = 1.0, | |
| zero_tensors: bool = False, | |
| is_causal: bool = False, | |
| window_size_left: int = -1, | |
| window_size_right: int = -1, | |
| softcap: float = 0.0, | |
| deterministic: bool = False, | |
| gen: Optional[torch.Generator] = None, | |
| rng_state: Optional[torch.Tensor] = None, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Backward pass for multi-head attention with variable sequence lengths. | |
| Args: | |
| dout: Gradient tensor of shape [batch_size, seqlen_q, num_heads, head_size] | |
| q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] | |
| k: Key tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] | |
| v: Value tensor of shape [batch_size, seqlen_k, num_heads_k, head_size] | |
| out: Output tensor from forward pass of shape [batch_size, seqlen_q, num_heads, head_size] | |
| softmax_lse: Log-sum-exp values from forward pass of shape [batch_size, num_heads, seqlen_q] | |
| cu_seqlens_q: Cumulative sequence lengths for queries of shape [batch_size+1] | |
| cu_seqlens_k: Cumulative sequence lengths for keys of shape [batch_size+1] | |
| dq: Optional gradient tensor for queries, same shape as q | |
| dk: Optional gradient tensor for keys, same shape as k | |
| dv: Optional gradient tensor for values, same shape as v | |
| alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] | |
| max_seqlen_q: Maximum sequence length for queries | |
| max_seqlen_k: Maximum sequence length for keys | |
| p_dropout: Dropout probability | |
| softmax_scale: Scale factor for softmax | |
| zero_tensors: Whether to zero tensors before computation | |
| is_causal: Whether to use causal attention | |
| window_size_left: Window size for left context (-1 for unlimited) | |
| window_size_right: Window size for right context (-1 for unlimited) | |
| softcap: Soft cap for attention weights | |
| deterministic: Whether to use deterministic algorithms | |
| gen: Optional random number generator | |
| rng_state: Optional RNG state from forward pass | |
| Returns: | |
| List of tensors: [dq, dk, dv] | |
| """ | |
| return ops.mha_varlen_bwd( | |
| dout, | |
| q, | |
| k, | |
| v, | |
| out, | |
| softmax_lse, | |
| dq, | |
| dk, | |
| dv, | |
| cu_seqlens_q, | |
| cu_seqlens_k, | |
| alibi_slopes, | |
| max_seqlen_q, | |
| max_seqlen_k, | |
| p_dropout, | |
| softmax_scale, | |
| zero_tensors, | |
| is_causal, | |
| window_size_left, | |
| window_size_right, | |
| softcap, | |
| deterministic, | |
| gen, | |
| rng_state, | |
| ) | |
| def mha_fwd_kvcache( | |
| q: torch.Tensor, | |
| kcache: torch.Tensor, | |
| vcache: torch.Tensor, | |
| k: Optional[torch.Tensor] = None, | |
| v: Optional[torch.Tensor] = None, | |
| seqlens_k: Optional[torch.Tensor] = None, | |
| rotary_cos: Optional[torch.Tensor] = None, | |
| rotary_sin: Optional[torch.Tensor] = None, | |
| cache_batch_idx: Optional[torch.Tensor] = None, | |
| leftpad_k: Optional[torch.Tensor] = None, | |
| block_table: Optional[torch.Tensor] = None, | |
| alibi_slopes: Optional[torch.Tensor] = None, | |
| out: Optional[torch.Tensor] = None, | |
| softmax_scale: float = 1.0, | |
| is_causal: bool = False, | |
| window_size_left: int = -1, | |
| window_size_right: int = -1, | |
| softcap: float = 0.0, | |
| is_rotary_interleaved: bool = False, | |
| num_splits: int = 1, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Forward pass for multi-head attention with KV cache. | |
| Args: | |
| q: Query tensor of shape [batch_size, seqlen_q, num_heads, head_size] | |
| kcache: Key cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] | |
| vcache: Value cache tensor of shape [batch_size_c, seqlen_k, num_heads_k, head_size] or [num_blocks, page_block_size, num_heads_k, head_size] | |
| k: Optional new keys tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] | |
| v: Optional new values tensor of shape [batch_size, seqlen_knew, num_heads_k, head_size] | |
| seqlens_k: Optional sequence lengths for keys of shape [batch_size] | |
| rotary_cos: Optional rotary cosine tensor of shape [seqlen_ro, rotary_dim/2] | |
| rotary_sin: Optional rotary sine tensor of shape [seqlen_ro, rotary_dim/2] | |
| cache_batch_idx: Optional indices to index into the KV cache | |
| leftpad_k: Optional left padding for keys of shape [batch_size] | |
| block_table: Optional block table of shape [batch_size, max_num_blocks_per_seq] | |
| alibi_slopes: Optional ALiBi slopes tensor of shape [num_heads] or [batch_size, num_heads] | |
| out: Optional output tensor, same shape as q | |
| softmax_scale: Scale factor for softmax | |
| is_causal: Whether to use causal attention | |
| window_size_left: Window size for left context (-1 for unlimited) | |
| window_size_right: Window size for right context (-1 for unlimited) | |
| softcap: Soft cap for attention weights | |
| is_rotary_interleaved: Whether rotary embeddings are interleaved | |
| num_splits: Number of splits for computation | |
| Returns: | |
| List of tensors: [output, softmax_lse] | |
| """ | |
| return ops.mha_fwd_kvcache( | |
| q, | |
| kcache, | |
| vcache, | |
| k, | |
| v, | |
| seqlens_k, | |
| rotary_cos, | |
| rotary_sin, | |
| cache_batch_idx, | |
| leftpad_k, | |
| block_table, | |
| alibi_slopes, | |
| out, | |
| softmax_scale, | |
| is_causal, | |
| window_size_left, | |
| window_size_right, | |
| softcap, | |
| is_rotary_interleaved, | |
| num_splits, | |
| ) | |