updt flash_attn_triton import (#37)
Browse files- updt flash_attn_triton import (98fbef4eccfacd39b03777cbb212c60508d48594)
Co-authored-by: Vitaliy Chiley <[email protected]>
- attention.py +12 -3
attention.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Optional
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from einops import rearrange
|
|
|
|
| 8 |
from torch import nn
|
| 9 |
from .norm import LPLayerNorm
|
| 10 |
|
|
@@ -87,9 +88,17 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
|
|
| 87 |
|
| 88 |
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
| 89 |
try:
|
| 90 |
-
from
|
| 91 |
except:
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
check_valid_inputs(query, key, value)
|
| 94 |
if dropout_p:
|
| 95 |
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
|
@@ -108,7 +117,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
|
|
| 108 |
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
| 109 |
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
| 110 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 111 |
-
attn_output =
|
| 112 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
| 113 |
return (output, None)
|
| 114 |
|
|
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from einops import rearrange
|
| 8 |
+
from packaging import version
|
| 9 |
from torch import nn
|
| 10 |
from .norm import LPLayerNorm
|
| 11 |
|
|
|
|
| 88 |
|
| 89 |
def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
|
| 90 |
try:
|
| 91 |
+
from .flash_attn_triton import flash_attn_func
|
| 92 |
except:
|
| 93 |
+
_installed = False
|
| 94 |
+
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
| 95 |
+
_installed = True
|
| 96 |
+
try:
|
| 97 |
+
from flash_attn.flash_attn_triton import flash_attn_func
|
| 98 |
+
except:
|
| 99 |
+
_installed = False
|
| 100 |
+
if not _installed:
|
| 101 |
+
raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
|
| 102 |
check_valid_inputs(query, key, value)
|
| 103 |
if dropout_p:
|
| 104 |
raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
|
|
|
|
| 117 |
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
|
| 118 |
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
|
| 119 |
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
|
| 120 |
+
attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
|
| 121 |
output = attn_output.view(*attn_output.shape[:2], -1)
|
| 122 |
return (output, None)
|
| 123 |
|