from typing import Optional, Tuple, Union import torch from torch import nn import torch.nn.functional as F from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Model, Qwen2DecoderLayer, Qwen2PreTrainedModel, Qwen2ForCausalLM ) from .configuration_qwen2_mla import Qwen2MLAConfig from transformers.models.gemma2.modeling_gemma2 import ( eager_attention_forward, # for supporting softcap logger ) from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( apply_rotary_pos_emb_interleave, DeepseekV3RMSNorm ) class MLAAttention(nn.Module): """ Modified from `transformers.models.llama.modeling_deepseek_v3.DeepseekV3Attention` add support for attention bias and softcapping """ def __init__(self, config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.attention_dropout = config.attention_dropout self.num_heads = config.num_attention_heads self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.kv_lora_rank = config.kv_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.v_head_dim = config.v_head_dim self.qk_head_dim = config.qk_head_dim self.qk_latent_layernorm = getattr(config, "qk_latent_layernorm", True) self.is_causal = True if self.q_lora_rank is None: self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=config.attention_bias) else: self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=False) if self.qk_latent_layernorm: self.q_a_layernorm = DeepseekV3RMSNorm(self.q_lora_rank) self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=config.attention_bias) self.kv_a_proj_with_mqa = nn.Linear( config.hidden_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=config.attention_bias, ) if self.qk_latent_layernorm: self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) self.kv_b_proj = nn.Linear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, config.hidden_size, bias=False, ) self.scaling = self.qk_head_dim**-0.5 def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: batch_size, seq_length = hidden_states.shape[:-1] query_shape = (batch_size, seq_length, -1, self.qk_head_dim) key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) if self.q_lora_rank is None: q_states = self.q_proj(hidden_states) elif self.qk_latent_layernorm: q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) else: q_states = self.q_b_proj(self.q_a_proj(hidden_states)) q_states = q_states.view(query_shape).transpose(1, 2) q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) if self.qk_latent_layernorm: k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) else: k_pass = self.kv_b_proj(k_pass).view(key_shape).transpose(1, 2) k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) cos, sin = position_embeddings q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) k_rot = k_rot.expand(*k_pass.shape[:-1], -1) query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) attention_interface = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, softcap=getattr(self.config, "attn_logit_softcapping", None), **kwargs, ) if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: attn_output = attn_output[:, :, :, : self.v_head_dim] attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen2MLADecoderLayer(Qwen2DecoderLayer): """ Qwen2 decoder layer with MLA (Multi-Head Latent Attention) instead of standard attention. This class inherits from Qwen2DecoderLayer and only replaces the self_attn component. """ def __init__(self, config: Qwen2MLAConfig, layer_idx: int): super().__init__(config, layer_idx) # Replace the standard Qwen2 attention with MLA attention self.self_attn = MLAAttention(config, layer_idx) class Qwen2MLAPreTrainedModel(Qwen2PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models for Qwen2 with MLA attention. """ config_class = Qwen2MLAConfig _no_split_modules = ["Qwen2MLADecoderLayer"] class Qwen2MLAModel(Qwen2MLAPreTrainedModel, Qwen2Model): """ The Qwen2 model with MLA attention layers. This model inherits from both Qwen2MLAPreTrainedModel and Qwen2Model, replacing the standard Qwen2 decoder layers with MLA-enabled ones. """ def __init__(self, config: Qwen2MLAConfig): super().__init__(config) # Replace the layers with MLA-enabled decoder layers self.layers = nn.ModuleList( [Qwen2MLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) class Qwen2MLAForCausalLM(Qwen2MLAPreTrainedModel, Qwen2ForCausalLM): """ The Qwen2 model with MLA attention for causal language modeling. This model can be used for text generation tasks, providing the same interface as Qwen2ForCausalLM but with MLA attention mechanism. """ def __init__(self, config: Qwen2MLAConfig): super().__init__(config) # Replace the base model with the MLA version self.model = Qwen2MLAModel(config) # Export the main classes for external use __all__ = [ "Qwen2MLAForCausalLM", "Qwen2MLAModel", "Qwen2MLAPreTrainedModel", ]