Gausson commited on
Commit
01f4d5b
·
verified ·
1 Parent(s): 92c9906

Upload 5 files

Browse files
custom_generate/generate.py ADDED
The diff for this file is too large to render. See raw diff
 
custom_generate_split_4_backup/functions_2_patch.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import inspect
3
+ import importlib
4
+
5
+ from typing import Callable, Optional, Union, Any, List
6
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
7
+ from transformers.cache_utils import Cache
8
+ from transformers.processing_utils import Unpack
9
+
10
+ from .sep_cache_utils import SepCache
11
+
12
+
13
+
14
+ def truncate_input_ids_4_autoregression(input_ids, key_states):
15
+ if input_ids.shape[-1] != key_states.shape[-2]:
16
+ assert input_ids.shape[-1] >= key_states.shape[-2]
17
+ truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
18
+ return truncated_input_ids
19
+ else:
20
+ return input_ids
21
+
22
+ def llama_atten_forward(
23
+ self,
24
+ hidden_states: torch.Tensor,
25
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
26
+ attention_mask: Optional[torch.Tensor],
27
+ past_key_value: Optional[Cache] = None,
28
+ cache_position: Optional[torch.LongTensor] = None,
29
+ **kwargs: Unpack[FlashAttentionKwargs],
30
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
31
+ input_shape = hidden_states.shape[:-1]
32
+
33
+ if hasattr(self, "head_dim"):
34
+ head_dim = self.head_dim
35
+ elif hasattr(self, "head_size"):
36
+ head_dim = self.head_size
37
+
38
+ hidden_shape = (*input_shape, -1, head_dim)
39
+
40
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
41
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
42
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
43
+
44
+
45
+ ###########################SepCache########################
46
+ assert isinstance(past_key_value, SepCache), f"`past_key_value` must be of the type: `SepCache`."
47
+ APPLY_PE_SHIFT = past_key_value.APPLY_PE_SHIFT
48
+ APPLY_PES_INSIDE = past_key_value.APPLY_PES_INSIDE
49
+ ###########################################################
50
+
51
+
52
+ ########################Monkey Patching####################
53
+ module = importlib.import_module(self.__module__)
54
+
55
+ apply_rotary_pos_emb = module.apply_rotary_pos_emb
56
+ rotate_half = module.rotate_half
57
+ eager_attention_forward = module.eager_attention_forward
58
+ ALL_ATTENTION_FUNCTIONS = module.ALL_ATTENTION_FUNCTIONS
59
+ ###########################################################
60
+
61
+ if not APPLY_PE_SHIFT:
62
+ cos, sin = position_embeddings
63
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
64
+
65
+ if past_key_value is not None:
66
+ # ##################################################Default#########################################################
67
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
68
+ # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
69
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
70
+ # ##################################################################################################################
71
+
72
+ ##################################################SepCache#########################################################
73
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
74
+ if APPLY_PE_SHIFT and (not APPLY_PES_INSIDE):
75
+ ### At least the shifted `sin` and `cos` should be properly provided (not `None`).
76
+ cache_kwargs = {"sin": sin, "cos": cos, "cos_q": cos_q, "sin_q": sin_q, "cache_position": cache_position, "partial_rotation_size": None }
77
+ else:
78
+ cache_kwargs = {}
79
+
80
+
81
+ if "kwargs" in locals():
82
+ pass
83
+ elif "flash_attn_kwargs" in locals():
84
+ kwargs = flash_attn_kwargs
85
+ else:
86
+ raise NameError("`kwargs` or `flash_attn_kwargs` should be given and they need to contain `sepllm_kwargs` (which contains `input_ids`) and `position_ids`.")
87
+
88
+ if "input_ids" not in locals():
89
+ if "input_ids" in kwargs:
90
+ input_ids = kwargs.get("input_ids", None)
91
+ else:
92
+ sepllm_kwargs = kwargs.get("sepllm_kwargs", None)
93
+ assert sepllm_kwargs is not None, f"`sepllm_kwargs` must be provided when `input_ids` is not given."
94
+ input_ids = sepllm_kwargs.get("input_ids", None)
95
+
96
+ assert input_ids is not None, f"`input_ids` must be properly provided directly or through `sepllm_kwargs` when calling `update()` in `SepCache`."
97
+
98
+ if "position_ids" not in locals():
99
+ position_ids = kwargs.get("position_ids")
100
+
101
+ assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`."
102
+ bsz, q_len, _ = hidden_states.size()
103
+
104
+ input_ids = truncate_input_ids_4_autoregression(input_ids = input_ids, key_states = key_states )
105
+
106
+ if APPLY_PE_SHIFT:
107
+ key_states, value_states, query_states = past_key_value.update(
108
+ key_states = key_states,
109
+ value_states = value_states,
110
+ query_states = query_states,
111
+ input_ids = input_ids,
112
+ layer_idx = self.layer_idx,
113
+ position_ids = position_ids,
114
+ PREFILLING_FLAG = q_len > 1,
115
+ cache_kwargs = cache_kwargs )
116
+
117
+ else:
118
+ key_states, value_states = past_key_value.update(
119
+ key_states = key_states,
120
+ value_states = value_states,
121
+ input_ids = input_ids,
122
+ layer_idx = self.layer_idx,
123
+ position_ids = position_ids,
124
+ PREFILLING_FLAG = q_len > 1,
125
+ cache_kwargs = cache_kwargs )
126
+
127
+ seq_len = past_key_value.get_usable_length(self.layer_idx)
128
+
129
+ if attention_mask is not None:
130
+ attention_mask = attention_mask[..., :seq_len]
131
+ ##################################################################################################################
132
+
133
+
134
+ attention_interface: Callable = eager_attention_forward
135
+ if self.config._attn_implementation != "eager":
136
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
137
+
138
+ attn_output, attn_weights = attention_interface(
139
+ self,
140
+ query_states,
141
+ key_states,
142
+ value_states,
143
+ attention_mask,
144
+ dropout=0.0 if not self.training else self.attention_dropout,
145
+ scaling=self.scaling,
146
+ **kwargs,
147
+ )
148
+
149
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
150
+ attn_output = self.o_proj(attn_output)
151
+ return attn_output, attn_weights
152
+
153
+
154
+
155
+
156
+
157
+
158
+ def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
159
+ """Validates model kwargs for generation. Generate argument typos will also be caught here."""
160
+ # If a `Cache` instance is passed, checks whether the model is compatible with it
161
+ if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
162
+ raise ValueError(
163
+ f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
164
+ "check the model documentation for supported cache formats."
165
+ )
166
+
167
+ # Excludes arguments that are handled before calling any model function
168
+ if self.config.is_encoder_decoder:
169
+ for key in ["decoder_input_ids"]:
170
+ model_kwargs.pop(key, None)
171
+
172
+ unused_model_args = []
173
+ model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
174
+ # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
175
+ # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
176
+ if "kwargs" in model_args or "model_kwargs" in model_args:
177
+ model_args |= set(inspect.signature(self.forward).parameters)
178
+
179
+ # Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
180
+ if self.config.is_encoder_decoder:
181
+ base_model = getattr(self, self.base_model_prefix, None)
182
+
183
+ # allow encoder kwargs
184
+ encoder = getattr(self, "encoder", None)
185
+ # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
186
+ # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
187
+ # TODO: A better way to handle this.
188
+ if encoder is None and base_model is not None:
189
+ encoder = getattr(base_model, "encoder", None)
190
+
191
+ if encoder is not None:
192
+ encoder_model_args = set(inspect.signature(encoder.forward).parameters)
193
+ model_args |= encoder_model_args
194
+
195
+ # allow decoder kwargs
196
+ decoder = getattr(self, "decoder", None)
197
+ if decoder is None and base_model is not None:
198
+ decoder = getattr(base_model, "decoder", None)
199
+
200
+ if decoder is not None:
201
+ decoder_model_args = set(inspect.signature(decoder.forward).parameters)
202
+ model_args |= {f"decoder_{x}" for x in decoder_model_args}
203
+
204
+ for key, value in model_kwargs.items():
205
+ # #############################Default###########################
206
+ # if value is not None and key not in model_args:
207
+ # unused_model_args.append(key)
208
+ # ###############################################################
209
+
210
+ ###############################SepCache###########################
211
+ if (value is not None) and (key not in model_args) and ("sep" not in str(key).lower()):
212
+ unused_model_args.append(key)
213
+ ###################################################################
214
+
215
+ if unused_model_args:
216
+ raise ValueError(
217
+ f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
218
+ " generate arguments will also show up in this list)"
219
+ )
220
+
221
+
custom_generate_split_4_backup/generate.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import types
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ import transformers
6
+ from transformers import Cache, GenerationConfig
7
+ import torch.nn as nn
8
+ from transformers.modeling_utils import PreTrainedModel
9
+
10
+ from .functions_2_patch import _validate_model_kwargs, llama_atten_forward
11
+ from .monkey_patching_utils import monkey_patching
12
+ from .sep_cache_utils import SepCache
13
+
14
+
15
+ UNSUPPORTED_GENERATION_ARGS = [
16
+ "cache_implementation", # cache-related arguments, here we always use SepCache
17
+ "cache_config",
18
+ "return_legacy_cache",
19
+ "num_beams", # beam search (and cousin techniques) are not supported
20
+ "compile_config", # SepCache doesn't support torch.compile
21
+ "assistant_model", # it also doesn't support speculative decoding
22
+ ]
23
+
24
+
25
+ def generate(model,
26
+ ## For SepCache
27
+ init_cache_size: Union[int, List] = 4,
28
+ sep_cache_size: Union[int, List] = 128,
29
+ local_size: Union[int, List]=256,
30
+ cache_size: Union[int, List]=512,
31
+ SEP_ACCUMULATION: bool = True,
32
+ USE_MAX_SEP_CACHE: bool = False,
33
+ SEP_PADDING_IN_BATCH: bool = False,
34
+ separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
35
+ PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
36
+
37
+ ## For inheritance & initialization states
38
+ past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
39
+ key_cache: List[torch.Tensor] = None,
40
+ value_cache: List[torch.Tensor] = None,
41
+
42
+ ## For debugging
43
+ PRINT_KV_RATIO_INSIDE: bool = False,
44
+ print_KV_inside_per_steps: int = 1000,
45
+ _seen_tokens: int = 0,
46
+ _kept_kv_ratio: List[Tuple[int]] = None,
47
+
48
+ ### For positional encoding shifting
49
+ APPLY_PE_SHIFT: bool = False,
50
+ APPLY_PES_INSIDE: bool = False,
51
+ _shifted_position_ids: List[torch.Tensor] = None,
52
+ _rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
53
+ _rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
54
+ pe_scaling_factor:float = 1.0,
55
+ pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
56
+ max_position_embeddings: int = 8192,
57
+ base: int=10000, ## The base for RoPE.
58
+
59
+ ## For basic transformer architecture
60
+ k_seq_dim: int=2, ## The dimension for seq_len in key tensors
61
+ v_seq_dim: int=2, ## The dimension for seq_len in value tensors
62
+ layer_num: int = None, ## required for initialization
63
+
64
+ model_type: str = 'llama', ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
65
+ device = None,
66
+
67
+ ## For verbosity of monkey patching
68
+ monkey_patch_verbose: bool = False,
69
+
70
+ **kwargs
71
+ ):
72
+ """Custom generate function for SepCache.
73
+
74
+ A cache as described in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094). In the training phase,
75
+ SepLLM condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the
76
+ corresponding SepCache only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
77
+
78
+ It stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
79
+ `[batch_size, num_heads, seq_len, head_dim]`.
80
+
81
+ Frequently-Used Parameters:
82
+
83
+ `init_cache_size: Union[int, List]`:
84
+ The maximum number of KVs to be stored for initial tokens.
85
+ In the paper, the hyperparameter `a` is an abbreviated alias for `self.init_cache_size`.
86
+
87
+ `sep_cache_size: Union[int, List]`:
88
+ The maximum number of KVs to be stored for separator tokens.
89
+ In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
90
+
91
+ `local_size: Union[int, List]`:
92
+ The maximum number of KVs to be stored for local tokens (i.e., sliding window).
93
+ In the paper, the hyperparameter `w` is an abbreviated alias for `self.local_size`.
94
+
95
+ `cache_size: Union[int, List]`:
96
+ The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
97
+ In the paper, the hyperparameter `c` is an abbreviated alias for `self.cache_size`.
98
+
99
+ Concerning these four parameters above:
100
+ When a list is passed (its length must be `layer_num`), it represents different values for each layer.
101
+ When an integer is passed, it means the setting is the same for all layers.
102
+
103
+
104
+ `USE_MAX_SEP_CACHE: bool`:
105
+ If True, it means we only keep at most `self.sep_cache_size` seperators' KVs.
106
+ If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs.
107
+ In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
108
+
109
+ `separator_token_ids: List[int]`:
110
+ The token ids of the separator tokens for the current model's tokenizer.
111
+ We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
112
+ to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
113
+
114
+ `PADDING_ID: int`:
115
+ The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
116
+
117
+ Important Note:
118
+ When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache.
119
+ However, you must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`.
120
+ Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.
121
+ To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
122
+ `init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025]
123
+ to leave room for `left_padding_offset`.
124
+
125
+ Please refer to the `__init__` function's comments for more details on the parameters.
126
+
127
+ Example:
128
+
129
+ ```python
130
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM,
131
+ >>> from .sep_cache_utils import SepCache
132
+ >>> import torch
133
+ >>> from huggingface_hub import login
134
+ >>> login("hf_xxxXXXxxx")
135
+
136
+
137
+ >>> def to_cuda(a_dict: dict) -> dict:
138
+ >>> new_dict = {}
139
+ >>> for k,v in a_dict.items():
140
+ >>> if isinstance(v, torch.Tensor):
141
+ >>> new_dict[k] = v.cuda()
142
+ >>> else:
143
+ >>> new_dict[k] = v
144
+ >>> return new_dict
145
+
146
+ >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", attn_implementation="flash_attention_2", device_map="cuda:0")
147
+ >>> model.bfloat16().cuda()
148
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
149
+ >>> inputs = tokenizer(text="My name is Llama 3", return_tensors="pt")
150
+ >>> inputs = to_cuda(inputs)
151
+ >>> # Prepare a cache and pass it to model's forward; `layer_num` is the number of layers for the pretrained model.
152
+ >>> past_key_values = SepCache(init_cache_size=4, sep_cache_size=128, local_size=256, cache_size=512, layer_num=32, USE_MAX_SEP_CACHE=True, model_type='llama')
153
+ >>> # `separator_token_ids` and `PADDING_ID` must also be provided if you are not using `model_type='llama'` like this demo.
154
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
155
+ >>> outputs.past_key_values # access SepCache filled with keys/values
156
+ SepCache()
157
+ ```
158
+
159
+ ```python
160
+ >>> ## When using the `update` function of SepCache to update the keys/values and the past token ids (necessary in SepCache), the current `input_ids` must also be provided.
161
+ >>> key_states, value_states = past_key_values.update(
162
+ key_states = key_states,
163
+ value_states = value_states,
164
+ input_ids = input_ids,
165
+ layer_idx = layer_idx,
166
+ PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
167
+ )
168
+
169
+ ```
170
+ For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
171
+ """
172
+
173
+ # 0. Monkey Patching for the `update` function of `SepCache`
174
+ model_layers = monkey_patching(model, model_atten_forward=llama_atten_forward, verbose=monkey_patch_verbose)
175
+
176
+ # 1. General sanity checks
177
+ # 1.a. A few arguments are not allowed, especially arguments that control caches.
178
+ generation_config = kwargs.get("generation_config")
179
+ default_global_generation_config = GenerationConfig()
180
+ default_model_generation_config = model.generation_config
181
+ for arg in UNSUPPORTED_GENERATION_ARGS:
182
+ has_custom_gen_config_arg = (
183
+ generation_config is not None
184
+ # = and not (match global default or match model-specific default)
185
+ and not (
186
+ getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
187
+ or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
188
+ )
189
+ )
190
+ kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
191
+ if kwargs_has_arg or has_custom_gen_config_arg:
192
+ raise ValueError(
193
+ f"`{arg}` is set, but it's not supported in this custom generate function. List of "
194
+ f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
195
+ )
196
+
197
+
198
+
199
+ # 1.b. The model must be decoder-only
200
+ if model.config.is_encoder_decoder:
201
+ raise ValueError("This custom generate function only works with decoder-only models")
202
+
203
+ # 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
204
+ # in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
205
+ kwargs.pop("custom_generate", None)
206
+
207
+
208
+ sepllm_kwargs = {}
209
+ sepllm_kwargs["input_ids"] = kwargs["input_ids"] ## `input_ids` must be passed to the `update` function of `SepCache`
210
+ kwargs["sepllm_kwargs"] = sepllm_kwargs
211
+
212
+ # 2. Generate with SepCache
213
+ # 2.a. prepare the cache, if it was not passed.
214
+ past_key_values = kwargs.pop("past_key_values", None)
215
+ if past_key_values is None:
216
+ past_key_values = SepCache(
217
+ ## For SepCache
218
+ init_cache_size = init_cache_size,
219
+ sep_cache_size = sep_cache_size,
220
+ local_size = local_size,
221
+ cache_size = cache_size,
222
+ SEP_ACCUMULATION = SEP_ACCUMULATION,
223
+ USE_MAX_SEP_CACHE = USE_MAX_SEP_CACHE,
224
+ SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH,
225
+ separator_token_ids = separator_token_ids, ## required for initialization if `model_type` is not provided.
226
+ PADDING_ID = PADDING_ID, ## required for initialization if `model_type` is not provided.
227
+
228
+ ## For inheritance & initialization states
229
+ past_tok_ids = past_tok_ids, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
230
+ key_cache = key_cache,
231
+ value_cache = value_cache,
232
+
233
+ ## For debugging
234
+ PRINT_KV_RATIO_INSIDE = PRINT_KV_RATIO_INSIDE,
235
+ print_KV_inside_per_steps = print_KV_inside_per_steps,
236
+ _seen_tokens = _seen_tokens,
237
+ _kept_kv_ratio = _kept_kv_ratio,
238
+
239
+ ### For positional encoding shifting
240
+ APPLY_PE_SHIFT = APPLY_PE_SHIFT,
241
+ APPLY_PES_INSIDE = APPLY_PES_INSIDE,
242
+ _shifted_position_ids = _shifted_position_ids,
243
+ _rope_unsqueeze_dim = _rope_unsqueeze_dim, ## The unsqueeze_dim when applying RoPE.
244
+ _rope_seq_dim =_rope_seq_dim, ## The seq_len dimension for the `cos` or `sin` tensors.
245
+ pe_scaling_factor = pe_scaling_factor,
246
+ pe_dim = pe_dim, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this, i.e., model.config.hidden_size // model.config.num_attention_heads
247
+ max_position_embeddings = max_position_embeddings, # i.e., model.config.max_position_embeddings
248
+ base = base, ## The base for RoPE.
249
+
250
+ ## For basic transformer architecture
251
+ k_seq_dim = k_seq_dim, ## The dimension for seq_len in key tensors
252
+ v_seq_dim = v_seq_dim, ## The dimension for seq_len in value tensors
253
+ layer_num = len(model_layers), ## required for initialization. model.config.num_hidden_layers
254
+
255
+ model_type = model_type, ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
256
+ device = device,
257
+ )
258
+
259
+ elif not isinstance(past_key_values, SepCache):
260
+ raise ValueError(f"`past_key_values` must be a `SepCache` instance, got a {type(past_key_values)} instance")
261
+
262
+ # 2.b. generate with the cache
263
+ kwargs["use_cache"] = True
264
+ generation_outputs = model.generate(**kwargs, past_key_values=past_key_values)
265
+ return generation_outputs
custom_generate_split_4_backup/monkey_patching_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import inspect
3
+ import importlib
4
+ import transformers
5
+ import types
6
+
7
+ import torch.nn as nn
8
+ from transformers.modeling_utils import PreTrainedModel
9
+ from typing import Callable, Optional, Union, Any, List
10
+
11
+ from .functions_2_patch import _validate_model_kwargs, llama_atten_forward
12
+
13
+
14
+ def get_full_class_import_path(obj):
15
+ """Get the complete class import path of an object"""
16
+ # Get the class of the object
17
+ cls = obj.__class__
18
+
19
+ # Get the module name where the class is defined
20
+ module = cls.__module__
21
+
22
+ # Get the qualified name of the class (including outer classes)
23
+ qualname = cls.__qualname__
24
+
25
+ # Handle nested classes (e.g., ClassA.ClassB)
26
+ if '.' in qualname:
27
+ # Replace nested class separators
28
+ class_path = f"{module}.{qualname.replace('.', '_')}"
29
+ else:
30
+ class_path = f"{module}.{qualname}"
31
+
32
+ return class_path
33
+
34
+
35
+ def get_importable_class_path(obj):
36
+ """Get the directly importable class path (handling special cases and dynamic classes)"""
37
+ cls = obj.__class__
38
+ module = cls.__module__
39
+ qualname = cls.__qualname__
40
+
41
+ # Handle built-in types
42
+ if module == 'builtins':
43
+ return qualname
44
+
45
+ # Handle dynamically generated classes (e.g., functools.partial)
46
+ if not hasattr(cls, '__module__') or module is None:
47
+ return f"<dynamic class {qualname}>"
48
+
49
+ # Handle nested classes
50
+ if '.' in qualname:
51
+ # Try to import the parent module to validate the path
52
+ try:
53
+ import importlib
54
+ parent_module = importlib.import_module(module)
55
+
56
+ # Follow the qualified name path
57
+ parts = qualname.split('.')
58
+ current = parent_module
59
+ for part in parts:
60
+ current = getattr(current, part)
61
+
62
+ # If successful access, return the original path
63
+ return f"{module}.{qualname}"
64
+ except (ImportError, AttributeError):
65
+ # Fallback: use underscore connection
66
+ return f"{module}.{qualname.replace('.', '_')}"
67
+
68
+ return f"{module}.{qualname}"
69
+
70
+
71
+
72
+ def monkey_patch_by_class_path(model, new_forward):
73
+ """Perform monkey patching through class path"""
74
+ # Get the complete class path
75
+ class_path = get_importable_class_path(model)
76
+
77
+ # Dynamically import the class
78
+ try:
79
+ import importlib
80
+ module_path, class_name = class_path.rsplit('.', 1)
81
+ module = importlib.import_module(module_path)
82
+ target_class = getattr(module, class_name)
83
+
84
+ # Save the original method
85
+ if not hasattr(target_class, '_original_forward'):
86
+ target_class._original_forward = target_class.forward
87
+
88
+ # Apply the patch
89
+ target_class.forward = new_forward
90
+
91
+ # Update the method binding of the current instance
92
+ model.forward = types.MethodType(target_class.forward, model)
93
+
94
+ return f"Successful Monkey Patch: {class_path}.forward"
95
+
96
+ except (ImportError, AttributeError, ValueError) as e:
97
+ return f"Patch Failed: {str(e)}"
98
+
99
+
100
+
101
+
102
+ def find_inner_attribute(obj, attr_name_list: List[str], default_type = PreTrainedModel ):
103
+ # try to find the attribute of the name in `attr_name_list`.
104
+ for target_attr_name in attr_name_list:
105
+ if hasattr(obj, target_attr_name):
106
+ return getattr(obj, target_attr_name)
107
+
108
+ # else: try to find the attribute of the type `default_type`
109
+ for attr_name in dir(obj):
110
+ attr_value = getattr(obj, attr_name)
111
+ if isinstance(attr_value, default_type):
112
+ return attr_value
113
+
114
+ raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.")
115
+
116
+
117
+ def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module):
118
+ for attr_name in dir(obj):
119
+ attr_value = getattr(obj, attr_name)
120
+ for pattern in name_pattern_list:
121
+ for ex_pattern in exclude_pattern_list:
122
+ if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ):
123
+ return attr_value
124
+ elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ):
125
+ return attr_value
126
+
127
+ raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.")
128
+
129
+
130
+
131
+ def monkey_patching(model_obj, model_atten_forward , verbose = True):
132
+ transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
133
+
134
+ ## get inner model
135
+ possible_inner_model_names = ["model", "transformer", "gpt_neox"]
136
+ inner_model_type = PreTrainedModel
137
+ inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
138
+
139
+
140
+ possible_layers_names = ["layers", "h" ]
141
+ layers_type = nn.ModuleList
142
+ model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
143
+
144
+ atten_attr_name_pattern_list = ["attention", "self_attn"]
145
+ atten_attr_name_pattern_exclude = ["norm", "layer"]
146
+
147
+ for i, decoder_layer in enumerate(model_layers):
148
+ self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
149
+ result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
150
+ if verbose:
151
+ decoder_class_name = get_importable_class_path(decoder_layer)
152
+ print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
153
+
154
+ return model_layers
custom_generate_split_4_backup/sep_cache_utils.py ADDED
@@ -0,0 +1,1205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Cache, GenerationConfig
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+ import torch
4
+ from packaging import version
5
+ from dataclasses import dataclass
6
+
7
+
8
+
9
+ class SepCache(Cache):
10
+ """
11
+ A cache as described in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094). In the training phase,
12
+ SepLLM condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the
13
+ corresponding SepCache only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
14
+
15
+ It stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
16
+ `[batch_size, num_heads, seq_len, head_dim]`.
17
+
18
+ Frequently-Used Parameters:
19
+
20
+ `init_cache_size: Union[int, List]`:
21
+ The maximum number of KVs to be stored for initial tokens.
22
+ In the paper, the hyperparameter `a` is an abbreviated alias for `self.init_cache_size`.
23
+
24
+ `sep_cache_size: Union[int, List]`:
25
+ The maximum number of KVs to be stored for separator tokens.
26
+ In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
27
+
28
+ `local_size: Union[int, List]`:
29
+ The maximum number of KVs to be stored for local tokens (i.e., sliding window).
30
+ In the paper, the hyperparameter `w` is an abbreviated alias for `self.local_size`.
31
+
32
+ `cache_size: Union[int, List]`:
33
+ The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
34
+ In the paper, the hyperparameter `c` is an abbreviated alias for `self.cache_size`.
35
+
36
+ Concerning these four parameters above:
37
+ When a list is passed (its length must be `layer_num`), it represents different values for each layer.
38
+ When an integer is passed, it means the setting is the same for all layers.
39
+
40
+
41
+ `USE_MAX_SEP_CACHE: bool`:
42
+ If True, it means we only keep at most `self.sep_cache_size` seperators' KVs.
43
+ If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs.
44
+ In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
45
+
46
+ `separator_token_ids: List[int]`:
47
+ The token ids of the separator tokens for the current model's tokenizer.
48
+ We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
49
+ to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
50
+
51
+ `PADDING_ID: int`:
52
+ The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
53
+
54
+ Important Note:
55
+ When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache.
56
+ However, you must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`.
57
+ Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.
58
+ To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
59
+ `init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025]
60
+ to leave room for `left_padding_offset`.
61
+
62
+ Please refer to the `__init__` function's comments for more details on the parameters.
63
+
64
+ Example:
65
+
66
+ ```python
67
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SepCache
68
+ >>> import torch
69
+ >>> from huggingface_hub import login
70
+ >>> login("hf_xxxXXXxxx")
71
+
72
+
73
+ >>> def to_cuda(a_dict: dict) -> dict:
74
+ >>> new_dict = {}
75
+ >>> for k,v in a_dict.items():
76
+ >>> if isinstance(v, torch.Tensor):
77
+ >>> new_dict[k] = v.cuda()
78
+ >>> else:
79
+ >>> new_dict[k] = v
80
+ >>> return new_dict
81
+
82
+ >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", attn_implementation="flash_attention_2", device_map="cuda:0")
83
+ >>> model.bfloat16().cuda()
84
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
85
+ >>> inputs = tokenizer(text="My name is Llama 3", return_tensors="pt")
86
+ >>> inputs = to_cuda(inputs)
87
+ >>> # Prepare a cache and pass it to model's forward; `layer_num` is the number of layers for the pretrained model.
88
+ >>> past_key_values = SepCache(init_cache_size=4, sep_cache_size=128, local_size=256, cache_size=512, layer_num=32, USE_MAX_SEP_CACHE=True, model_type='llama')
89
+ >>> # `separator_token_ids` and `PADDING_ID` must also be provided if you are not using `model_type='llama'` like this demo.
90
+ >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
91
+ >>> outputs.past_key_values # access SepCache filled with keys/values
92
+ SepCache()
93
+ ```
94
+
95
+ ```python
96
+ >>> ## When using the `update` function of SepCache to update the keys/values and the past token ids (necessary in SepCache), the current `input_ids` must also be provided.
97
+ >>> key_states, value_states = past_key_values.update(
98
+ key_states = key_states,
99
+ value_states = value_states,
100
+ input_ids = input_ids,
101
+ layer_idx = layer_idx,
102
+ PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
103
+ )
104
+
105
+ ```
106
+ For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
107
+ """
108
+ # is_sliding = True
109
+
110
+ @staticmethod
111
+ def slice_on_1d(x, start, end):
112
+ return x[:, start:end, ...]
113
+ @staticmethod
114
+ def slice_on_2d(x, start, end):
115
+ return x[:, :, start:end, ...]
116
+ @staticmethod
117
+ def slice_on_3d(x, start, end):
118
+ return x[:, :, :, start:end, ...]
119
+
120
+
121
+ @staticmethod
122
+ def sep_1bat_select_on_1d(x, Bid, sep_index, min_sep_num=None, max_sep_num=None, SEP_PADDING_IN_BATCH=True):
123
+ """
124
+ For the record with index `Bid` in a batch, extract the K/V states corresponding to the separators on dimension 1.
125
+ If `SEP_PADDING_IN_BATCH=True`, pad to the longest length (i.e. `max_sep_num`);
126
+ otherwise, truncate to the shortest length (i.e. `min_sep_num`).
127
+ """
128
+ sep_index = sep_index.to(x.device)
129
+
130
+ if SEP_PADDING_IN_BATCH: ## Need padding
131
+ assert max_sep_num is not None, f"if `SEP_PADDING_IN_BATCH=True`, `max_sep_num` should not be None"
132
+ new_x_sep = x[Bid, sep_index, ...] # # batch x seqlen x head x dim --> sep_num x head x dim
133
+ padding_num = max_sep_num - new_x_sep.shape[0]
134
+ if padding_num > 0 :
135
+ assert padding_num <= x.shape[1], f"`padding_num` should be <= `x.shape[1]`, i.e. x's seqlen"
136
+ new_x_pad = x[Bid, -padding_num: , ...] # padding_num x head x dim
137
+ return torch.cat([new_x_sep, new_x_pad ] , dim=0) # max_sep_num x head x dim
138
+ else:
139
+ return new_x_sep # max_sep_num x head x dim
140
+
141
+ if min_sep_num is None:
142
+ return x[Bid, sep_index, ...] # # batch x seqlen x head x dim --> sep_num x head x dim
143
+ else: ## `min_sep_num` is provided. Need truncation
144
+ new_x = x[Bid, sep_index, ...] # # batch x seqlen x head x dim --> sep_num x head x dim
145
+ return new_x[ :min_sep_num, ...] # # min_sep_num x head x dim
146
+
147
+
148
+ @staticmethod
149
+ def sep_1bat_select_on_2d(x, Bid, sep_index, min_sep_num=None, max_sep_num=None, SEP_PADDING_IN_BATCH=True):
150
+ """
151
+ For the record with index `Bid` in a batch, extract the K/V states corresponding to the separators on dimension 2.
152
+ If `SEP_PADDING_IN_BATCH=True`, pad to the longest length (i.e. `max_sep_num`);
153
+ otherwise, truncate to the shortest length (i.e. `min_sep_num`).
154
+ """
155
+ sep_index = sep_index.to(x.device)
156
+
157
+ if SEP_PADDING_IN_BATCH: ## Need padding
158
+ assert max_sep_num is not None, f"if `SEP_PADDING_IN_BATCH=True`, `max_sep_num` should not be None"
159
+ new_x_sep = x[Bid, :, sep_index, ...] # # batch x head x seqlen x dim --> head x sep_num x dim
160
+ padding_num = max_sep_num - new_x_sep.shape[-2]
161
+ if padding_num > 0 :
162
+ assert padding_num<= x.shape[-2], f"`padding_num` should be <= `x.shape[-2]`, i.e. x's seqlen"
163
+ new_x_pad = x[Bid, :, -padding_num: , ...] # head x padding_num x dim
164
+ return torch.cat([new_x_sep, new_x_pad ] , dim=-2) # head x max_sep_num x dim
165
+ else:
166
+ return new_x_sep # head x max_sep_num x dim
167
+
168
+ if min_sep_num is None:
169
+ return x[Bid, :, sep_index, ...] # # batch x head x seqlen x dim --> head x sep_num x dim
170
+ else: ## `min_sep_num` is provided. Need truncation
171
+ new_x = x[Bid, :, sep_index, ...] # # batch x head x seqlen x dim --> head x sep_num x dim
172
+ return new_x[:, :min_sep_num, ...] # # head x min_sep_num x dim
173
+
174
+
175
+ @staticmethod
176
+ def sep_1bat_select_on_3d(x, Bid, sep_index, min_sep_num=None, max_sep_num=None, SEP_PADDING_IN_BATCH=True):
177
+ """
178
+ For the record with index `Bid` in a batch, extract the K/V states corresponding to the separators on dimension 3.
179
+ If `SEP_PADDING_IN_BATCH=True`, pad to the longest length (i.e. `max_sep_num`);
180
+ otherwise, truncate to the shortest length (i.e. `min_sep_num`).
181
+ """
182
+ sep_index = sep_index.to(x.device)
183
+
184
+ if SEP_PADDING_IN_BATCH: ## Need padding
185
+ assert max_sep_num is not None, f"if `SEP_PADDING_IN_BATCH=True`, `max_sep_num` should not be None"
186
+ new_x_sep = x[Bid, :, :, sep_index, ...] # # batch x head x dim x seqlen --> head x dim x sep_num
187
+ padding_num = max_sep_num - new_x_sep.shape[-1]
188
+ if padding_num > 0 :
189
+ assert padding_num <= x.shape[-1], f"`padding_num` should be <= `x.shape[-1]`, i.e. x's seqlen"
190
+ new_x_pad = x[Bid, :, :, -padding_num:, ...] # head x dim x padding_num
191
+ return torch.cat([new_x_sep, new_x_pad] , dim=-1) # head x dim x max_sep_num
192
+ else:
193
+ return new_x_sep # head x dim x max_sep_num
194
+
195
+ if min_sep_num is None:
196
+ return x[Bid, :, :, sep_index, ...] # # batch x head x dim x seqlen --> head x dim x sep_num
197
+ else: ## `min_sep_num` is provided. Need truncation
198
+ new_x = x[Bid, :, :, sep_index, ...] # # batch x head x dim x seqlen --> head x dim x sep_num
199
+ return new_x[:, :, :min_sep_num, ...] # # head x dim x min_sep_num
200
+
201
+ DIM_TO_SLICE = {
202
+ 1: slice_on_1d,
203
+ 2: slice_on_2d,
204
+ 3: slice_on_3d,
205
+ }
206
+
207
+ BAT_DIM_TO_SELECT = {
208
+ 1: sep_1bat_select_on_1d,
209
+ 2: sep_1bat_select_on_2d,
210
+ 3: sep_1bat_select_on_3d,
211
+ }
212
+
213
+ def __init__(self,
214
+ ## For SepLLM
215
+ init_cache_size: Union[int, List] = 4,
216
+ sep_cache_size: Union[int, List] = 64,
217
+ local_size: Union[int, List]=256,
218
+ cache_size: Union[int, List]=512,
219
+ SEP_ACCUMULATION: bool = True,
220
+ USE_MAX_SEP_CACHE: bool = False,
221
+ SEP_PADDING_IN_BATCH: bool = False,
222
+ separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
223
+ PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
224
+
225
+ ## For inheritance & initialization states
226
+ past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
227
+ key_cache: List[torch.Tensor] = None,
228
+ value_cache: List[torch.Tensor] = None,
229
+
230
+ ## For debugging
231
+ PRINT_KV_RATIO_INSIDE: bool = False,
232
+ print_KV_inside_per_steps: int = 1000,
233
+ _seen_tokens: int = 0,
234
+ _kept_kv_ratio: List[Tuple[int]] = None,
235
+
236
+ ### For positional encoding shifting
237
+ APPLY_PE_SHIFT: bool = False,
238
+ APPLY_PES_INSIDE: bool = True,
239
+ _shifted_position_ids: List[torch.Tensor] = None,
240
+ _rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
241
+ _rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
242
+ pe_scaling_factor:float = 1.0,
243
+ pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
244
+ max_position_embeddings: int = 8192,
245
+ base: int=10000, ## The base for RoPE.
246
+
247
+ ## For basic transformer architecture
248
+ k_seq_dim: int=2, ## The dimension for seq_len in key tensors
249
+ v_seq_dim: int=2, ## The dimension for seq_len in value tensors
250
+ layer_num: int = None, ## required for initialization
251
+
252
+ model_type: str = None, ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
253
+ device = None
254
+ ) -> None:
255
+ """
256
+ `SEP_ACCUMULATION`: If True, it means we will try to accumulate all the KVs for seperators. If False, only the `new_sep_kv` compressed from the `past_win_kv` will be kept (see function `compress_kv_cache_and_tokids_layer_wise`).
257
+
258
+ `USE_MAX_SEP_CACHE`: If True, it means we only keep at most `self.sep_cache_size` seperators' KVs. If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs. In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
259
+
260
+ `SEP_PADDING_IN_BATCH`: If True, it means that SepCache will pad separator tokens in other records to be aligned with the record with the most separators in a batch. If False, it means that SepCache will truncate older separator tokens in other records to be aligned with the record with the fewest separators in a batch.
261
+
262
+ Note: If `SEP_ACCUMULATION=True` and `USE_MAX_SEP_CACHE=False`, as the number of input tokens increases, the number of separators in the KV cache will also accumulate endlessly
263
+ and `self.cache_size` will also be infinitely expanded (no longer fixed).
264
+
265
+ When `SEP_PADDING_IN_BATCH=True` is used in combination with `USE_MAX_SEP_CACHE=False` and `SEP_ACCUMULATION=True`, the KV cache will accumulate indefinitely,
266
+ and since `SEP_PADDING_IN_BATCH=True`, the KVs of all separators will be retained (rather than being truncated).
267
+
268
+
269
+ For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
270
+ """
271
+
272
+ super().__init__()
273
+ if (key_cache is not None) or (value_cache is not None) or (past_tok_ids is not None):
274
+ assert isinstance(key_cache, list)
275
+ assert isinstance(value_cache, list)
276
+ assert isinstance(past_tok_ids, list), f"For SepCache, if `key_cache` and `value_cache` are given (e.g., provided from legacy `past_key_values`), `past_tok_ids` corresponding to `key_cache` and `value_cache` must also be provided to initialize SepCache."
277
+
278
+ assert len(key_cache) == len(past_tok_ids), f"The length of `key_cache` ({len(key_cache)}) should be equal to that of `past_tok_ids` ({len(past_tok_ids)})."
279
+ assert len(value_cache) == len(past_tok_ids), f"The length of `value_cache` ({len(value_cache)}) should be equal to that of `past_tok_ids` ({len(past_tok_ids)})."
280
+ assert layer_num is not None, f"`layer_num` must be provided according to the pretrained model."
281
+
282
+ ## For basic parameters & states
283
+ self.key_cache: List[torch.Tensor] = key_cache if key_cache is not None else []
284
+ self.value_cache: List[torch.Tensor] = value_cache if value_cache is not None else []
285
+
286
+ self.k_seq_dim = k_seq_dim ## The dimension for the seq_len in key states. Typically, 2.
287
+ self.v_seq_dim = v_seq_dim ## The dimension for the seq_len in value states. Typically, 2.
288
+
289
+ self.k_slice = self.DIM_TO_SLICE[k_seq_dim]
290
+ self.v_slice = self.DIM_TO_SLICE[v_seq_dim]
291
+
292
+ self.k_bat_dim_select = self.BAT_DIM_TO_SELECT[k_seq_dim]
293
+ self.v_bat_dim_select = self.BAT_DIM_TO_SELECT[v_seq_dim]
294
+ self._seen_tokens: int = _seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen as well as performing statistics.
295
+ self.layer_num = layer_num
296
+ self.device = device # Deprecated
297
+
298
+
299
+ ## For debugging
300
+ self.PRINT_KV_RATIO_INSIDE = PRINT_KV_RATIO_INSIDE
301
+ self.print_KV_inside_per_steps = print_KV_inside_per_steps
302
+ self._print_kv_ratio_count = 0
303
+ self._kept_kv_ratio: List[Tuple[int]] = _kept_kv_ratio if _kept_kv_ratio is not None else []
304
+
305
+ ## For Streaming SepLLM
306
+ self.past_tok_ids: List[torch.Tensor] = past_tok_ids if past_tok_ids is not None else [] ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache
307
+ self.left_padding_offset = None
308
+ self._set_layer_wise_attribute("init_cache_size", init_cache_size, layer_num)
309
+ self._set_layer_wise_attribute("local_size", local_size, layer_num)
310
+ self._set_layer_wise_attribute("cache_size", cache_size, layer_num)
311
+ self._set_layer_wise_attribute("sep_cache_size", sep_cache_size, layer_num)
312
+ self._set_layer_wise_attribute("sep_exrange", 0, layer_num) # runtime right boundary for separators, excluded
313
+ self._set_layer_wise_attribute("max_sep_exidx", self._list_element_add(self.sep_cache_size, self.init_cache_size), layer_num) # max right boundary for separators, excluded
314
+ self.SEP_ACCUMULATION = SEP_ACCUMULATION
315
+ self.USE_MAX_SEP_CACHE = USE_MAX_SEP_CACHE
316
+ self.SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH
317
+
318
+
319
+ ### For positional encoding shifting
320
+ self.APPLY_PE_SHIFT = APPLY_PE_SHIFT
321
+ self.APPLY_PES_INSIDE = APPLY_PES_INSIDE
322
+
323
+ self.cos_sin_rerotation_cache = {}
324
+ self._cos_cache = None
325
+ self._sin_cache = None
326
+ self._shifted_position_ids: List[torch.Tensor] = _shifted_position_ids if _shifted_position_ids is not None else []
327
+ self._rope_unsqueeze_dim = _rope_unsqueeze_dim
328
+ self._rope_seq_dim = _rope_seq_dim
329
+
330
+ self.pe_dim = pe_dim
331
+ self.max_position_embeddings = max_position_embeddings
332
+ self.base = base
333
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.pe_dim, 2, dtype=torch.int64).float().to(device) / self.pe_dim))
334
+ self.inv_freq = inv_freq
335
+ self.pe_scaling_factor = pe_scaling_factor
336
+ self._sin_cached = None
337
+ self._cos_cached = None
338
+
339
+ if model_type is None:
340
+ assert isinstance(separator_token_ids, list), f"`separator_token_ids: List[int]` must be correctly provided for initialization unless `model_type` is properly given, which will auto-fiil `separator_token_ids`."
341
+ assert len(separator_token_ids) > 0, f"`separator_token_ids: List[int]` should NOT be empty."
342
+ for i in range(len(separator_token_ids)):
343
+ assert isinstance(separator_token_ids[i], int), f"The ids in `separator_token_ids` must be of the type `int`."
344
+ assert isinstance(PADDING_ID, int), f"`PADDING_ID: int` must be correctly provided for initialization unless `model_type` is properly given, which will auto-fiil `PADDING_ID`."
345
+ self.separator_token_ids = separator_token_ids
346
+ self.PADDING_ID = PADDING_ID
347
+ else:
348
+ assert isinstance(model_type, str), f"`model_type` should be a `str` or `None`."
349
+ if 'llama' in model_type.lower():
350
+ # print("Debug: For Llama's default separators")
351
+ self.separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262] # llama3 8b
352
+ self.PADDING_ID = 128009
353
+ elif ( 'pythia' in model_type.lower() ) or ( 'gpt_neox' in model_type.lower() ):
354
+ # print("Debug: For GPTNeox's default separators")
355
+ self.separator_token_ids = [15, 13, 32, 2, 28, 27, 209, 186, 187, 964, 1157, 3736, 2195, 3706, 1163, 2490, 50276, 586, 4928, 50275 ] # pythia 14b
356
+ self.PADDING_ID = 0
357
+ elif 'falcon' in model_type.lower():
358
+ # print(f"Debug: For Falcon's default separators")
359
+ self.separator_token_ids = [25, 23, 42, 12, 38, 37, 193, 4610, 204, 258, 1212, 23787, 466 ] # falcon-40b
360
+ self.PADDING_ID = 11
361
+ else:
362
+ raise NotImplementedError(f"NOT implemented for the tokenizer of the backbone model type: `{model_type}`. You must provide `separator_token_ids: List[int]` and `PADDING_ID: int` for initialization in this case! ")
363
+
364
+ if APPLY_PE_SHIFT:
365
+ print(">>>>>>>>---------#####################################################################################-----------<<<<<<<<")
366
+ print(">>>>>>>>--------- -----------<<<<<<<<")
367
+ print(">>>>>>>>--------- Warning: When `APPLY_PE_SHIFT=True`, SepCache must store the key/value states ----------<<<<<<<<")
368
+ print(">>>>>>>>--------- before applying positional encoding (specifically RoPE) -----------<<<<<<<<")
369
+ print(">>>>>>>>---------#####################################################################################-----------<<<<<<<<")
370
+
371
+ if APPLY_PES_INSIDE:
372
+ print(">>>>>>>>---------#####################################################################################-----------<<<<<<<<")
373
+ print(">>>>>>>>--------- -----------<<<<<<<<")
374
+ print(">>>>>>>>--------- Warning: When `APPLY_PES_INSIDE=True`, there is no need to apply rotary positional embedding--<<<<<<<<")
375
+ print(">>>>>>>>--------- within the self_attention function, as this operation will be handled inside the `update` ---<<<<<<<<")
376
+ print(">>>>>>>>--------- function of SepCache. Note that `APPLY_PES_INSIDE=True` is typically used together with ---<<<<<<<<")
377
+ print(">>>>>>>>--------- `APPLY_PE_SHIFT=True`. ---<<<<<<<<")
378
+ print(">>>>>>>>---------#####################################################################################-----------<<<<<<<<")
379
+
380
+
381
+ def update(
382
+ self,
383
+ key_states: torch.Tensor,
384
+ value_states: torch.Tensor,
385
+ layer_idx: int,
386
+ input_ids: torch.Tensor = None,
387
+ PREFILLING_FLAG: bool = True,
388
+ query_states: Optional[torch.Tensor] = None,
389
+ position_ids: Optional[torch.Tensor]=None,
390
+ cache_kwargs: Optional[Dict[str, Any]] = None,
391
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor],Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
392
+ """
393
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
394
+
395
+ Parameters:
396
+ `key_states` (`torch.Tensor`):
397
+ The new key states to cache.
398
+ `value_states` (`torch.Tensor`):
399
+ The new value states to cache.
400
+ `input_ids` (`torch.Tensor`)
401
+ The ids of the input tokens (context tokens or autoregressive tokens)
402
+ `layer_idx` (`int`):
403
+ The index of the layer to cache the states for.
404
+ `PREFILLING_FLAG` (`bool`)
405
+ It should be `True` at pre-filling phase and `False` when decoding
406
+
407
+ `query_states` (`Optional[torch.Tensor]`)
408
+ The query states that need positional encoding shifting. Only useful when `self.APPLY_PE_SHIFT=True`
409
+ `position_ids` (`Optional[torch.Tensor]`)
410
+ The original positional ids of the tokens in the input sequence (i.e., indices of positions of each input sequence tokens in the position embeddings)
411
+ Only useful when `self.APPLY_PE_SHIFT=True`, i.e., SepCache will utilize `position_ids` to calculate positional shifting.
412
+ `cache_kwargs` (`Dict[str, Any]`, optional):
413
+ Additional arguments for the cache update. The following arguments can be used in `SepCache`: `sin`,
414
+ `cos`, `sin_q`, `cos_q`, `shifted_pos_ids` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
415
+ rotation as the tokens are shifted. (These are only useful when `self.APPLY_PE_SHIFT=True`)
416
+
417
+ Only useful when `self.APPLY_PE_SHIFT=True` and `self.APPLY_PES_INSIDE=False`:
418
+ `cos` and `sin` are the shifted rotation matrices for key states
419
+ `cos_q` and `sin_q` are the shifted rotation matrices for query states
420
+ `shifted_pos_ids` is the shifted positional ids for key states
421
+
422
+ When `self.APPLY_PE_SHIFT=True` and `self.APPLY_PES_INSIDE=True`:
423
+ SepCache will utilize `position_ids` to calculate positional shifting.
424
+
425
+ `partial_rotation_size` means that `partial_rotation_size` slices along certain dimension need to be shifted (i.e., [0, 1, ..., `partial_rotation_size-1`] slices along certain dimension)
426
+
427
+ Return:
428
+ A tuple containing the updated key, value, and query states (query states are optional: only applicable when `self.APPLY_PE_SHIFT=True`).
429
+
430
+ For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
431
+ """
432
+
433
+ APPLY_PE_SHIFT = self.APPLY_PE_SHIFT
434
+ APPLY_PES_INSIDE = self.APPLY_PES_INSIDE
435
+ SEP_ACCUMULATION = self.SEP_ACCUMULATION
436
+ USE_MAX_SEP_CACHE = self.USE_MAX_SEP_CACHE
437
+ SEP_PADDING_IN_BATCH = self.SEP_PADDING_IN_BATCH
438
+
439
+ if input_ids is None:
440
+ input_ids = cache_kwargs.get("input_ids", None)
441
+ assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`."
442
+
443
+ assert (self.APPLY_PE_SHIFT and (query_states is not None)) or not APPLY_PE_SHIFT, f"If `APPLY_PE_SHIFT=True`, `query_states` should be provided and it will be updated and returned"
444
+
445
+ # Update the number of seen tokens
446
+ if layer_idx == 0:
447
+ assert key_states.shape[-2] == input_ids.shape[-1], f"`key_states.shape[-2]` ({key_states.shape[-2]}) should be equal to `input_ids.shape[-1]` ({input_ids.shape[-1]})."
448
+ self._seen_tokens += input_ids.shape[-1]
449
+
450
+ # [bsz, num_heads, seq_len, head_dim]
451
+ new_kv_pair = (key_states, value_states)
452
+
453
+ if (key_states.shape[self.k_seq_dim] + self.get_usable_length(layer_idx) < self.cache_size[layer_idx]) or PREFILLING_FLAG: ## For prefilling
454
+ assert (PREFILLING_FLAG and key_states.shape[self.k_seq_dim] >= 1) or (not PREFILLING_FLAG and key_states.shape[self.k_seq_dim] == 1)
455
+
456
+ # Update cache and past token ids
457
+ self.update_kv_cache_and_past_tok_ids(new_kv_pair, input_ids, layer_idx, COMPRESS_KV=False, SEP_ACCUMULATION=SEP_ACCUMULATION, USE_MAX_SEP_CACHE=USE_MAX_SEP_CACHE, SEP_PADDING_IN_BATCH=SEP_PADDING_IN_BATCH)
458
+
459
+ if APPLY_PE_SHIFT:
460
+ shifted_keys, shifted_queries = self.apply_shifted_pos_emb(layer_idx, APPLY_PES_INSIDE, PREFILLING_FLAG, key_states, query_states, position_ids, cache_kwargs )
461
+ query_states = shifted_queries
462
+ self.set_kv_cache( (shifted_keys, self.value_cache[layer_idx]), layer_idx)
463
+
464
+ if PREFILLING_FLAG and layer_idx == 0:
465
+ self.left_padding_offset = self.get_initial_pos_offset(layer_idx)
466
+
467
+ ## Count KV usage
468
+ kv_len_ori = self.get_seq_length(layer_idx)
469
+ kv_len_cmp = self.get_usable_length(layer_idx)
470
+ self._update_kv_ratio(kv_len_cmp=kv_len_cmp, kv_len_ori=kv_len_ori, layer_idx=layer_idx)
471
+
472
+ else:
473
+ ## Update the KV cache, count KV usage, and compress the KV cache if necessary
474
+ kv_len_ori = self.get_seq_length(layer_idx)
475
+ offset_init_size_layer = self.update_kv_cache_and_past_tok_ids(new_kv_pair, input_ids, layer_idx, COMPRESS_KV=True, SEP_ACCUMULATION=SEP_ACCUMULATION, USE_MAX_SEP_CACHE=USE_MAX_SEP_CACHE, SEP_PADDING_IN_BATCH=SEP_PADDING_IN_BATCH)
476
+ kv_len_cmp = self.get_usable_length(layer_idx)
477
+ self._update_kv_ratio(kv_len_cmp=kv_len_cmp, kv_len_ori=kv_len_ori, layer_idx=layer_idx)
478
+
479
+ if APPLY_PE_SHIFT:
480
+ shifted_keys, shifted_queries = self.apply_shifted_pos_emb(layer_idx, APPLY_PES_INSIDE, PREFILLING_FLAG, key_states, query_states, position_ids, cache_kwargs )
481
+ query_states = shifted_queries
482
+ self.set_kv_cache( (shifted_keys, self.value_cache[layer_idx]), layer_idx)
483
+
484
+ if self.PRINT_KV_RATIO_INSIDE:
485
+ self._print_kv_ratio(layer_idx)
486
+
487
+ if query_states is not None:
488
+ return self.key_cache[layer_idx], self.value_cache[layer_idx], query_states
489
+ else:
490
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
491
+
492
+
493
+ def update_kv_cache_and_past_tok_ids(self, new_kv_pair: Tuple[torch.Tensor], input_ids: torch.Tensor, layer_idx: int, COMPRESS_KV=False, SEP_ACCUMULATION:bool=True, USE_MAX_SEP_CACHE:bool=False, SEP_PADDING_IN_BATCH:bool=True) -> None:
494
+ """Update the KV cache and past token ids; compress the KV cache if necessary."""
495
+ assert layer_idx is not None, f"`layer_idx` must be given"
496
+ assert len(new_kv_pair) == 2, f"The length of `new_kv_pair` must be 2."
497
+ assert len(self.key_cache) == len(self.value_cache), f"The layer numbers of stored `self.key_cache` and `self.value_cache` must be the same."
498
+
499
+ self.append_past_tok_ids(input_ids, layer_idx)
500
+
501
+ key, value = new_kv_pair
502
+
503
+ if len(self.key_cache) <= layer_idx:
504
+ self.key_cache.append(key)
505
+ self.value_cache.append(value)
506
+ assert len(self.key_cache) - 1 == layer_idx, f"The key_cache should be updated sequentially according to the layer numbering."
507
+ assert len(self.value_cache) - 1 == layer_idx, f"The value_cache should be updated sequentially according to the layer numbering."
508
+ else:
509
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx] , key], dim=self.k_seq_dim)
510
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx] , value], dim=self.v_seq_dim)
511
+
512
+ assert len(self.key_cache) == len(self.value_cache), f"The layer numbers of stored key_cache and value_cache must be the same."
513
+ assert self.key_cache[layer_idx].shape[self.k_seq_dim] == self.value_cache[layer_idx].shape[self.v_seq_dim], "The seq length for key_cache and value_cache must be the same."
514
+
515
+ if COMPRESS_KV:
516
+ cmp_past_kv_pairs, cmp_past_tok_ids, offset_init_size_layer = self.compress_kv_cache_and_tokids_layer_wise((self.key_cache[layer_idx], self.value_cache[layer_idx]), layer_idx ,SEP_ACCUMULATION=SEP_ACCUMULATION, USE_MAX_SEP_CACHE=USE_MAX_SEP_CACHE, SEP_PADDING_IN_BATCH=SEP_PADDING_IN_BATCH )
517
+ self.set_kv_cache(cmp_past_kv_pairs, layer_idx)
518
+ self.set_past_tok_ids(cmp_past_tok_ids, layer_idx)
519
+ return offset_init_size_layer
520
+
521
+
522
+ def append_past_tok_ids(self, input_ids: torch.Tensor, layer_idx: int) -> None:
523
+ """Naively append the new `input_ids` to `self.past_tok_ids[layer_idx]`"""
524
+ assert layer_idx is not None, f"`layer_idx` must be given"
525
+
526
+ if len(self.past_tok_ids) <= layer_idx:
527
+ self.past_tok_ids.append(input_ids)
528
+ assert len(self.past_tok_ids) - 1 == layer_idx, f"The past_tok_ids should be updated sequentially according to the layer numbering."
529
+ else:
530
+ self.past_tok_ids[layer_idx] = torch.cat([self.past_tok_ids[layer_idx] , input_ids], dim=-1)
531
+
532
+
533
+ def compress_kv_cache_and_tokids_layer_wise(self, past_kv_pairs, layer_idx:int ,SEP_ACCUMULATION=False, USE_MAX_SEP_CACHE=False, SEP_PADDING_IN_BATCH=True ):
534
+ """
535
+ `SEP_ACCUMULATION`: If True, it means we will try to accumulate all the KVs for seperators. If False, only the `new_sep_kv` compressed from the `past_win_kv` will be kept (see function `compress_kv_cache_and_tokids_layer_wise`).
536
+
537
+ `USE_MAX_SEP_CACHE`: If True, it means we only keep at most `self.sep_cache_size` seperators' KVs. If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs. In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
538
+
539
+ `SEP_PADDING_IN_BATCH`: If True, it means that SepCache will pad separator tokens in other records to be aligned with the record with the most separators in a batch. If False, it means that SepCache will truncate older separator tokens in other records to be aligned with the record with the fewest separators in a batch.
540
+
541
+ Note: If `SEP_ACCUMULATION=True` and `USE_MAX_SEP_CACHE=False`, as the number of input tokens increases, the number of separators in the KV cache will also accumulate endlessly
542
+ and `self.cache_size` will also be infinitely expanded (no longer fixed).
543
+
544
+ When `SEP_PADDING_IN_BATCH=True` is used in combination with `USE_MAX_SEP_CACHE=False` and `SEP_ACCUMULATION=True`, the KV cache will accumulate indefinitely,
545
+ and since `SEP_PADDING_IN_BATCH=True`, the KVs of all separators will be retained (rather than being truncated).
546
+
547
+
548
+ For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
549
+ """
550
+
551
+ key, value = past_kv_pairs
552
+ seq_len = key.size(self.k_seq_dim)
553
+ assert seq_len == self.get_usable_length(layer_idx), f"The seq_len of cached past key and value states should be the same as the return of `get_usable_length()`, which is {self.get_usable_length(layer_idx)}"
554
+
555
+
556
+ left_padding_offset = self.left_padding_offset
557
+ assert left_padding_offset is not None
558
+ offset_init_size_layer = self.init_cache_size[layer_idx] + left_padding_offset
559
+ self._set_layer_wise_attribute("max_sep_exidx", self._list_element_add(self.sep_cache_size, self.init_cache_size, bias=left_padding_offset), self.layer_num)
560
+ self._CHECK_PARAMS_VALIDITY(layer_idx, left_padding_offset)
561
+
562
+ if self.sep_exrange[layer_idx] <=0:
563
+ self.sep_exrange[layer_idx] = offset_init_size_layer
564
+
565
+ assert seq_len - self.local_size[layer_idx] > self.sep_exrange[layer_idx]
566
+
567
+ if offset_init_size_layer > 0:
568
+ initial_kv, initial_tokids = self.slice_kv_cache_and_tokids( past_kv_pairs, self.past_tok_ids[layer_idx], 0, offset_init_size_layer, seq_len=seq_len, _CHECK_IDX=True )
569
+
570
+ Before_First_Time_Compress_Flag = (self.sep_exrange[layer_idx] == offset_init_size_layer) ## If true, it means the present timestamp is before t1: the 1st time to compress the past window, in which only seperators' kv are kept.
571
+
572
+ if SEP_ACCUMULATION and not Before_First_Time_Compress_Flag: ## To get the old sep kv and sep token ids.
573
+ past_sep_kv, past_sep_tokids = self.slice_kv_cache_and_tokids( past_kv_pairs, self.past_tok_ids[layer_idx], offset_init_size_layer, self.sep_exrange[layer_idx], seq_len=seq_len, _CHECK_IDX=True )
574
+
575
+ past_win_kv, past_win_tokids = self.slice_kv_cache_and_tokids( past_kv_pairs, self.past_tok_ids[layer_idx], self.sep_exrange[layer_idx], seq_len - self.local_size[layer_idx], seq_len=seq_len, _CHECK_IDX=True )
576
+
577
+
578
+ local_kv, local_tokids = self.slice_kv_cache_and_tokids( past_kv_pairs, self.past_tok_ids[layer_idx], seq_len - self.local_size[layer_idx], seq_len, seq_len=seq_len, _CHECK_IDX=True )
579
+
580
+ new_sep_kv, new_sep_tokids, min_sep_num, max_sep_num = self.compress_past_win_2_seps( past_win_kv, past_win_tokids, SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH ) ## To get the new sep kv and sep token ids that were just compressed from the past window
581
+
582
+ if SEP_ACCUMULATION and not Before_First_Time_Compress_Flag: ## Try to accumulate all the seen seps
583
+ sep_kv, sep_tokids = self.cat_kv_cache_and_tokids( [ past_sep_kv, new_sep_kv ] , [past_sep_tokids, new_sep_tokids ] )
584
+ new_sep_len = new_sep_tokids.shape[-1]
585
+ sep_len = sep_tokids.shape[-1]
586
+ else: ## Only keep the newly obtained kv (those just compressed from the past window)
587
+ sep_kv, sep_tokids = new_sep_kv, new_sep_tokids
588
+ # new_sep_len = new_sep_tokids.shape[-1]
589
+ sep_len = sep_tokids.shape[-1]
590
+ assert (SEP_PADDING_IN_BATCH and max_sep_num==sep_len) or ( (not SEP_PADDING_IN_BATCH) and min_sep_num==sep_len)
591
+
592
+
593
+ if USE_MAX_SEP_CACHE: ## Fixed sep cache size, i.e., only keep max_sep_len seps' kv in the cache.
594
+ if offset_init_size_layer + sep_len > self.max_sep_exidx[layer_idx]:
595
+ max_sep_len = self.max_sep_exidx[layer_idx] - offset_init_size_layer
596
+ assert sep_kv[0].shape[-2] == sep_tokids.shape[-1], f"The seq_len for seps' KVs and tok_ids should be the same."
597
+
598
+ sep_kv, sep_tokids = self.slice_kv_cache_and_tokids( sep_kv, sep_tokids, sep_len-max_sep_len, sep_len, seq_len = sep_tokids.shape[-1] ,_CHECK_IDX=True )
599
+ self.sep_exrange[layer_idx] = self.max_sep_exidx[layer_idx]
600
+ else:
601
+ self.sep_exrange[layer_idx] = offset_init_size_layer + sep_len
602
+
603
+ else: ## Extend the sep cache and the whole cache if USE_MAX_SEP_CACHE is not set
604
+ self.sep_exrange[layer_idx] = offset_init_size_layer + sep_len
605
+ if self.sep_exrange[layer_idx] > self.max_sep_exidx[layer_idx]:
606
+ cache_incremental_gap = self.sep_exrange[layer_idx] - self.max_sep_exidx[layer_idx]
607
+ self.max_sep_exidx[layer_idx] = self.sep_exrange[layer_idx]
608
+ self.sep_cache_size[layer_idx] = self.sep_cache_size[layer_idx] + cache_incremental_gap
609
+ self.cache_size[layer_idx] = self.cache_size[layer_idx] + cache_incremental_gap
610
+
611
+ if offset_init_size_layer > 0:
612
+ cmp_past_kv_pairs, cmp_past_tok_ids = self.cat_kv_cache_and_tokids( [initial_kv, sep_kv, local_kv ] , [initial_tokids, sep_tokids, local_tokids ] )
613
+ else:
614
+ cmp_past_kv_pairs, cmp_past_tok_ids = self.cat_kv_cache_and_tokids( [sep_kv, local_kv ] , [sep_tokids, local_tokids ] )
615
+
616
+ return cmp_past_kv_pairs, cmp_past_tok_ids, offset_init_size_layer
617
+
618
+
619
+ def compress_past_win_2_seps(self, past_win_kv: Tuple[torch.Tensor], past_win_tokids: torch.Tensor, MIN_SEP_ALERT: bool=False, SEP_PADDING_IN_BATCH: bool=True ) -> Tuple[Union[Tuple[torch.Tensor], torch.Tensor, int ]]:
620
+ """Compress the KVs in the past window into the sep cache where only separators' KVs are kept. Padding or Truncating if necessary."""
621
+ sep_index_tensor = torch.zeros_like(past_win_tokids).bool() # batch x seq_len
622
+
623
+ for sp_id in self.separator_token_ids:
624
+ sep_index_tensor = sep_index_tensor | ( past_win_tokids == sp_id ) # batch x seq_len
625
+
626
+ sep_cnt = sep_index_tensor.int().sum(-1)
627
+ min_sep_num = sep_cnt.min() # the min sep number for the seqs in a batch
628
+ max_sep_num = sep_cnt.max() # the max sep number for the seqs in a batch
629
+
630
+
631
+ if MIN_SEP_ALERT and not SEP_PADDING_IN_BATCH:
632
+ assert min_sep_num>0, f"The min sep number for each compressing time in a batch should be at least one if `MIN_SEP_ALERT=True` and `SEP_PADDING_IN_BATCH=False`"
633
+
634
+ batch1_sep_ids_list = []
635
+ batch_size = past_win_tokids.shape[0]
636
+ for b_id in range(batch_size):
637
+ batch1_sep_ids = past_win_tokids[b_id, sep_index_tensor[b_id]] # # sep_num
638
+ if SEP_PADDING_IN_BATCH: ## padding
639
+ sep_num = batch1_sep_ids.shape[-1]
640
+ padding_num = max_sep_num - sep_num
641
+ if padding_num > 0:
642
+ assert padding_num <= past_win_tokids.shape[-1], f"padding_num: {padding_num} should be <= past_win_tokids.shape[-1]:{past_win_tokids.shape[-1]}"
643
+ batch1_sep_ids = batch1_sep_ids # # sep_num
644
+ batch1_pad_ids = past_win_tokids[b_id, -padding_num:] # # padding_num
645
+ batch1_sep_ids = torch.cat([batch1_sep_ids, batch1_pad_ids], dim =-1) ## max_sep_num
646
+ else: ## truncating
647
+ batch1_sep_ids = batch1_sep_ids[..., :min_sep_num ] # # min_sep_num
648
+ batch1_sep_ids_list.append(batch1_sep_ids)
649
+
650
+ new_sep_tokids = torch.stack(batch1_sep_ids_list, dim=0) # # B x min_sep_num
651
+ key_cache, value_cache = past_win_kv
652
+
653
+ assert batch_size==key_cache.shape[0]
654
+ batch1_sep_k_list = []
655
+ batch1_sep_v_list = []
656
+ batch1_sep_ids_list = []
657
+ for b_id in range(batch_size):
658
+ batch1_sep_k = self.k_bat_dim_select(key_cache, b_id, sep_index_tensor[b_id], min_sep_num, max_sep_num, SEP_PADDING_IN_BATCH)
659
+ batch1_sep_k_list.append(batch1_sep_k)
660
+
661
+ batch1_sep_v = self.v_bat_dim_select(value_cache, b_id, sep_index_tensor[b_id], min_sep_num, max_sep_num, SEP_PADDING_IN_BATCH)
662
+ batch1_sep_v_list.append( batch1_sep_v )
663
+
664
+ sep_k = torch.stack(batch1_sep_k_list, dim=0) ## batch x head x min_sep_num x dim
665
+ sep_v = torch.stack(batch1_sep_v_list, dim=0) ## batch x head x min_sep_num x dim
666
+ new_sep_kv = (sep_k, sep_v)
667
+
668
+ return new_sep_kv, new_sep_tokids, min_sep_num, max_sep_num
669
+
670
+
671
+ def apply_shifted_pos_emb(self, layer_idx: int, APPLY_PES_INSIDE: bool, PREFILLING_FLAG: bool, key_states: torch.Tensor, query_states: torch.Tensor, position_ids: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None ) -> torch.Tensor:
672
+ """Perform positional encoding shifting if required"""
673
+ seq_len = self.get_usable_length(layer_idx)
674
+ keys_to_shift = self.key_cache[layer_idx]
675
+ queries_to_shift = query_states
676
+ assert keys_to_shift.shape[self.k_seq_dim] == seq_len
677
+
678
+ if cache_kwargs is None:
679
+ cache_kwargs = {}
680
+
681
+ if APPLY_PES_INSIDE:
682
+ if len(self._shifted_position_ids) <= layer_idx:
683
+ self._shifted_position_ids.append(None)
684
+
685
+ if PREFILLING_FLAG: ## for prefilling
686
+ assert position_ids.shape[-1] >= seq_len, f"The length of position_ids should be >= the usable length of kv cache when prefilling."
687
+ self._shifted_position_ids[layer_idx] = position_ids[:, :seq_len].detach()
688
+ shifted_pos_ids = self._shifted_position_ids[layer_idx]
689
+
690
+ elif self._shifted_position_ids[layer_idx].shape[-1] >= seq_len: ## for generation
691
+ assert position_ids.shape[-1] == 1, f"The length of query and position_ids should be 1 during generation."
692
+ shifted_pos_ids = self._shifted_position_ids[layer_idx][:, :seq_len].detach()
693
+
694
+ elif self._shifted_position_ids[layer_idx].shape[-1] < seq_len: ## for generation
695
+ assert position_ids.shape[-1] == 1, f"The length of query and position_ids should be 1 during generation."
696
+ increased_gap = seq_len - self._shifted_position_ids[layer_idx].shape[-1]
697
+ assert increased_gap < self._shifted_position_ids[layer_idx].shape[-1], f"Normally, for auto-regressive model, the input length for each step should be 1 during generation."
698
+
699
+ new_position_ids = self._shifted_position_ids[layer_idx][:, -increased_gap: ] + increased_gap
700
+ self._shifted_position_ids[layer_idx] = torch.cat([self._shifted_position_ids[layer_idx], new_position_ids.detach()], dim=-1)
701
+ shifted_pos_ids = self._shifted_position_ids[layer_idx]
702
+ else:
703
+ raise RuntimeError
704
+
705
+ cos, sin = self._get_naive_shifted_cos_sin(
706
+ key_states, shifted_pos_ids, seq_len
707
+ )
708
+
709
+ q_rope_idx = torch.arange( seq_len - query_states.shape[self.k_seq_dim], seq_len).to(cos.device)
710
+ cos_q, sin_q = cos.index_select(self._rope_seq_dim, q_rope_idx), sin.index_select(self._rope_seq_dim, q_rope_idx)
711
+
712
+ else:
713
+ sin = cache_kwargs.get("sin")
714
+ cos = cache_kwargs.get("cos")
715
+ sin_q = cache_kwargs.get("sin_q")
716
+ cos_q = cache_kwargs.get("cos_q")
717
+ shifted_pos_ids = cache_kwargs.get("shifted_pos_ids")
718
+ assert (sin is not None) and (cos is not None), f"sin and cos matrices should be be provided"
719
+ if sin_q is None:
720
+ q_rope_idx = torch.arange( seq_len - query_states.shape[self.k_seq_dim], seq_len).to(sin.device)
721
+ sin_q = sin.index_select(self._rope_seq_dim, q_rope_idx)
722
+ if cos_q is None:
723
+ q_rope_idx = torch.arange( seq_len - query_states.shape[self.k_seq_dim], seq_len).to(cos.device)
724
+ cos_q = cos.index_select(self._rope_seq_dim, q_rope_idx)
725
+
726
+ partial_rotation_size = cache_kwargs.get("partial_rotation_size")
727
+
728
+ # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
729
+ if partial_rotation_size is not None:
730
+ keys_to_shift, keys_pass = (
731
+ keys_to_shift[..., :partial_rotation_size],
732
+ keys_to_shift[..., partial_rotation_size:]
733
+ )
734
+ queries_to_shift, queries_pass = (
735
+ queries_to_shift[..., :partial_rotation_size],
736
+ queries_to_shift[..., partial_rotation_size:]
737
+ )
738
+
739
+ shifted_keys = self._apply_rotary_pos_emb_single(keys_to_shift, cos, sin, shifted_pos_ids, unsqueeze_dim=self._rope_unsqueeze_dim)
740
+ shifted_queries = self._apply_rotary_pos_emb_single(queries_to_shift, cos_q, sin_q, shifted_pos_ids[:, -queries_to_shift.shape[self.k_seq_dim] : ], unsqueeze_dim=self._rope_unsqueeze_dim)
741
+
742
+ if partial_rotation_size is not None:
743
+ shifted_keys = torch.cat( [shifted_keys, keys_pass], dim=-1)
744
+ shifted_queries = torch.cat( [shifted_queries, queries_pass], dim=-1)
745
+
746
+
747
+ return shifted_keys, shifted_queries
748
+
749
+
750
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
751
+ """Returns the sequence length of the seen tokens. A layer index can be optionally passed."""
752
+ return self._seen_tokens
753
+
754
+
755
+ def get_usable_length(self, layer_idx: int = 0) -> int:
756
+ """Returns the sequence length of the actual cached states. A layer index must be passed."""
757
+ if len(self.key_cache) <= layer_idx :
758
+ return 0
759
+ assert self.key_cache[layer_idx].shape[self.k_seq_dim] == self.value_cache[layer_idx].shape[self.v_seq_dim], f"`self.key_cache` and `self.value_cache` should have the same length."
760
+ return self.key_cache[layer_idx].shape[self.k_seq_dim]
761
+
762
+ def get_initial_pos_offset(self, layer_idx:int = 0) -> int:
763
+ """Return the number of padding tokens in the record with the most left padding tokens in a batch."""
764
+ assert isinstance(self.PADDING_ID, int), f"`self.PADDING_ID` should be correctly set."
765
+ assert len(self.past_tok_ids) > layer_idx, f"`self.past_tok_ids` for layer {layer_idx} must have been properly set."
766
+
767
+ past_tok_ids = self.past_tok_ids[layer_idx]
768
+ assert past_tok_ids is not None, f"`past_tok_ids` for layer {layer_idx} should not be None"
769
+
770
+ pad_index_tensor = (past_tok_ids == self.PADDING_ID) ## batch x seq_len
771
+ pad_toks_cnt = pad_index_tensor.int().sum(-1) ## [batch]
772
+ offset = pad_toks_cnt.max().item()
773
+
774
+ return offset
775
+
776
+
777
+ def get_batch_size(self) -> int:
778
+ """Return the batch size."""
779
+ assert self.key_cache is not None, f"`self.key_cache` should not be None."
780
+ assert self.value_cache is not None, f"`self.value_cache` should not be None."
781
+ assert len(self.key_cache) > 0, f"`self.key_cache` is empty. No batch size is available."
782
+ assert len(self.value_cache) > 0, f"self.value_cache is empty. No batch size is available."
783
+
784
+ assert len(self.value_cache) == len(self.key_cache), f"self.value_cache and self.key_cache should be at the same length."
785
+ assert self.value_cache[0].shape[0] == self.key_cache[0].shape[0], f"self.value_cache and self.key_cache should have the same batch size."
786
+
787
+ return self.value_cache[0].shape[0]
788
+
789
+ def get_kv_pair(self, layer_idx: int = None) -> Tuple[torch.Tensor]:
790
+ assert layer_idx is not None, f"`layer_idx` must be given."
791
+
792
+ if (len(self.key_cache) <= layer_idx) and (len(self.value_cache) <= layer_idx ):
793
+ key = self.key_cache[layer_idx]
794
+ value = self.value_cache[layer_idx]
795
+ else:
796
+ raise RuntimeError(f"The KV for layer:{layer_idx} have not been set.")
797
+ return (key, value)
798
+
799
+
800
+ def set_kv_cache(self, kv_pair: Tuple , layer_idx: int ) -> None:
801
+ assert len(kv_pair) == 2, f"The length of `kv_pair` must be 2."
802
+ self.key_cache[layer_idx] = kv_pair[0]
803
+ self.value_cache[layer_idx] = kv_pair[1]
804
+
805
+ def set_past_tok_ids(self, tok_ids: torch.Tensor, layer_idx:int) -> None:
806
+ self.past_tok_ids[layer_idx] = tok_ids
807
+
808
+
809
+ def cat_kv_cache_and_tokids(self, kv_pairs_list: List[Tuple[torch.Tensor]] , tok_ids_list:List[torch.Tensor]) -> Tuple[Union[Tuple[torch.Tensor],torch.Tensor]]:
810
+
811
+ return self.cat_kv_cache(kv_pairs_list), self.cat_token_ids(tok_ids_list)
812
+
813
+
814
+ def slice_kv_cache_and_tokids(self, kv_pair:Tuple[torch.Tensor], tok_ids_list:torch.Tensor, start:int, end:int, seq_len:int=None, _CHECK_IDX:bool=True, ) -> Tuple[Union[Tuple[torch.Tensor], torch.Tensor]]:
815
+
816
+ sliced_kv = self._slice_kv(start, end, kv_pair=kv_pair, seq_len=seq_len, _CHECK_IDX=_CHECK_IDX,)
817
+ sliced_tids = self._slice_tok_ids(start, end, tok_ids_list = tok_ids_list, seq_len=seq_len, _CHECK_IDX=_CHECK_IDX)
818
+
819
+ return sliced_kv , sliced_tids
820
+
821
+
822
+ def cat_kv_cache(self, kv_pairs_list: List[Tuple[torch.Tensor]] ) -> Tuple[torch.Tensor]:
823
+ assert len(kv_pairs_list) >= 1
824
+
825
+ if len(kv_pairs_list) == 1 :
826
+ return kv_pairs_list[0]
827
+ else:
828
+ ret = None
829
+ for i, kv_pair in enumerate(kv_pairs_list): # enumerate all the KVs needed to be cat
830
+ if i == 0:
831
+ ret = kv_pair
832
+ else:
833
+ ret = self._cat_kv(ret, kv_pair)
834
+ return ret
835
+
836
+
837
+ def cat_token_ids(self, tok_ids_list:List[torch.Tensor] ) -> torch.Tensor :
838
+ assert len(tok_ids_list) >= 1
839
+
840
+ return torch.cat(tok_ids_list, dim=-1)
841
+
842
+
843
+ def _cat_kv(self, kv_pair_a:Tuple[torch.Tensor], kv_pair_b:Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
844
+ k_a, v_a = kv_pair_a
845
+ k_b, v_b = kv_pair_b
846
+
847
+ cat_k = torch.cat([k_a, k_b], dim=self.k_seq_dim)
848
+ cat_v = torch.cat([v_a, v_b], dim=self.v_seq_dim)
849
+ return (cat_k, cat_v)
850
+
851
+
852
+ def _slice_kv(self, start:int, end:int, kv_pair: Tuple[torch.Tensor], seq_len:int=None, _CHECK_IDX:bool=True) -> Tuple[torch.Tensor] :
853
+ assert kv_pair is not None, f"kv_pair must NOT be None when slicing it."
854
+ key_cache = kv_pair[0]
855
+ value_cache = kv_pair[1]
856
+
857
+ if _CHECK_IDX:
858
+ assert seq_len is not None, f"seq_len must be given for checking the index for slicing"
859
+ start, end = self._CHECK_IDX(start, end, seq_len)
860
+
861
+ sliced_key_cache = self.k_slice(key_cache, start, end)
862
+ sliced_value_cache = self.v_slice(value_cache, start, end)
863
+
864
+ return ( sliced_key_cache, sliced_value_cache)
865
+
866
+
867
+ def _slice_tok_ids(self, start:int, end:int, tok_ids_list:torch.Tensor , seq_len:int=None, _CHECK_IDX:bool=False) -> torch.Tensor:
868
+ assert tok_ids_list is not None, f"tok_ids_list must NOT be None when slicing it."
869
+
870
+ if _CHECK_IDX:
871
+ assert seq_len is not None, f"seq_len must be given for checking the index for slicing"
872
+ start, end = self._CHECK_IDX(start, end, seq_len)
873
+
874
+ sliced_tok_ids = tok_ids_list[:, start:end]
875
+ return sliced_tok_ids
876
+
877
+ def _set_layer_wise_attribute(self, name: str, value: Any, layer_num:int ):
878
+ """Set layer-wise attributes"""
879
+ if isinstance(value, int):
880
+ setattr(self, name, [value] * layer_num)
881
+ elif isinstance(value, (list, tuple)):
882
+ assert len(value) == layer_num, f"The length of {name}: {len(value)} must be equal to `layer_num`: {layer_num}"
883
+ setattr(self, name, list(value))
884
+ else:
885
+ raise TypeError(f"{name} must be of the type `int` or `list` but got `{type(value)}`")
886
+
887
+ def _list_element_add(self, list_a: List, list_b: List, bias: int=0, dtype = int, device = 'cpu') -> List:
888
+ """Element-wise addition between two lists."""
889
+ assert len(list_a) == len(list_b), f"The length of `list_a` ({len(list_a)}) must be equal to that of `list_b` ({len(list_b)})."
890
+ tensor_c = torch.tensor(list_a, dtype=dtype, device=device) + torch.tensor(list_b, dtype=dtype, device=device) + torch.tensor([bias], dtype=dtype, device=device)
891
+ return tensor_c.int().tolist()
892
+
893
+ def _CHECK_IDX(self, start: int = 0, end: int = 100, seq_len: int = 1000):
894
+ assert isinstance(start, int) and isinstance(end, int) and isinstance(seq_len, int), f"`start`, `end`, `seq_len` must be `int`."
895
+ assert seq_len>0 , f"`seq_len` must > 0"
896
+
897
+ if start <0 :
898
+ start = start % seq_len
899
+ if end < 0 :
900
+ end = end % seq_len
901
+ assert (start >=0) and (start < seq_len) , f"start:{start}, end:{end}, seq_len:{seq_len}"
902
+ assert (end >= 0) and (end <= seq_len) , f"start:{start}, end:{end}, seq_len:{seq_len}"
903
+ assert start < end, f"start:{start}, end:{end}, seq_len:{seq_len}"
904
+
905
+ return start,end
906
+
907
+ def _CHECK_PARAMS_VALIDITY(self, layer_idx:int, left_padding_offset:int):
908
+ assert len(self.cache_size) > layer_idx
909
+ assert len(self.init_cache_size) > layer_idx
910
+ assert len(self.sep_cache_size) > layer_idx
911
+ assert len(self.max_sep_exidx) > layer_idx
912
+ assert len(self.local_size) > layer_idx
913
+
914
+ assert self.cache_size[layer_idx] > 0 , f"`self.cache_size` for layer:{layer_idx} must be greater than 0"
915
+ assert self.init_cache_size[layer_idx] >= 0 , f"`self.init_cache_size` for layer:{layer_idx} must be greater than (equal to) 0"
916
+ assert self.local_size[layer_idx] > 0 , f"`self.local_size` for layer:{layer_idx} must be greater than 0"
917
+
918
+ assert self.sep_cache_size[layer_idx] > 0 , f"`self.sep_cache_size` for layer:{layer_idx} must be greater than 0"
919
+ assert self.max_sep_exidx[layer_idx] > 0 , f"`self.max_sep_exidx` for layer:{layer_idx} must be greater than 0"
920
+ assert self.init_cache_size[layer_idx] + self.sep_cache_size[layer_idx] + self.local_size[layer_idx] + left_padding_offset < self.cache_size[layer_idx], f"`init_cache_size` ({self.init_cache_size[layer_idx]}) + `sep_cache_size` ({self.sep_cache_size[layer_idx]}) + `local_size` ({self.local_size[layer_idx]}) + `left_padding_offset` ({left_padding_offset}) for layer {layer_idx} should be less than `cache_size`:({self.cache_size[layer_idx]}) for layer {layer_idx}, i.e., a + s + w + (left_padding_offset) < c. Please increase `cache_size` if applicable."
921
+
922
+
923
+
924
+ def _rotate_half(self, x):
925
+ """Rotates half the hidden dims of the input."""
926
+ x1 = x[..., : x.shape[-1] // 2]
927
+ x2 = x[..., x.shape[-1] // 2 :]
928
+ return torch.cat((-x2, x1), dim=-1)
929
+
930
+ def _apply_rotary_pos_emb_single(self, k, cos, sin, position_ids=None, unsqueeze_dim=1):
931
+ """Applies Rotary Position Embedding to the query and key tensors.
932
+
933
+ Args:
934
+ q (`torch.Tensor`): The query tensor.
935
+ k (`torch.Tensor`): The key tensor.
936
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
937
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
938
+ position_ids (`torch.Tensor`, *optional*):
939
+ Deprecated and unused.
940
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
941
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
942
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
943
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
944
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
945
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
946
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
947
+ Returns:
948
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
949
+ """
950
+ cos = cos.unsqueeze(unsqueeze_dim) # batch x seq_len x dim --> batch x 1 x seq_len x dim
951
+ sin = sin.unsqueeze(unsqueeze_dim)
952
+ k_embed = (k * cos) + (self._rotate_half(k) * sin)
953
+ return k_embed
954
+
955
+
956
+ def _get_naive_shifted_cos_sin(self, x: torch.Tensor, position_ids: torch.Tensor=None, seq_len=None):
957
+ # x: [batch, num_attention_heads, seq_len, head_size]
958
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
959
+ position_ids_expanded = position_ids[:, None, :].float()
960
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
961
+ emb = torch.cat((freqs, freqs), dim=-1)
962
+ cos = emb.cos().to(dtype=x.dtype)
963
+ sin = emb.sin().to(dtype=x.dtype)
964
+ # backwards compatibility
965
+ self._cos_cached = cos
966
+ self._sin_cached = sin
967
+ return cos, sin
968
+
969
+
970
+ def _get_scaled_shifted_cos_sin(self, x, position_ids, seq_len=None):
971
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
972
+ position_ids = position_ids.float() / self.scaling_factor
973
+ cos, sin = self._get_naive_shifted_cos_sin(x, position_ids, seq_len)
974
+ return cos, sin
975
+
976
+
977
+ def _get_dynamicNTK_scaling_shifted_cos_sin(self, x, position_ids, seq_len=None):
978
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
979
+ seq_len = torch.max(position_ids) + 1
980
+ if seq_len > self.max_position_embeddings:
981
+ base = self.base * (
982
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
983
+ ) ** (self.dim / (self.dim - 2))
984
+ inv_freq = 1.0 / (
985
+ base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
986
+ )
987
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO: this may break with compilation
988
+
989
+ cos, sin = self._get_naive_shifted_cos_sin(x, position_ids, seq_len)
990
+ return cos, sin
991
+
992
+
993
+ def _update_kv_ratio(self, kv_len_cmp:int, kv_len_ori:int, layer_idx: int=0) -> None:
994
+ """Update the KV ratios which are for statistics and debugging."""
995
+ if len(self._kept_kv_ratio) <= layer_idx:
996
+ self._kept_kv_ratio.append( (kv_len_cmp, kv_len_ori ) )
997
+ else:
998
+ old_kv_len_cmp = self._kept_kv_ratio[layer_idx][0]
999
+ old_kv_len_ori = self._kept_kv_ratio[layer_idx][1]
1000
+ self._kept_kv_ratio[layer_idx] = (old_kv_len_cmp + kv_len_cmp, old_kv_len_ori + kv_len_ori )
1001
+
1002
+ def _print_kv_ratio(self, layer_idx : int, LAYER_WISE: bool = False):
1003
+ """Print the KV ratios."""
1004
+ self._print_kv_ratio_count += 1
1005
+ if LAYER_WISE:
1006
+ if self._print_kv_ratio_count % self.print_KV_inside_per_steps == 0:
1007
+ print(f"######################## [Kept Tokens, Seen Tokens] : {self._kept_kv_ratio[layer_idx]}, Ratio: { (self._kept_kv_ratio[layer_idx][0]+1e-6) / (self._kept_kv_ratio[layer_idx][1]+1e-6) :.4f} ########################")
1008
+
1009
+ elif self._print_kv_ratio_count % (self.print_KV_inside_per_steps * self.layer_num) == 0:
1010
+ print(f"######################## [Kept Tokens, Seen Tokens] : {self._kept_kv_ratio[layer_idx]}, Ratio: { (self._kept_kv_ratio[layer_idx][0]+1e-6) / (self._kept_kv_ratio[layer_idx][1]+1e-6) :.4f} ########################")
1011
+
1012
+
1013
+ @classmethod ## Deprecated
1014
+ def from_legacy_cache(cls,
1015
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1016
+
1017
+ ## For SepLLM
1018
+ init_cache_size: Union[int, List] = 4,
1019
+ sep_cache_size: Union[int, List] = 64,
1020
+ local_size: Union[int, List]=256,
1021
+ cache_size: Union[int, List]=512,
1022
+ SEP_ACCUMULATION: bool = True,
1023
+ USE_MAX_SEP_CACHE: bool = False,
1024
+ SEP_PADDING_IN_BATCH: bool = False,
1025
+ separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided. set it to `[-1]` to degrade SepCache to StreamingLLM's SinkCache
1026
+ PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
1027
+
1028
+ ## For inheritance & initialization states
1029
+ past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
1030
+ key_cache: List[torch.Tensor] = None,
1031
+ value_cache: List[torch.Tensor] = None,
1032
+
1033
+ ## For debugging
1034
+ PRINT_KV_RATIO_INSIDE: bool = False,
1035
+ print_KV_inside_per_steps: int = 1000,
1036
+ _seen_tokens: int = 0,
1037
+ _kept_kv_ratio: List[Tuple[int]] = None,
1038
+
1039
+ ### For positional encoding shifting
1040
+ APPLY_PE_SHIFT: bool = False,
1041
+ APPLY_PES_INSIDE: bool = True,
1042
+ _shifted_position_ids: List[torch.Tensor] = None,
1043
+ _rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
1044
+ _rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
1045
+ pe_scaling_factor:float = 1.0,
1046
+ pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
1047
+ max_position_embeddings: int = 8192,
1048
+ base: int=10000, ## The base for RoPE.
1049
+
1050
+ ## For basic transformer architecture
1051
+ k_seq_dim: int=2, ## The dimension for seq_len in key tensors
1052
+ v_seq_dim: int=2, ## The dimension for seq_len in value tensors
1053
+ layer_num: int = None, ## required for initialization
1054
+
1055
+ model_type: str = None, ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
1056
+ device = None
1057
+ ) -> "SepCache":
1058
+ """Deprecated: Converts a cache in the legacy cache format into `SepCache`."""
1059
+
1060
+ if past_key_values is not None:
1061
+ assert len(past_key_values)==0, f"`from_legacy_cache` function is deprecated. You can only use it when `past_key_values=None` or `past_key_values` is empty, in which case, `from_legacy_cache` is equivalent to the `__init__` function."
1062
+ past_key_values = None
1063
+
1064
+ assert past_key_values is None, f"`from_legacy_cache` function is deprecated. You can only use it when `past_key_values=None` or `past_key_values` is empty, in which case, `from_legacy_cache` is equivalent to the `__init__` function."
1065
+
1066
+
1067
+ if past_key_values is not None: ## Deprecated
1068
+ key_cache = []
1069
+ value_cache = []
1070
+
1071
+ for i, kv in enumerate(past_key_values):
1072
+ if i == 0:
1073
+ past_tok_ids = [] if len(kv) == 4 else past_tok_ids
1074
+
1075
+ if len(kv) == 4:
1076
+ k, v, p_tok_ids, _seen_tokens = kv
1077
+ key_cache.append(k)
1078
+ value_cache.append(v)
1079
+ past_tok_ids.append(p_tok_ids)
1080
+ _seen_tokens = _seen_tokens
1081
+ elif len(kv) == 2:
1082
+ k, v = kv
1083
+ key_cache.append(k)
1084
+ value_cache.append(v)
1085
+
1086
+ cache = cls(
1087
+ ## For SepLLM
1088
+ init_cache_size = init_cache_size,
1089
+ sep_cache_size = sep_cache_size,
1090
+ local_size = local_size,
1091
+ cache_size = cache_size,
1092
+ SEP_ACCUMULATION = SEP_ACCUMULATION,
1093
+ USE_MAX_SEP_CACHE = USE_MAX_SEP_CACHE,
1094
+ SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH,
1095
+ separator_token_ids = separator_token_ids,
1096
+ PADDING_ID = PADDING_ID,
1097
+
1098
+ ## For inheritance & initialization states
1099
+ past_tok_ids = past_tok_ids, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache
1100
+ key_cache = key_cache,
1101
+ value_cache = value_cache,
1102
+
1103
+ ## For debugging
1104
+ PRINT_KV_RATIO_INSIDE = PRINT_KV_RATIO_INSIDE,
1105
+ print_KV_inside_per_steps = print_KV_inside_per_steps,
1106
+ _seen_tokens = _seen_tokens,
1107
+ _kept_kv_ratio = _kept_kv_ratio,
1108
+
1109
+ ### For positional encoding shifting
1110
+ APPLY_PE_SHIFT = APPLY_PE_SHIFT,
1111
+ APPLY_PES_INSIDE = APPLY_PES_INSIDE,
1112
+ _shifted_position_ids = _shifted_position_ids,
1113
+ _rope_unsqueeze_dim = _rope_unsqueeze_dim,
1114
+ _rope_seq_dim = _rope_seq_dim,
1115
+ pe_scaling_factor = pe_scaling_factor,
1116
+ pe_dim = pe_dim,
1117
+ max_position_embeddings = max_position_embeddings,
1118
+ base = base,
1119
+
1120
+ ## For basic transformer architecture
1121
+ k_seq_dim = k_seq_dim,
1122
+ v_seq_dim = v_seq_dim,
1123
+ layer_num = layer_num,
1124
+
1125
+ model_type = model_type,
1126
+ device = device,
1127
+ )
1128
+
1129
+ return cache
1130
+
1131
+
1132
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]]: ## Deprecated
1133
+ """Deprecated: Converts the `SepCache` instance into the legacy cache format, i.e., tuple."""
1134
+ print(">>>>>>>>>>>Warnings: Please try to avoid using this deprecated `to_legacy_cache` function since it will drop many useful parameters or states in SepCache.<<<<<<<<<<<")
1135
+ legacy_cache = ()
1136
+ for layer_idx in range(len(self.key_cache)):
1137
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.past_tok_ids[layer_idx], self._seen_tokens), )
1138
+ return legacy_cache
1139
+
1140
+
1141
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
1142
+ if layer_idx < len(self):
1143
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
1144
+ else:
1145
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
1146
+
1147
+ def __iter__(self):
1148
+ """
1149
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
1150
+ keys and values
1151
+ """
1152
+ for layer_idx in range(len(self)):
1153
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
1154
+
1155
+ def __len__(self):
1156
+ """
1157
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
1158
+ to the number of layers in the model.
1159
+ """
1160
+ if self.key_cache is not None:
1161
+ return len(self.key_cache)
1162
+ else:
1163
+ return 0
1164
+
1165
+ @property
1166
+ def seen_tokens(self):
1167
+ if hasattr(self, "_seen_tokens"):
1168
+ return self._seen_tokens
1169
+ else:
1170
+ return None
1171
+
1172
+
1173
+
1174
+ class KVUsageCounter:
1175
+ def __init__(self, PRINT_KV_per_ITERs:int = 100):
1176
+ """
1177
+ For detailed usage instructions, please refer to sepllm.github.io
1178
+ """
1179
+ self._total_kept_kv_ratio = (0, 0)
1180
+ self._printing_counter = 0
1181
+ self.PRINT_KV_per_ITERs = PRINT_KV_per_ITERs
1182
+
1183
+ def accumulate_historical_kv_stats(self, cache: SepCache = None) -> None:
1184
+ assert cache is not None, f"The KV cache object (of the class SepCache) must be given."
1185
+ assert hasattr(cache, "_kept_kv_ratio"), f"The cache object must have the attribute _kept_kv_ratio."
1186
+ assert hasattr(cache, "layer_num"), f"The cache object must have the attribute layer_num."
1187
+
1188
+
1189
+ assert len(cache._kept_kv_ratio) == cache.layer_num, f"The length ({cache._kept_kv_ratio}) of cache object's _kept_kv_ratio attribute must be equal to layer_num ({cache.layer_num})"
1190
+ for ly in range(cache.layer_num):
1191
+ self._total_kept_kv_ratio = (self._total_kept_kv_ratio[0] + cache._kept_kv_ratio[ly][0], self._total_kept_kv_ratio[1] + cache._kept_kv_ratio[ly][1] )
1192
+ self._printing_counter += 1
1193
+
1194
+ def WHETHER_2_PRINT(self) -> bool:
1195
+ return (self._printing_counter % self.PRINT_KV_per_ITERs) == 0
1196
+
1197
+
1198
+ def print_KV_ratio(self) -> None:
1199
+ print(f"######################## The KVs for ALL layers: (KV number kept for predicting current token)/(Total seen KV number) ########################")
1200
+ print(f"########################>>>>>>>>>>> [Kept Tokens, Seen Tokens] : {self._total_kept_kv_ratio}, Ratio: { (self._total_kept_kv_ratio[0]+1e-6) / (self._total_kept_kv_ratio[1]+1e-6):.4f} <<<<<<<<<<<<##########################")
1201
+ print(f"######################## -------------------------------------------------------------------------------------------- ########################")
1202
+
1203
+
1204
+
1205
+