x54-729
commited on
Commit
·
c9c773f
1
Parent(s):
f739ec7
fix flash attention import
Browse files- configuration_internlm2.py +9 -2
- modeling_internlm2.py +4 -2
configuration_internlm2.py
CHANGED
|
@@ -169,5 +169,12 @@ class InternLM2Config(PretrainedConfig):
|
|
| 169 |
raise ValueError(
|
| 170 |
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
| 171 |
)
|
| 172 |
-
if
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
raise ValueError(
|
| 170 |
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
| 171 |
)
|
| 172 |
+
if (
|
| 173 |
+
rope_scaling_factor is None
|
| 174 |
+
or not isinstance(rope_scaling_factor, (float, int))
|
| 175 |
+
or rope_scaling_factor < 1.0
|
| 176 |
+
):
|
| 177 |
+
raise ValueError(
|
| 178 |
+
f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
|
| 179 |
+
f"of type {type(rope_scaling_factor)}"
|
| 180 |
+
)
|
modeling_internlm2.py
CHANGED
|
@@ -40,7 +40,6 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
| 40 |
from transformers.utils import (
|
| 41 |
add_start_docstrings,
|
| 42 |
add_start_docstrings_to_model_forward,
|
| 43 |
-
is_flash_attn_2_available,
|
| 44 |
is_flash_attn_greater_or_equal_2_10,
|
| 45 |
logging,
|
| 46 |
replace_return_docstrings,
|
|
@@ -53,9 +52,12 @@ except Exception:
|
|
| 53 |
|
| 54 |
from .configuration_internlm2 import InternLM2Config
|
| 55 |
|
| 56 |
-
|
|
|
|
| 57 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 58 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
logger = logging.get_logger(__name__)
|
|
|
|
| 40 |
from transformers.utils import (
|
| 41 |
add_start_docstrings,
|
| 42 |
add_start_docstrings_to_model_forward,
|
|
|
|
| 43 |
is_flash_attn_greater_or_equal_2_10,
|
| 44 |
logging,
|
| 45 |
replace_return_docstrings,
|
|
|
|
| 52 |
|
| 53 |
from .configuration_internlm2 import InternLM2Config
|
| 54 |
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 58 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
| 59 |
+
except:
|
| 60 |
+
pass
|
| 61 |
|
| 62 |
|
| 63 |
logger = logging.get_logger(__name__)
|