add sanity checks
Browse files- README.md +2 -2
- custom_generate/generate.py +51 -31
README.md
CHANGED
|
@@ -21,8 +21,8 @@ This implementation should match the `SinkCache` class present in `transformers<
|
|
| 21 |
|
| 22 |
|
| 23 |
## Additional Arguments
|
| 24 |
-
- `window_length` (`int`, defaults to
|
| 25 |
-
- `num_sink_tokens` (`int`, defaults to
|
| 26 |
|
| 27 |
|
| 28 |
## Output Type changes
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
## Additional Arguments
|
| 24 |
+
- `window_length` (`int`, *optional*, defaults to 256): The length of the context window.
|
| 25 |
+
- `num_sink_tokens` (`int`, *optional*, defaults to 4): The number of sink tokens. See the original paper for more information.
|
| 26 |
|
| 27 |
|
| 28 |
## Output Type changes
|
custom_generate/generate.py
CHANGED
|
@@ -1,11 +1,18 @@
|
|
| 1 |
import torch
|
| 2 |
from typing import Any, Dict, List, Optional, Tuple
|
| 3 |
-
from transformers.utils import logging
|
| 4 |
-
from transformers.cache_utils import Cache
|
| 5 |
|
| 6 |
-
|
| 7 |
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
class SinkCache(Cache):
|
| 10 |
"""
|
| 11 |
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
|
@@ -15,28 +22,13 @@ class SinkCache(Cache):
|
|
| 15 |
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
| 16 |
`[batch_size, num_heads, seq_len, head_dim]`.
|
| 17 |
|
|
|
|
|
|
|
| 18 |
Parameters:
|
| 19 |
window_length (`int`):
|
| 20 |
The length of the context window.
|
| 21 |
num_sink_tokens (`int`):
|
| 22 |
The number of sink tokens. See the original paper for more information.
|
| 23 |
-
|
| 24 |
-
Example:
|
| 25 |
-
|
| 26 |
-
```python
|
| 27 |
-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
|
| 28 |
-
|
| 29 |
-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
| 30 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
| 31 |
-
|
| 32 |
-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
| 33 |
-
|
| 34 |
-
>>> # Prepare a cache class and pass it to model's forward
|
| 35 |
-
>>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
|
| 36 |
-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
| 37 |
-
>>> outputs.past_key_values # access cache filled with key/values from generation
|
| 38 |
-
SinkCache()
|
| 39 |
-
```
|
| 40 |
"""
|
| 41 |
|
| 42 |
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
|
@@ -48,7 +40,6 @@ class SinkCache(Cache):
|
|
| 48 |
self.cos_sin_rerotation_cache = {}
|
| 49 |
self._cos_cache = None
|
| 50 |
self._sin_cache = None
|
| 51 |
-
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
| 52 |
|
| 53 |
@staticmethod
|
| 54 |
def _rotate_half(x):
|
|
@@ -86,8 +77,6 @@ class SinkCache(Cache):
|
|
| 86 |
|
| 87 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 88 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
| 89 |
-
# TODO: deprecate this function in favor of `cache_position`
|
| 90 |
-
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|
| 91 |
if len(self.key_cache) <= layer_idx:
|
| 92 |
return 0
|
| 93 |
return self.key_cache[layer_idx].shape[-2]
|
|
@@ -130,10 +119,6 @@ class SinkCache(Cache):
|
|
| 130 |
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
| 131 |
using_rope = cos is not None and sin is not None
|
| 132 |
|
| 133 |
-
# Update the number of seen tokens
|
| 134 |
-
if layer_idx == 0:
|
| 135 |
-
self._seen_tokens += key_states.shape[-2]
|
| 136 |
-
|
| 137 |
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
| 138 |
if using_rope and layer_idx == 0:
|
| 139 |
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
|
@@ -194,17 +179,52 @@ class SinkCache(Cache):
|
|
| 194 |
|
| 195 |
|
| 196 |
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
kwargs.pop("custom_generate", None)
|
| 200 |
|
| 201 |
-
#
|
|
|
|
| 202 |
past_key_values = kwargs.pop("past_key_values", None)
|
| 203 |
if past_key_values is None:
|
| 204 |
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
| 205 |
elif not isinstance(past_key_values, SinkCache):
|
| 206 |
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
| 207 |
|
| 208 |
-
# generate with the cache
|
| 209 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
| 210 |
return generation_outputs
|
|
|
|
| 1 |
import torch
|
| 2 |
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
from transformers import Cache, GenerationConfig
|
| 5 |
|
| 6 |
|
| 7 |
+
UNSUPPORTED_GENERATION_ARGS = [
|
| 8 |
+
"cache_implementation", # cache-related arguments, here we always use SinkCache
|
| 9 |
+
"cache_config",
|
| 10 |
+
"return_legacy_cache",
|
| 11 |
+
"num_beams", # beam search (and cousin techniques) are not supported
|
| 12 |
+
"compile_config", # SinkCache doesn't support torch.compile
|
| 13 |
+
"assistant_model", # it also doesn't support speculative decoding
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
class SinkCache(Cache):
|
| 17 |
"""
|
| 18 |
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
|
|
|
|
| 22 |
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
| 23 |
`[batch_size, num_heads, seq_len, head_dim]`.
|
| 24 |
|
| 25 |
+
This class was copied from transformers 4.52.0, with minor modifications.
|
| 26 |
+
|
| 27 |
Parameters:
|
| 28 |
window_length (`int`):
|
| 29 |
The length of the context window.
|
| 30 |
num_sink_tokens (`int`):
|
| 31 |
The number of sink tokens. See the original paper for more information.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"""
|
| 33 |
|
| 34 |
def __init__(self, window_length: int, num_sink_tokens: int) -> None:
|
|
|
|
| 40 |
self.cos_sin_rerotation_cache = {}
|
| 41 |
self._cos_cache = None
|
| 42 |
self._sin_cache = None
|
|
|
|
| 43 |
|
| 44 |
@staticmethod
|
| 45 |
def _rotate_half(x):
|
|
|
|
| 77 |
|
| 78 |
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 79 |
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
|
|
|
|
|
| 80 |
if len(self.key_cache) <= layer_idx:
|
| 81 |
return 0
|
| 82 |
return self.key_cache[layer_idx].shape[-2]
|
|
|
|
| 119 |
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
|
| 120 |
using_rope = cos is not None and sin is not None
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
# Update the sin/cos cache, which holds sin/cos values for all possible positions
|
| 123 |
if using_rope and layer_idx == 0:
|
| 124 |
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
|
|
|
|
| 179 |
|
| 180 |
|
| 181 |
def generate(model, window_length=256, num_sink_tokens=4, **kwargs):
|
| 182 |
+
"""Custom generate function for SinkCache.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
model (`PreTrainedModel`):
|
| 186 |
+
The model to generate from.
|
| 187 |
+
window_length (`int`, *optional*, defaults to 256):
|
| 188 |
+
The length of the context window.
|
| 189 |
+
num_sink_tokens (`int`, *optional*, defaults to 4):
|
| 190 |
+
The number of sink tokens. See the original paper for more information.
|
| 191 |
+
"""
|
| 192 |
+
# 1. General sanity checks
|
| 193 |
+
# 1.a. A few arguments are not allowed, especially arguments that control caches.
|
| 194 |
+
generation_config = kwargs.get("generation_config")
|
| 195 |
+
default_global_generation_config = GenerationConfig()
|
| 196 |
+
default_model_generation_config = model.generation_config
|
| 197 |
+
for arg in UNSUPPORTED_GENERATION_ARGS:
|
| 198 |
+
has_custom_gen_config_arg = (
|
| 199 |
+
generation_config is not None
|
| 200 |
+
# = and not (match global default or match model-specific default)
|
| 201 |
+
and not (
|
| 202 |
+
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
|
| 203 |
+
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
| 204 |
+
)
|
| 205 |
+
)
|
| 206 |
+
if arg in kwargs or has_custom_gen_config_arg:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
| 209 |
+
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# 1.b. The model must be decoder-only
|
| 213 |
+
if model.config.is_encoder_decoder:
|
| 214 |
+
raise ValueError("This custom generate function only works with decoder-only models")
|
| 215 |
+
|
| 216 |
+
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
|
| 217 |
+
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
|
| 218 |
kwargs.pop("custom_generate", None)
|
| 219 |
|
| 220 |
+
# 2. Generate with SinkCache
|
| 221 |
+
# 2.a. prepare the cache, if it was not passed.
|
| 222 |
past_key_values = kwargs.pop("past_key_values", None)
|
| 223 |
if past_key_values is None:
|
| 224 |
past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens)
|
| 225 |
elif not isinstance(past_key_values, SinkCache):
|
| 226 |
raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance")
|
| 227 |
|
| 228 |
+
# 2.b. generate with the cache
|
| 229 |
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
| 230 |
return generation_outputs
|