Update modeling_rwkv6qwen2.py
Browse filesadded check for fla import requirement
- modeling_rwkv6qwen2.py +8 -2
modeling_rwkv6qwen2.py
CHANGED
|
@@ -204,8 +204,14 @@ class RWKV6State(Cache):
|
|
| 204 |
# self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
| 205 |
# self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
| 206 |
|
| 207 |
-
|
| 208 |
-
from fla.ops.gla.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
class RWKV6Attention(nn.Module):
|
| 211 |
def __init__(self, config, layer_idx: Optional[int] = None):
|
|
|
|
| 204 |
# self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
| 205 |
# self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
| 206 |
|
| 207 |
+
try:
|
| 208 |
+
#from fla.ops.gla.chunk import chunk_gla
|
| 209 |
+
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
| 210 |
+
except ImportError:
|
| 211 |
+
print("Required module is not installed. Please install it using the following commands:")
|
| 212 |
+
print("pip install -U git+https://github.com/sustcsonglin/flash-linear-attention")
|
| 213 |
+
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
| 214 |
+
print("pip install triton>=2.2.0")
|
| 215 |
|
| 216 |
class RWKV6Attention(nn.Module):
|
| 217 |
def __init__(self, config, layer_idx: Optional[int] = None):
|