| # Flash Attention | |
| Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention. | |
| Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention). | |
| ```python | |
| # /// script | |
| # dependencies = ["numpy", "torch", "kernels"] | |
| # /// | |
| import torch | |
| from kernels import get_kernel | |
| # Setup | |
| torch.manual_seed(42) | |
| flash_attn = get_kernel("kernels-community/flash-attn") | |
| device = torch.device("cuda") | |
| # Show available functions | |
| print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")]) | |
| # 1. Standard attention | |
| print("\n1. Standard attention:") | |
| B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim | |
| q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16) | |
| out = flash_attn.mha_fwd(q=q, k=k, v=v, is_causal=False)[0] | |
| print(f"Output: {out.shape}") | |
| # 2. Variable length sequences | |
| print("\n2. Variable length sequences:") | |
| q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10 | |
| k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12 | |
| # For 3 sequences with lengths [3,4,3] for q and [4,5,3] for k | |
| cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32) | |
| cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32) | |
| out_var = flash_attn.mha_varlen_fwd( | |
| q=q_var, | |
| k=k_var, | |
| v=v_var, | |
| cu_seqlens_q=cu_q, | |
| cu_seqlens_k=cu_k, | |
| max_seqlen_q=4, | |
| max_seqlen_k=5, | |
| )[0] | |
| print(f"Output: {out_var.shape}") | |
| # 3. KV-cache for autoregressive generation | |
| print("\n3. KV-cache:") | |
| cache_len, new_len = 10, 2 | |
| kcache = vcache = torch.randn(B, cache_len, H, D, device=device, dtype=torch.float16) | |
| q_new = k_new = v_new = torch.randn( | |
| B, new_len, H, D, device=device, dtype=torch.float16 | |
| ) | |
| seqlens = torch.full((B,), cache_len + new_len, device=device, dtype=torch.int32) | |
| out_kv = flash_attn.mha_fwd_kvcache( | |
| q=q_new, | |
| kcache=kcache, | |
| vcache=vcache, | |
| k=k_new, | |
| v=v_new, | |
| seqlens_k=seqlens, | |
| is_causal=True, | |
| )[0] | |
| print(f"Output: {out_kv.shape}") | |
| ``` | |
| expected output | |
| ```txt | |
| Fetching 3 files: 100%|█████████████████████████████████████████████████████| 3/3 [00:00<00:00, 16384.00it/s] | |
| Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd'] | |
| 1. Standard attention: | |
| Output: torch.Size([2, 5, 4, 8]) | |
| 2. Variable length sequences: | |
| Output: torch.Size([10, 4, 8]) | |
| 3. KV-cache: | |
| Output: torch.Size([2, 2, 4, 8]) | |
| ``` |