import torch import inspect import importlib from typing import Callable, Optional, Union, Any, List from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.cache_utils import Cache from transformers.processing_utils import Unpack from .sep_cache_utils import SepCache def truncate_input_ids_4_autoregression(input_ids, key_states): if input_ids.shape[-1] != key_states.shape[-2]: assert input_ids.shape[-1] >= key_states.shape[-2] truncated_input_ids = input_ids[..., -key_states.shape[-2]: ] return truncated_input_ids else: return input_ids def llama_atten_forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] if hasattr(self, "head_dim"): head_dim = self.head_dim elif hasattr(self, "head_size"): head_dim = self.head_size hidden_shape = (*input_shape, -1, head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) ###########################SepCache######################## assert isinstance(past_key_value, SepCache), f"`past_key_value` must be of the type: `SepCache`." APPLY_PE_SHIFT = past_key_value.APPLY_PE_SHIFT APPLY_PES_INSIDE = past_key_value.APPLY_PES_INSIDE ########################################################### ########################Monkey Patching#################### module = importlib.import_module(self.__module__) apply_rotary_pos_emb = module.apply_rotary_pos_emb rotate_half = module.rotate_half eager_attention_forward = module.eager_attention_forward ALL_ATTENTION_FUNCTIONS = module.ALL_ATTENTION_FUNCTIONS ########################################################### if not APPLY_PE_SHIFT: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # ##################################################Default######################################################### # sin and cos are specific to RoPE models; cache_position needed for the static cache # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # ################################################################################################################## ##################################################SepCache######################################################### # sin and cos are specific to RoPE models; position_ids needed for the static cache if APPLY_PE_SHIFT and (not APPLY_PES_INSIDE): ### At least the shifted `sin` and `cos` should be properly provided (not `None`). cache_kwargs = {"sin": sin, "cos": cos, "cos_q": cos_q, "sin_q": sin_q, "cache_position": cache_position, "partial_rotation_size": None } else: cache_kwargs = {} if "kwargs" in locals(): pass elif "flash_attn_kwargs" in locals(): kwargs = flash_attn_kwargs else: raise NameError("`kwargs` or `flash_attn_kwargs` should be given and they need to contain `sepllm_kwargs` (which contains `input_ids`) and `position_ids`.") if "input_ids" not in locals(): if "input_ids" in kwargs: input_ids = kwargs.get("input_ids", None) else: sepllm_kwargs = kwargs.get("sepllm_kwargs", None) assert sepllm_kwargs is not None, f"`sepllm_kwargs` must be provided when `input_ids` is not given." input_ids = sepllm_kwargs.get("input_ids", None) assert input_ids is not None, f"`input_ids` must be properly provided directly or through `sepllm_kwargs` when calling `update()` in `SepCache`." if "position_ids" not in locals(): position_ids = kwargs.get("position_ids") assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`." bsz, q_len, _ = hidden_states.size() input_ids = truncate_input_ids_4_autoregression(input_ids = input_ids, key_states = key_states ) if APPLY_PE_SHIFT: key_states, value_states, query_states = past_key_value.update( key_states = key_states, value_states = value_states, query_states = query_states, input_ids = input_ids, layer_idx = self.layer_idx, position_ids = position_ids, PREFILLING_FLAG = q_len > 1, cache_kwargs = cache_kwargs ) else: key_states, value_states = past_key_value.update( key_states = key_states, value_states = value_states, input_ids = input_ids, layer_idx = self.layer_idx, position_ids = position_ids, PREFILLING_FLAG = q_len > 1, cache_kwargs = cache_kwargs ) seq_len = past_key_value.get_usable_length(self.layer_idx) if attention_mask is not None: attention_mask = attention_mask[..., :seq_len] ################################################################################################################## attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights def _validate_model_kwargs(self, model_kwargs: dict[str, Any]): """Validates model kwargs for generation. Generate argument typos will also be caught here.""" # If a `Cache` instance is passed, checks whether the model is compatible with it if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: raise ValueError( f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " "check the model documentation for supported cache formats." ) # Excludes arguments that are handled before calling any model function if self.config.is_encoder_decoder: for key in ["decoder_input_ids"]: model_kwargs.pop(key, None) unused_model_args = [] model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters) # `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) if "kwargs" in model_args or "model_kwargs" in model_args: model_args |= set(inspect.signature(self.forward).parameters) # Encoder-Decoder models may also need Encoder arguments from `model_kwargs` if self.config.is_encoder_decoder: base_model = getattr(self, self.base_model_prefix, None) # allow encoder kwargs encoder = getattr(self, "encoder", None) # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`. # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder` # TODO: A better way to handle this. if encoder is None and base_model is not None: encoder = getattr(base_model, "encoder", None) if encoder is not None: encoder_model_args = set(inspect.signature(encoder.forward).parameters) model_args |= encoder_model_args # allow decoder kwargs decoder = getattr(self, "decoder", None) if decoder is None and base_model is not None: decoder = getattr(base_model, "decoder", None) if decoder is not None: decoder_model_args = set(inspect.signature(decoder.forward).parameters) model_args |= {f"decoder_{x}" for x in decoder_model_args} for key, value in model_kwargs.items(): # #############################Default########################### # if value is not None and key not in model_args: # unused_model_args.append(key) # ############################################################### ###############################SepCache########################### if (value is not None) and (key not in model_args) and ("sep" not in str(key).lower()): unused_model_args.append(key) ################################################################### if unused_model_args: raise ValueError( f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" " generate arguments will also show up in this list)" )