x54-729
commited on
Commit
·
e04e906
1
Parent(s):
5a485b0
replace get_max_length
Browse files- modeling_internlm2.py +3 -3
modeling_internlm2.py
CHANGED
|
@@ -1077,7 +1077,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
| 1077 |
min_dtype = torch.finfo(dtype).min
|
| 1078 |
sequence_length = input_tensor.shape[1]
|
| 1079 |
if using_static_cache:
|
| 1080 |
-
target_length = past_key_values.
|
| 1081 |
else:
|
| 1082 |
target_length = (
|
| 1083 |
attention_mask.shape[-1]
|
|
@@ -1266,8 +1266,8 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
|
|
| 1266 |
if isinstance(past_key_values, Cache):
|
| 1267 |
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
| 1268 |
max_cache_length = (
|
| 1269 |
-
torch.tensor(past_key_values.
|
| 1270 |
-
if past_key_values.
|
| 1271 |
else None
|
| 1272 |
)
|
| 1273 |
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
|
|
|
| 1077 |
min_dtype = torch.finfo(dtype).min
|
| 1078 |
sequence_length = input_tensor.shape[1]
|
| 1079 |
if using_static_cache:
|
| 1080 |
+
target_length = past_key_values.get_max_cache_shape()
|
| 1081 |
else:
|
| 1082 |
target_length = (
|
| 1083 |
attention_mask.shape[-1]
|
|
|
|
| 1266 |
if isinstance(past_key_values, Cache):
|
| 1267 |
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
| 1268 |
max_cache_length = (
|
| 1269 |
+
torch.tensor(past_key_values.get_max_cache_shape(), device=input_ids.device)
|
| 1270 |
+
if past_key_values.get_max_cache_shape() is not None
|
| 1271 |
else None
|
| 1272 |
)
|
| 1273 |
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|