renll commited on
Commit
86604fc
·
verified ·
1 Parent(s): 590843e

Upload 12 files

Browse files
added_tokens.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<|/tool_call|>": 200026,
3
+ "<|/tool|>": 200024,
4
+ "<|assistant|>": 200019,
5
+ "<|end|>": 200020,
6
+ "<|system|>": 200022,
7
+ "<|tag|>": 200028,
8
+ "<|tool_call|>": 200025,
9
+ "<|tool_response|>": 200027,
10
+ "<|tool|>": 200023,
11
+ "<|user|>": 200021
12
+ }
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Phi4FlashForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_phi4flash.Phi4FlashConfig",
8
+ "AutoModelForCausalLM": "modeling_phi4flash.Phi4FlashForCausalLM",
9
+ "AutoTokenizer": "Xenova/gpt-4o"
10
+ },
11
+ "pad_token_id": 199999,
12
+ "bos_token_id": 199999,
13
+ "embd_pdrop": 0.0,
14
+ "eos_token_id": 199999,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 2560,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 10240,
19
+ "layer_norm_eps": 1e-5,
20
+ "max_position_embeddings": 262144,
21
+ "_attn_implementation": "flash_attention_2",
22
+ "mb_per_layer": 2,
23
+ "model_type": "phi4flash",
24
+ "num_attention_heads": 40,
25
+ "num_hidden_layers": 32,
26
+ "num_key_value_heads": 20,
27
+ "resid_pdrop": 0.0,
28
+ "sliding_window": 512,
29
+ "torch_dtype": "bfloat16",
30
+ "tie_word_embeddings": true,
31
+ "transformers_version": "4.46.1",
32
+ "use_cache": true,
33
+ "mlp_bias": false,
34
+ "lm_head_bias": false,
35
+ "vocab_size": 200064
36
+ }
37
+
configuration_phi4flash.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Phi4Flash model configuration"""
17
+
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+ import math
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class Phi4FlashConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Phi4FlashModel`]. It is used to instantiate an Phi4Flash
28
+ model according to the specified arguments, defining the model architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 51200):
35
+ Vocabulary size of the Phi4Flash model. Defines the number of different tokens that can be represented by the
36
+ `inputs_ids` passed when calling [`Phi4FlashModel`].
37
+ hidden_size (`int`, *optional*, defaults to 2048):
38
+ Dimension of the hidden representations.
39
+ intermediate_size (`int`, *optional*, defaults to 8192):
40
+ Dimension of the MLP representations.
41
+ num_hidden_layers (`int`, *optional*, defaults to 24):
42
+ Number of hidden layers in the Transformer decoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 32):
44
+ Number of attention heads for each attention layer in the Transformer decoder.
45
+ num_key_value_heads (`int`, *optional*):
46
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
47
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
48
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
49
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
50
+ by meanpooling all the original heads within that group. For more details checkout [this
51
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
52
+ `num_attention_heads`.
53
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
54
+ Dropout probability for mlp outputs.
55
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
56
+ The dropout ratio for the embeddings.
57
+ attention_dropout (`float`, *optional*, defaults to 0.0):
58
+ The dropout ratio after computing the attention scores.
59
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`):
60
+ The non-linear activation function (function or string) in the decoder.
61
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
62
+ The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048
63
+ tokens.
64
+ initializer_range (`float`, *optional*, defaults to 0.02):
65
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
67
+ The epsilon used by the rms normalization layers.
68
+ use_cache (`bool`, *optional*, defaults to `True`):
69
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
70
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
71
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
72
+ Whether to tie weight embeddings
73
+ rope_theta (`float`, *optional*, defaults to 10000.0):
74
+ The base period of the RoPE embeddings.
75
+
76
+ Example:
77
+
78
+ ```python
79
+ >>> from transformers import Phi4FlashModel, Phi4FlashConfig
80
+
81
+ >>> # Initializing a Phi4Flash style configuration
82
+ >>> configuration = Phi4FlashConfig.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
83
+
84
+ >>> # Initializing a model from the configuration
85
+ >>> model = Phi4FlashModel(configuration)
86
+
87
+ >>> # Accessing the model configuration
88
+ >>> configuration = model.config
89
+ ```"""
90
+
91
+ model_type = "phi4flash"
92
+ keys_to_ignore_at_inference = ["past_key_values"]
93
+
94
+ def __init__(
95
+ self,
96
+ vocab_size=51200,
97
+ hidden_size=2560,
98
+ intermediate_size=9216,
99
+ num_hidden_layers=32,
100
+ num_attention_heads=40,
101
+ num_key_value_heads=4,
102
+ resid_pdrop=0.0,
103
+ embd_pdrop=0.0,
104
+ attention_dropout=0.0,
105
+ hidden_act="silu",
106
+ max_position_embeddings=4096,
107
+ initializer_range=0.02,
108
+ layer_norm_eps=1e-5,
109
+ use_cache=True,
110
+ tie_word_embeddings=True,
111
+ rope_theta=10000.0,
112
+ bos_token_id=1,
113
+ eos_token_id=2,
114
+ sliding_window=2047,
115
+ mb_per_layer= 2,
116
+ mamba_d_state=16,
117
+ mamba_d_conv=4,
118
+ mamba_expand=2,
119
+ mamba_dt_rank="auto",
120
+ mamba_conv_bias=True,
121
+ mamba_proj_bias=False,
122
+ **kwargs,
123
+ ):
124
+ self.vocab_size = vocab_size
125
+ self.hidden_size = hidden_size
126
+ self.intermediate_size = intermediate_size
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.num_attention_heads = num_attention_heads
129
+
130
+ if num_key_value_heads is None:
131
+ num_key_value_heads = num_attention_heads
132
+
133
+ self.num_key_value_heads = num_key_value_heads
134
+ self.resid_pdrop = resid_pdrop
135
+ self.embd_pdrop = embd_pdrop
136
+ self.attention_dropout = attention_dropout
137
+ self.hidden_act = hidden_act
138
+ self.max_position_embeddings = max_position_embeddings
139
+ self.initializer_range = initializer_range
140
+ self.layer_norm_eps = layer_norm_eps
141
+ self.use_cache = use_cache
142
+ self.rope_theta = rope_theta
143
+ self.mb_per_layer = mb_per_layer
144
+ self.sliding_window = [
145
+ sliding_window if layer_idx < num_hidden_layers // 2 and layer_idx % 2 == 1 else None
146
+ for layer_idx in range(num_hidden_layers)
147
+ ]
148
+
149
+ self.mamba_d_state = mamba_d_state
150
+ self.mamba_d_conv = mamba_d_conv
151
+ self.mamba_expand = mamba_expand
152
+ self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
153
+ self.mamba_conv_bias = mamba_conv_bias
154
+ self.mamba_proj_bias = mamba_proj_bias
155
+
156
+ super().__init__(
157
+ bos_token_id=bos_token_id,
158
+ eos_token_id=eos_token_id,
159
+ tie_word_embeddings=tie_word_embeddings,
160
+ **kwargs,
161
+ )
162
+
163
+
164
+ @property
165
+ def layers_block_type(self):
166
+ layer_block_types = []
167
+ for i in range(self.num_hidden_layers):
168
+ if i % 2 == 1:
169
+ layer_block_type = "attention" if i <= (self.num_hidden_layers //2 +1) else "shared_attention"
170
+ else:
171
+ layer_block_type = "mamba"
172
+ layer_block_types.append(layer_block_type)
173
+ return layer_block_types
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 199999,
4
+ "eos_token_id": [
5
+ 200020,
6
+ 199999
7
+ ],
8
+ "pad_token_id": 199999,
9
+ "transformers_version": "4.45.0"
10
+ }
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4683d17ca19ab12e0278b6a1db98db76301cbbc3119d9599739df14f45554d03
3
+ size 4952270280
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40beaf0c37ad2788ccb63d698afe9725e84479d68bf7a1e9c0ce921af0e3916e
3
+ size 3777232440
model.safetensors.index.json ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 8729453568
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00002-of-00002.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
8
+ "model.final_layernorm.bias": "model-00002-of-00002.safetensors",
9
+ "model.final_layernorm.weight": "model-00002-of-00002.safetensors",
10
+ "model.layers.0.attn.A_log": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.attn.D": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.attn.conv1d.bias": "model-00001-of-00002.safetensors",
13
+ "model.layers.0.attn.conv1d.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.0.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
15
+ "model.layers.0.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.0.attn.in_proj.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.0.attn.out_proj.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.0.attn.x_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.0.input_layernorm.bias": "model-00001-of-00002.safetensors",
20
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.0.mlp.fc1.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.0.mlp.fc2.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.0.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
24
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.1.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
26
+ "model.layers.1.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.1.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
28
+ "model.layers.1.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
29
+ "model.layers.1.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
30
+ "model.layers.1.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
31
+ "model.layers.1.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.1.attn.out_proj.bias": "model-00001-of-00002.safetensors",
33
+ "model.layers.1.attn.out_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.1.input_layernorm.bias": "model-00001-of-00002.safetensors",
35
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.1.mlp.fc1.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.1.mlp.fc2.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.1.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
39
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.10.attn.A_log": "model-00001-of-00002.safetensors",
41
+ "model.layers.10.attn.D": "model-00001-of-00002.safetensors",
42
+ "model.layers.10.attn.conv1d.bias": "model-00001-of-00002.safetensors",
43
+ "model.layers.10.attn.conv1d.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.10.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
45
+ "model.layers.10.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.10.attn.in_proj.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.10.attn.out_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.10.attn.x_proj.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.10.input_layernorm.bias": "model-00001-of-00002.safetensors",
50
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.10.mlp.fc1.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.10.mlp.fc2.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.10.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
54
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.11.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
56
+ "model.layers.11.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.11.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
58
+ "model.layers.11.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
59
+ "model.layers.11.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
60
+ "model.layers.11.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
61
+ "model.layers.11.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.11.attn.out_proj.bias": "model-00001-of-00002.safetensors",
63
+ "model.layers.11.attn.out_proj.weight": "model-00001-of-00002.safetensors",
64
+ "model.layers.11.input_layernorm.bias": "model-00001-of-00002.safetensors",
65
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.11.mlp.fc1.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.11.mlp.fc2.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.11.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
69
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.12.attn.A_log": "model-00001-of-00002.safetensors",
71
+ "model.layers.12.attn.D": "model-00001-of-00002.safetensors",
72
+ "model.layers.12.attn.conv1d.bias": "model-00001-of-00002.safetensors",
73
+ "model.layers.12.attn.conv1d.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.12.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
75
+ "model.layers.12.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.12.attn.in_proj.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.12.attn.out_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.12.attn.x_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.12.input_layernorm.bias": "model-00001-of-00002.safetensors",
80
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.12.mlp.fc1.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.12.mlp.fc2.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.12.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
84
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.13.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
86
+ "model.layers.13.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.13.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
88
+ "model.layers.13.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
89
+ "model.layers.13.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
90
+ "model.layers.13.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
91
+ "model.layers.13.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
92
+ "model.layers.13.attn.out_proj.bias": "model-00001-of-00002.safetensors",
93
+ "model.layers.13.attn.out_proj.weight": "model-00001-of-00002.safetensors",
94
+ "model.layers.13.input_layernorm.bias": "model-00001-of-00002.safetensors",
95
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.13.mlp.fc1.weight": "model-00001-of-00002.safetensors",
97
+ "model.layers.13.mlp.fc2.weight": "model-00001-of-00002.safetensors",
98
+ "model.layers.13.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
99
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
100
+ "model.layers.14.attn.A_log": "model-00001-of-00002.safetensors",
101
+ "model.layers.14.attn.D": "model-00001-of-00002.safetensors",
102
+ "model.layers.14.attn.conv1d.bias": "model-00001-of-00002.safetensors",
103
+ "model.layers.14.attn.conv1d.weight": "model-00001-of-00002.safetensors",
104
+ "model.layers.14.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
105
+ "model.layers.14.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
106
+ "model.layers.14.attn.in_proj.weight": "model-00001-of-00002.safetensors",
107
+ "model.layers.14.attn.out_proj.weight": "model-00001-of-00002.safetensors",
108
+ "model.layers.14.attn.x_proj.weight": "model-00001-of-00002.safetensors",
109
+ "model.layers.14.input_layernorm.bias": "model-00001-of-00002.safetensors",
110
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
111
+ "model.layers.14.mlp.fc1.weight": "model-00001-of-00002.safetensors",
112
+ "model.layers.14.mlp.fc2.weight": "model-00001-of-00002.safetensors",
113
+ "model.layers.14.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
114
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
115
+ "model.layers.15.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
116
+ "model.layers.15.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.15.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
118
+ "model.layers.15.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
119
+ "model.layers.15.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
120
+ "model.layers.15.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
121
+ "model.layers.15.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.15.attn.out_proj.bias": "model-00001-of-00002.safetensors",
123
+ "model.layers.15.attn.out_proj.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.15.input_layernorm.bias": "model-00001-of-00002.safetensors",
125
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.15.mlp.fc1.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.15.mlp.fc2.weight": "model-00001-of-00002.safetensors",
128
+ "model.layers.15.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
129
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
130
+ "model.layers.16.attn.A_log": "model-00001-of-00002.safetensors",
131
+ "model.layers.16.attn.D": "model-00001-of-00002.safetensors",
132
+ "model.layers.16.attn.conv1d.bias": "model-00001-of-00002.safetensors",
133
+ "model.layers.16.attn.conv1d.weight": "model-00001-of-00002.safetensors",
134
+ "model.layers.16.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
135
+ "model.layers.16.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
136
+ "model.layers.16.attn.in_proj.weight": "model-00001-of-00002.safetensors",
137
+ "model.layers.16.attn.out_proj.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.16.attn.x_proj.weight": "model-00001-of-00002.safetensors",
139
+ "model.layers.16.input_layernorm.bias": "model-00001-of-00002.safetensors",
140
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
141
+ "model.layers.16.mlp.fc1.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.16.mlp.fc2.weight": "model-00001-of-00002.safetensors",
143
+ "model.layers.16.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
144
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
145
+ "model.layers.17.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
146
+ "model.layers.17.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.17.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
148
+ "model.layers.17.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
149
+ "model.layers.17.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
150
+ "model.layers.17.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
151
+ "model.layers.17.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.17.attn.out_proj.bias": "model-00001-of-00002.safetensors",
153
+ "model.layers.17.attn.out_proj.weight": "model-00001-of-00002.safetensors",
154
+ "model.layers.17.input_layernorm.bias": "model-00001-of-00002.safetensors",
155
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.17.mlp.fc1.weight": "model-00001-of-00002.safetensors",
157
+ "model.layers.17.mlp.fc2.weight": "model-00001-of-00002.safetensors",
158
+ "model.layers.17.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
159
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
160
+ "model.layers.18.attn.in_proj.weight": "model-00002-of-00002.safetensors",
161
+ "model.layers.18.attn.out_proj.weight": "model-00002-of-00002.safetensors",
162
+ "model.layers.18.input_layernorm.bias": "model-00002-of-00002.safetensors",
163
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00002.safetensors",
164
+ "model.layers.18.mlp.fc1.weight": "model-00002-of-00002.safetensors",
165
+ "model.layers.18.mlp.fc2.weight": "model-00002-of-00002.safetensors",
166
+ "model.layers.18.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
167
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
168
+ "model.layers.19.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
169
+ "model.layers.19.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
170
+ "model.layers.19.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
171
+ "model.layers.19.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
172
+ "model.layers.19.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
173
+ "model.layers.19.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
174
+ "model.layers.19.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
175
+ "model.layers.19.attn.out_proj.bias": "model-00002-of-00002.safetensors",
176
+ "model.layers.19.attn.out_proj.weight": "model-00002-of-00002.safetensors",
177
+ "model.layers.19.input_layernorm.bias": "model-00002-of-00002.safetensors",
178
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00002.safetensors",
179
+ "model.layers.19.mlp.fc1.weight": "model-00002-of-00002.safetensors",
180
+ "model.layers.19.mlp.fc2.weight": "model-00002-of-00002.safetensors",
181
+ "model.layers.19.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
182
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
183
+ "model.layers.2.attn.A_log": "model-00001-of-00002.safetensors",
184
+ "model.layers.2.attn.D": "model-00001-of-00002.safetensors",
185
+ "model.layers.2.attn.conv1d.bias": "model-00001-of-00002.safetensors",
186
+ "model.layers.2.attn.conv1d.weight": "model-00001-of-00002.safetensors",
187
+ "model.layers.2.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
188
+ "model.layers.2.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
189
+ "model.layers.2.attn.in_proj.weight": "model-00001-of-00002.safetensors",
190
+ "model.layers.2.attn.out_proj.weight": "model-00001-of-00002.safetensors",
191
+ "model.layers.2.attn.x_proj.weight": "model-00001-of-00002.safetensors",
192
+ "model.layers.2.input_layernorm.bias": "model-00001-of-00002.safetensors",
193
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
194
+ "model.layers.2.mlp.fc1.weight": "model-00001-of-00002.safetensors",
195
+ "model.layers.2.mlp.fc2.weight": "model-00001-of-00002.safetensors",
196
+ "model.layers.2.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
197
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
198
+ "model.layers.20.attn.in_proj.weight": "model-00002-of-00002.safetensors",
199
+ "model.layers.20.attn.out_proj.weight": "model-00002-of-00002.safetensors",
200
+ "model.layers.20.input_layernorm.bias": "model-00002-of-00002.safetensors",
201
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00002.safetensors",
202
+ "model.layers.20.mlp.fc1.weight": "model-00002-of-00002.safetensors",
203
+ "model.layers.20.mlp.fc2.weight": "model-00002-of-00002.safetensors",
204
+ "model.layers.20.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
205
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
206
+ "model.layers.21.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
207
+ "model.layers.21.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
208
+ "model.layers.21.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
209
+ "model.layers.21.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
210
+ "model.layers.21.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
211
+ "model.layers.21.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
212
+ "model.layers.21.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
213
+ "model.layers.21.attn.out_proj.bias": "model-00002-of-00002.safetensors",
214
+ "model.layers.21.attn.out_proj.weight": "model-00002-of-00002.safetensors",
215
+ "model.layers.21.input_layernorm.bias": "model-00002-of-00002.safetensors",
216
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
217
+ "model.layers.21.mlp.fc1.weight": "model-00002-of-00002.safetensors",
218
+ "model.layers.21.mlp.fc2.weight": "model-00002-of-00002.safetensors",
219
+ "model.layers.21.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
220
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
221
+ "model.layers.22.attn.in_proj.weight": "model-00002-of-00002.safetensors",
222
+ "model.layers.22.attn.out_proj.weight": "model-00002-of-00002.safetensors",
223
+ "model.layers.22.input_layernorm.bias": "model-00002-of-00002.safetensors",
224
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
225
+ "model.layers.22.mlp.fc1.weight": "model-00002-of-00002.safetensors",
226
+ "model.layers.22.mlp.fc2.weight": "model-00002-of-00002.safetensors",
227
+ "model.layers.22.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
228
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
229
+ "model.layers.23.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
230
+ "model.layers.23.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
231
+ "model.layers.23.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
232
+ "model.layers.23.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
233
+ "model.layers.23.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
234
+ "model.layers.23.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
235
+ "model.layers.23.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
236
+ "model.layers.23.attn.out_proj.bias": "model-00002-of-00002.safetensors",
237
+ "model.layers.23.attn.out_proj.weight": "model-00002-of-00002.safetensors",
238
+ "model.layers.23.input_layernorm.bias": "model-00002-of-00002.safetensors",
239
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
240
+ "model.layers.23.mlp.fc1.weight": "model-00002-of-00002.safetensors",
241
+ "model.layers.23.mlp.fc2.weight": "model-00002-of-00002.safetensors",
242
+ "model.layers.23.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
243
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
244
+ "model.layers.24.attn.in_proj.weight": "model-00002-of-00002.safetensors",
245
+ "model.layers.24.attn.out_proj.weight": "model-00002-of-00002.safetensors",
246
+ "model.layers.24.input_layernorm.bias": "model-00002-of-00002.safetensors",
247
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
248
+ "model.layers.24.mlp.fc1.weight": "model-00002-of-00002.safetensors",
249
+ "model.layers.24.mlp.fc2.weight": "model-00002-of-00002.safetensors",
250
+ "model.layers.24.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
251
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
252
+ "model.layers.25.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
253
+ "model.layers.25.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
254
+ "model.layers.25.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
255
+ "model.layers.25.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
256
+ "model.layers.25.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
257
+ "model.layers.25.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
258
+ "model.layers.25.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
259
+ "model.layers.25.attn.out_proj.bias": "model-00002-of-00002.safetensors",
260
+ "model.layers.25.attn.out_proj.weight": "model-00002-of-00002.safetensors",
261
+ "model.layers.25.input_layernorm.bias": "model-00002-of-00002.safetensors",
262
+ "model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
263
+ "model.layers.25.mlp.fc1.weight": "model-00002-of-00002.safetensors",
264
+ "model.layers.25.mlp.fc2.weight": "model-00002-of-00002.safetensors",
265
+ "model.layers.25.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
266
+ "model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
267
+ "model.layers.26.attn.in_proj.weight": "model-00002-of-00002.safetensors",
268
+ "model.layers.26.attn.out_proj.weight": "model-00002-of-00002.safetensors",
269
+ "model.layers.26.input_layernorm.bias": "model-00002-of-00002.safetensors",
270
+ "model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
271
+ "model.layers.26.mlp.fc1.weight": "model-00002-of-00002.safetensors",
272
+ "model.layers.26.mlp.fc2.weight": "model-00002-of-00002.safetensors",
273
+ "model.layers.26.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
274
+ "model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
275
+ "model.layers.27.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
276
+ "model.layers.27.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
277
+ "model.layers.27.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
278
+ "model.layers.27.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
279
+ "model.layers.27.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
280
+ "model.layers.27.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
281
+ "model.layers.27.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
282
+ "model.layers.27.attn.out_proj.bias": "model-00002-of-00002.safetensors",
283
+ "model.layers.27.attn.out_proj.weight": "model-00002-of-00002.safetensors",
284
+ "model.layers.27.input_layernorm.bias": "model-00002-of-00002.safetensors",
285
+ "model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
286
+ "model.layers.27.mlp.fc1.weight": "model-00002-of-00002.safetensors",
287
+ "model.layers.27.mlp.fc2.weight": "model-00002-of-00002.safetensors",
288
+ "model.layers.27.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
289
+ "model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
290
+ "model.layers.28.attn.in_proj.weight": "model-00002-of-00002.safetensors",
291
+ "model.layers.28.attn.out_proj.weight": "model-00002-of-00002.safetensors",
292
+ "model.layers.28.input_layernorm.bias": "model-00002-of-00002.safetensors",
293
+ "model.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
294
+ "model.layers.28.mlp.fc1.weight": "model-00002-of-00002.safetensors",
295
+ "model.layers.28.mlp.fc2.weight": "model-00002-of-00002.safetensors",
296
+ "model.layers.28.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
297
+ "model.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
298
+ "model.layers.29.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
299
+ "model.layers.29.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
300
+ "model.layers.29.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
301
+ "model.layers.29.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
302
+ "model.layers.29.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
303
+ "model.layers.29.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
304
+ "model.layers.29.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
305
+ "model.layers.29.attn.out_proj.bias": "model-00002-of-00002.safetensors",
306
+ "model.layers.29.attn.out_proj.weight": "model-00002-of-00002.safetensors",
307
+ "model.layers.29.input_layernorm.bias": "model-00002-of-00002.safetensors",
308
+ "model.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
309
+ "model.layers.29.mlp.fc1.weight": "model-00002-of-00002.safetensors",
310
+ "model.layers.29.mlp.fc2.weight": "model-00002-of-00002.safetensors",
311
+ "model.layers.29.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
312
+ "model.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
313
+ "model.layers.3.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
314
+ "model.layers.3.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
315
+ "model.layers.3.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
316
+ "model.layers.3.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
317
+ "model.layers.3.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
318
+ "model.layers.3.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
319
+ "model.layers.3.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
320
+ "model.layers.3.attn.out_proj.bias": "model-00001-of-00002.safetensors",
321
+ "model.layers.3.attn.out_proj.weight": "model-00001-of-00002.safetensors",
322
+ "model.layers.3.input_layernorm.bias": "model-00001-of-00002.safetensors",
323
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
324
+ "model.layers.3.mlp.fc1.weight": "model-00001-of-00002.safetensors",
325
+ "model.layers.3.mlp.fc2.weight": "model-00001-of-00002.safetensors",
326
+ "model.layers.3.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
327
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
328
+ "model.layers.30.attn.in_proj.weight": "model-00002-of-00002.safetensors",
329
+ "model.layers.30.attn.out_proj.weight": "model-00002-of-00002.safetensors",
330
+ "model.layers.30.input_layernorm.bias": "model-00002-of-00002.safetensors",
331
+ "model.layers.30.input_layernorm.weight": "model-00002-of-00002.safetensors",
332
+ "model.layers.30.mlp.fc1.weight": "model-00002-of-00002.safetensors",
333
+ "model.layers.30.mlp.fc2.weight": "model-00002-of-00002.safetensors",
334
+ "model.layers.30.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
335
+ "model.layers.30.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
336
+ "model.layers.31.attn.Wqkv.bias": "model-00002-of-00002.safetensors",
337
+ "model.layers.31.attn.Wqkv.weight": "model-00002-of-00002.safetensors",
338
+ "model.layers.31.attn.inner_cross_attn.lambda_k1": "model-00002-of-00002.safetensors",
339
+ "model.layers.31.attn.inner_cross_attn.lambda_k2": "model-00002-of-00002.safetensors",
340
+ "model.layers.31.attn.inner_cross_attn.lambda_q1": "model-00002-of-00002.safetensors",
341
+ "model.layers.31.attn.inner_cross_attn.lambda_q2": "model-00002-of-00002.safetensors",
342
+ "model.layers.31.attn.inner_cross_attn.subln.weight": "model-00002-of-00002.safetensors",
343
+ "model.layers.31.attn.out_proj.bias": "model-00002-of-00002.safetensors",
344
+ "model.layers.31.attn.out_proj.weight": "model-00002-of-00002.safetensors",
345
+ "model.layers.31.input_layernorm.bias": "model-00002-of-00002.safetensors",
346
+ "model.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
347
+ "model.layers.31.mlp.fc1.weight": "model-00002-of-00002.safetensors",
348
+ "model.layers.31.mlp.fc2.weight": "model-00002-of-00002.safetensors",
349
+ "model.layers.31.post_attention_layernorm.bias": "model-00002-of-00002.safetensors",
350
+ "model.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
351
+ "model.layers.4.attn.A_log": "model-00001-of-00002.safetensors",
352
+ "model.layers.4.attn.D": "model-00001-of-00002.safetensors",
353
+ "model.layers.4.attn.conv1d.bias": "model-00001-of-00002.safetensors",
354
+ "model.layers.4.attn.conv1d.weight": "model-00001-of-00002.safetensors",
355
+ "model.layers.4.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
356
+ "model.layers.4.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
357
+ "model.layers.4.attn.in_proj.weight": "model-00001-of-00002.safetensors",
358
+ "model.layers.4.attn.out_proj.weight": "model-00001-of-00002.safetensors",
359
+ "model.layers.4.attn.x_proj.weight": "model-00001-of-00002.safetensors",
360
+ "model.layers.4.input_layernorm.bias": "model-00001-of-00002.safetensors",
361
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
362
+ "model.layers.4.mlp.fc1.weight": "model-00001-of-00002.safetensors",
363
+ "model.layers.4.mlp.fc2.weight": "model-00001-of-00002.safetensors",
364
+ "model.layers.4.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
365
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
366
+ "model.layers.5.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
367
+ "model.layers.5.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
368
+ "model.layers.5.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
369
+ "model.layers.5.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
370
+ "model.layers.5.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
371
+ "model.layers.5.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
372
+ "model.layers.5.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
373
+ "model.layers.5.attn.out_proj.bias": "model-00001-of-00002.safetensors",
374
+ "model.layers.5.attn.out_proj.weight": "model-00001-of-00002.safetensors",
375
+ "model.layers.5.input_layernorm.bias": "model-00001-of-00002.safetensors",
376
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
377
+ "model.layers.5.mlp.fc1.weight": "model-00001-of-00002.safetensors",
378
+ "model.layers.5.mlp.fc2.weight": "model-00001-of-00002.safetensors",
379
+ "model.layers.5.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
380
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
381
+ "model.layers.6.attn.A_log": "model-00001-of-00002.safetensors",
382
+ "model.layers.6.attn.D": "model-00001-of-00002.safetensors",
383
+ "model.layers.6.attn.conv1d.bias": "model-00001-of-00002.safetensors",
384
+ "model.layers.6.attn.conv1d.weight": "model-00001-of-00002.safetensors",
385
+ "model.layers.6.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
386
+ "model.layers.6.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
387
+ "model.layers.6.attn.in_proj.weight": "model-00001-of-00002.safetensors",
388
+ "model.layers.6.attn.out_proj.weight": "model-00001-of-00002.safetensors",
389
+ "model.layers.6.attn.x_proj.weight": "model-00001-of-00002.safetensors",
390
+ "model.layers.6.input_layernorm.bias": "model-00001-of-00002.safetensors",
391
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
392
+ "model.layers.6.mlp.fc1.weight": "model-00001-of-00002.safetensors",
393
+ "model.layers.6.mlp.fc2.weight": "model-00001-of-00002.safetensors",
394
+ "model.layers.6.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
395
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
396
+ "model.layers.7.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
397
+ "model.layers.7.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
398
+ "model.layers.7.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
399
+ "model.layers.7.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
400
+ "model.layers.7.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
401
+ "model.layers.7.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
402
+ "model.layers.7.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
403
+ "model.layers.7.attn.out_proj.bias": "model-00001-of-00002.safetensors",
404
+ "model.layers.7.attn.out_proj.weight": "model-00001-of-00002.safetensors",
405
+ "model.layers.7.input_layernorm.bias": "model-00001-of-00002.safetensors",
406
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
407
+ "model.layers.7.mlp.fc1.weight": "model-00001-of-00002.safetensors",
408
+ "model.layers.7.mlp.fc2.weight": "model-00001-of-00002.safetensors",
409
+ "model.layers.7.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
410
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
411
+ "model.layers.8.attn.A_log": "model-00001-of-00002.safetensors",
412
+ "model.layers.8.attn.D": "model-00001-of-00002.safetensors",
413
+ "model.layers.8.attn.conv1d.bias": "model-00001-of-00002.safetensors",
414
+ "model.layers.8.attn.conv1d.weight": "model-00001-of-00002.safetensors",
415
+ "model.layers.8.attn.dt_proj.bias": "model-00001-of-00002.safetensors",
416
+ "model.layers.8.attn.dt_proj.weight": "model-00001-of-00002.safetensors",
417
+ "model.layers.8.attn.in_proj.weight": "model-00001-of-00002.safetensors",
418
+ "model.layers.8.attn.out_proj.weight": "model-00001-of-00002.safetensors",
419
+ "model.layers.8.attn.x_proj.weight": "model-00001-of-00002.safetensors",
420
+ "model.layers.8.input_layernorm.bias": "model-00001-of-00002.safetensors",
421
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
422
+ "model.layers.8.mlp.fc1.weight": "model-00001-of-00002.safetensors",
423
+ "model.layers.8.mlp.fc2.weight": "model-00001-of-00002.safetensors",
424
+ "model.layers.8.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
425
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
426
+ "model.layers.9.attn.Wqkv.bias": "model-00001-of-00002.safetensors",
427
+ "model.layers.9.attn.Wqkv.weight": "model-00001-of-00002.safetensors",
428
+ "model.layers.9.attn.inner_cross_attn.lambda_k1": "model-00001-of-00002.safetensors",
429
+ "model.layers.9.attn.inner_cross_attn.lambda_k2": "model-00001-of-00002.safetensors",
430
+ "model.layers.9.attn.inner_cross_attn.lambda_q1": "model-00001-of-00002.safetensors",
431
+ "model.layers.9.attn.inner_cross_attn.lambda_q2": "model-00001-of-00002.safetensors",
432
+ "model.layers.9.attn.inner_cross_attn.subln.weight": "model-00001-of-00002.safetensors",
433
+ "model.layers.9.attn.out_proj.bias": "model-00001-of-00002.safetensors",
434
+ "model.layers.9.attn.out_proj.weight": "model-00001-of-00002.safetensors",
435
+ "model.layers.9.input_layernorm.bias": "model-00001-of-00002.safetensors",
436
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
437
+ "model.layers.9.mlp.fc1.weight": "model-00001-of-00002.safetensors",
438
+ "model.layers.9.mlp.fc2.weight": "model-00001-of-00002.safetensors",
439
+ "model.layers.9.post_attention_layernorm.bias": "model-00001-of-00002.safetensors",
440
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors"
441
+ }
442
+ }
modeling_phi4flash.py ADDED
@@ -0,0 +1,2098 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ PyTorch Phi4Flash model."""
17
+
18
+
19
+ import inspect
20
+ import math
21
+ import warnings
22
+ from typing import List, Optional, Tuple, Union, Dict, Any
23
+ import copy
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache
31
+ from transformers.utils import is_torchdynamo_compiling
32
+ from transformers.modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.generation import GenerationMixin
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ is_flash_attn_greater_or_equal_2_10,
45
+ logging,
46
+ replace_return_docstrings,
47
+ )
48
+ from einops import rearrange, repeat
49
+
50
+ from .configuration_phi4flash import Phi4FlashConfig
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
55
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
56
+
57
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
58
+
59
+ if not _flash_supports_window_size:
60
+ raise ValueError("Please update flash-attention to support window size.")
61
+
62
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
63
+ import causal_conv1d_cuda
64
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
65
+
66
+ from torch.amp import custom_bwd, custom_fwd
67
+ import selective_scan_cuda
68
+
69
+ _CHECKPOINT_FOR_DOC = "microsoft/Phi-4-mini-flash-reasoning"
70
+ _CONFIG_FOR_DOC = "Phi4FlashConfig"
71
+
72
+ # monkey patch to add support for our cache
73
+ def _prepare_cache_for_generation(
74
+ self,
75
+ generation_config,
76
+ model_kwargs: Dict,
77
+ assistant_model: "PreTrainedModel",
78
+ batch_size: int,
79
+ max_cache_length: int,
80
+ device: torch.device,
81
+ ) -> bool:
82
+ """
83
+ Prepares the cache for generation (if applicable), given `generate`'s parameterization. If a cache is
84
+ instantiated, writes it to `model_kwargs`, under the name expected by the model.
85
+ """
86
+
87
+ cache_name = "past_key_values"
88
+
89
+ # Quick escape route 2: if the user specifies no cache is to be used. (conflicting arguments are handled in
90
+ # `generation_config.validate()`)
91
+ if generation_config.use_cache is False:
92
+ return
93
+
94
+ # Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`
95
+
96
+ # TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
97
+ # which is only supported in dynamic caches atm
98
+ if assistant_model is not None:
99
+ logger.warning_once(
100
+ "An assistant model is provided, using a dynamic cache instead of a cache of type="
101
+ f"'{generation_config.cache_implementation}'."
102
+ )
103
+ model_kwargs[cache_name] = DynamicCache()
104
+ return
105
+
106
+ model_kwargs[cache_name] = self._get_cache(
107
+ cache_implementation="sambay",
108
+ batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
109
+ max_cache_len=max_cache_length,
110
+ device=device,
111
+ model_kwargs=model_kwargs,
112
+ )
113
+
114
+ def _get_cache(
115
+ self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
116
+ ) -> Cache:
117
+ """
118
+ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
119
+ new `generate` call requires a larger cache or uses a different batch size.
120
+
121
+ Returns the resulting cache object.
122
+ """
123
+ cache_cls: Cache = SambaYCache
124
+ requires_cross_attention_cache = (
125
+ self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
126
+ )
127
+
128
+ if hasattr(self, "_cache"):
129
+ cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
130
+
131
+ if cache_implementation == "sliding_window":
132
+ max_cache_len = min(self.config.sliding_window[1], max_cache_len)
133
+
134
+ need_new_cache = (
135
+ not hasattr(self, "_cache")
136
+ or (not isinstance(cache_to_check, cache_cls))
137
+ or cache_to_check.batch_size != batch_size
138
+ )
139
+ if cache_implementation != "mamba":
140
+ need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
141
+
142
+ if requires_cross_attention_cache and hasattr(self, "_cache"):
143
+ need_new_cache = (
144
+ need_new_cache
145
+ or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
146
+ )
147
+
148
+ if need_new_cache:
149
+ if hasattr(self.config, "_pre_quantization_dtype"):
150
+ cache_dtype = self.config._pre_quantization_dtype
151
+ else:
152
+ if not is_torchdynamo_compiling():
153
+ cache_dtype = self.dtype
154
+ else:
155
+ # NOTE: self.dtype is not compatible with torch.compile, as it calls `self.parameters()`.
156
+ # Workaround: trust the lm_head, whose attribute name is somewhat consistent across generative
157
+ # models. May cause trobles with non-text modalities.
158
+ cache_dtype = self.get_output_embeddings().weight.dtype
159
+
160
+ def get_layer_device_map(execution_device_map: Optional[dict] = None):
161
+ if execution_device_map is None:
162
+ return None
163
+ elif len(execution_device_map) == 1 and "" in execution_device_map:
164
+ return {idx: execution_device_map[""] for idx in range(self.config.num_hidden_layers)}
165
+ layer_device_map = {}
166
+ for layer in execution_device_map:
167
+ for idx in range(self.config.num_hidden_layers):
168
+ if f".{idx}." in f"{layer}.":
169
+ layer_device_map[idx] = execution_device_map[layer]
170
+ break
171
+ for idx in range(self.config.num_hidden_layers):
172
+ if idx not in layer_device_map:
173
+ raise RuntimeError(f"layer {idx} has not been mapped to a device.")
174
+ return layer_device_map
175
+
176
+ execution_device_map = None
177
+ # Taken from dispatch_model from accelerate.
178
+ # This is needed here if we don't want to make changes in accelerate in order to save execution_device
179
+ # For offloaded case, we need to get the execution device, not just the device where it is offloaded
180
+ if hasattr(self, "hf_device_map"):
181
+ main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
182
+ execution_device_map = {
183
+ name: main_device if device in ["cpu", "disk"] else device
184
+ for name, device in self.hf_device_map.items()
185
+ }
186
+ layer_device_map = get_layer_device_map(execution_device_map)
187
+
188
+ cache_kwargs = {
189
+ "config": self.config.get_text_config(),
190
+ "batch_size": batch_size,
191
+ "max_cache_len": max_cache_len,
192
+ "device": device,
193
+ "dtype": cache_dtype,
194
+ "layer_device_map": layer_device_map,
195
+ }
196
+ self._cache = cache_cls(**cache_kwargs)
197
+ else:
198
+ self._cache.reset()
199
+ return self._cache
200
+
201
+ GenerationMixin._prepare_cache_for_generation = _prepare_cache_for_generation
202
+ GenerationMixin._get_cache = _get_cache
203
+
204
+ class SambaYCache(Cache):
205
+ """
206
+ A dynamic cache that can handle the sliding window attention cache, one layer of full attention cache and the mamba cache
207
+ (which has a constant shape regardless of seq_len).
208
+
209
+ """
210
+
211
+ def __init__(self,
212
+ config: Phi4FlashConfig,
213
+ batch_size: int = None,
214
+ max_cache_len: int = None,
215
+ device: Union[torch.device, str] = "cuda",
216
+ dtype: torch.dtype = torch.float16,
217
+ max_batch_size: Optional[int] = None,
218
+ layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
219
+ ) -> None:
220
+ super().__init__()
221
+ self.dtype = dtype
222
+ self.has_previous_state = False # only used by mamba
223
+ intermediate_size = config.mamba_expand * config.hidden_size
224
+ ssm_state_size = config.mamba_d_state
225
+ conv_kernel_size = config.mamba_d_conv
226
+ self.conv_kernel_size = conv_kernel_size
227
+
228
+ if batch_size is not None:
229
+ logger.warning_once(
230
+ f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
231
+ "v4.49. Use the more precisely named 'max_batch_size' argument instead."
232
+ )
233
+
234
+ self.max_cache_len = max_cache_len
235
+ self.max_batch_size = batch_size or max_batch_size
236
+ # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
237
+ self.head_dim = config.hidden_size // config.num_attention_heads
238
+ self.num_key_value_heads = config.num_key_value_heads
239
+ self.global_attn_idx = config.num_hidden_layers//2 + 1
240
+ self.key_cache: List[torch.Tensor] = []
241
+ self.value_cache: List[torch.Tensor] = []
242
+ global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
243
+ sliding_cache_shape = (
244
+ self.max_batch_size,
245
+ self.num_key_value_heads,
246
+ min(config.sliding_window[1], max_cache_len),
247
+ self.head_dim,
248
+ )
249
+ conv_cache_shape = (self.max_batch_size, intermediate_size, conv_kernel_size)
250
+ ssm_cache_shape = (self.max_batch_size, intermediate_size, ssm_state_size)
251
+ for i in range(config.num_hidden_layers//2 + 2):
252
+ if layer_device_map is not None:
253
+ layer_device = layer_device_map[i]
254
+ else:
255
+ layer_device = device
256
+ # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
257
+ # breaks when updating the cache.
258
+ if i == self.global_attn_idx:
259
+ key_cache_shape = value_cache_shape = global_cache_shape
260
+ elif i % 2 == 0:
261
+ key_cache_shape = conv_cache_shape
262
+ value_cache_shape = ssm_cache_shape
263
+ else:
264
+ key_cache_shape = value_cache_shape = sliding_cache_shape
265
+ new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=layer_device)
266
+ new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=layer_device)
267
+ torch._dynamo.mark_static_address(new_layer_key_cache)
268
+ torch._dynamo.mark_static_address(new_layer_value_cache)
269
+ self.key_cache.append(new_layer_key_cache)
270
+ self.value_cache.append(new_layer_value_cache)
271
+
272
+ def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
273
+ if cache_position.shape[0] > max_cache_len:
274
+ k_out = key_states[:, :, -max_cache_len:, :]
275
+ v_out = value_states[:, :, -max_cache_len:, :]
276
+ # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
277
+ self.key_cache[layer_idx] += k_out
278
+ self.value_cache[layer_idx] += v_out
279
+ # we should return the whole states instead of k_out, v_out to take the whole prompt
280
+ # into consideration when building kv cache instead of just throwing away tokens outside of the window
281
+ return key_states, value_states
282
+
283
+ slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
284
+ cache_position = cache_position.clamp(0, max_cache_len - 1)
285
+ to_shift = cache_position >= max_cache_len - 1
286
+ indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
287
+ k_out = k_out[:, :, indices]
288
+ v_out = v_out[:, :, indices]
289
+
290
+ k_out[:, :, cache_position] = key_states
291
+ v_out[:, :, cache_position] = value_states
292
+ # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
293
+ self.key_cache[layer_idx].zero_()
294
+ self.value_cache[layer_idx].zero_()
295
+
296
+ self.key_cache[layer_idx] += k_out
297
+ self.value_cache[layer_idx] += v_out
298
+ return k_out, v_out
299
+
300
+ def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
301
+ k_out[:, :, cache_position] = key_states
302
+ v_out[:, :, cache_position] = value_states
303
+
304
+ self.key_cache[layer_idx] = k_out
305
+ self.value_cache[layer_idx] = v_out
306
+ return k_out, v_out
307
+
308
+ def update(
309
+ self,
310
+ key_states: torch.Tensor,
311
+ value_states: torch.Tensor,
312
+ layer_idx: int,
313
+ cache_kwargs: Optional[Dict[str, Any]] = None,
314
+ ) -> Tuple[torch.Tensor]:
315
+ cache_position = cache_kwargs.get("cache_position")
316
+ k_out = self.key_cache[layer_idx]
317
+ v_out = self.value_cache[layer_idx]
318
+ if layer_idx == self.global_attn_idx:
319
+ update_fn = self._static_update
320
+ elif layer_idx % 2 == 1:
321
+ update_fn = self._sliding_update
322
+
323
+ return update_fn(
324
+ cache_position,
325
+ layer_idx,
326
+ key_states,
327
+ value_states,
328
+ k_out,
329
+ v_out,
330
+ k_out.shape[2],
331
+ )
332
+
333
+ def get_max_cache_shape(self) -> Optional[int]:
334
+ return self.max_cache_len
335
+
336
+ def get_seq_length(self, layer_idx: Optional[int] = 0):
337
+ # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
338
+ # limit the check to the first batch member and head dimension.
339
+ # TODO: deprecate this function in favor of `cache_position`
340
+ return (self.key_cache[self.global_attn_idx][0, 0].any(dim=-1)).sum()
341
+
342
+ def reset(self):
343
+ """Resets the cache values while preserving the objects"""
344
+ for layer_idx in range(len(self.key_cache)):
345
+ # In-place ops prevent breaking the static address
346
+ self.key_cache[layer_idx].zero_()
347
+ self.value_cache[layer_idx].zero_()
348
+
349
+ @property
350
+ def batch_size(self):
351
+ logger.warning_once(
352
+ f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
353
+ "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
354
+ )
355
+ return self.max_batch_size
356
+
357
+
358
+
359
+
360
+ swiglu_fwd_codestring = """
361
+ template <typename T> T swiglu_fwd(T x, T y) {
362
+ return float(x) * float(y) / (1.0f + ::exp(-float(x)));
363
+ }
364
+ """
365
+ swiglu_bwd_codestring = """
366
+ template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
367
+ float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
368
+ dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
369
+ dy = float(x) * x_sigmoid * float(g);
370
+ }
371
+ """
372
+ swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
373
+ swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
374
+
375
+
376
+ class SwiGLUFunction(torch.autograd.Function):
377
+
378
+ @staticmethod
379
+ def forward(ctx, x, y):
380
+ ctx.save_for_backward(x, y)
381
+ return swiglu_fwd(x, y)
382
+
383
+ @staticmethod
384
+ def backward(ctx, dout):
385
+ x, y = ctx.saved_tensors
386
+ return swiglu_bwd(x, y, dout)
387
+
388
+ swiglu = SwiGLUFunction.apply
389
+
390
+
391
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->SambaY
392
+ class SambaYRMSNorm(nn.Module):
393
+ def __init__(self, hidden_size, eps=1e-5):
394
+ """
395
+ SambaYRMSNorm is equivalent to T5LayerNorm
396
+ """
397
+ super().__init__()
398
+ self.weight = nn.Parameter(torch.ones(hidden_size))
399
+ self.variance_epsilon = eps
400
+
401
+ def forward(self, hidden_states):
402
+ input_dtype = hidden_states.dtype
403
+ hidden_states = hidden_states.to(torch.float32)
404
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
405
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
406
+ return self.weight * hidden_states.to(input_dtype)
407
+
408
+
409
+ PHI_NORM_CLASS = nn.LayerNorm
410
+
411
+
412
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
413
+ def _get_unpad_data(attention_mask):
414
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
415
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
416
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
417
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
418
+ return (
419
+ indices,
420
+ cu_seqlens,
421
+ max_seqlen_in_batch,
422
+ )
423
+
424
+
425
+ class SambaYMLP(nn.Module):
426
+ """Gated Linear Unit.
427
+
428
+ Reference:
429
+ Language Modeling with Gated Convolutional Networks.
430
+ https://arxiv.org/pdf/1612.08083v3.pdf.
431
+
432
+ """
433
+
434
+ def __init__(self, config):
435
+ super().__init__()
436
+
437
+ self.config = config
438
+ self.fc1 = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
439
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
440
+
441
+ self.activation_fn = ACT2FN[config.hidden_act]
442
+
443
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
444
+ y = self.fc1(hidden_states)
445
+
446
+ # Special case for SwiGLU
447
+ if self.config.hidden_act == "silu" and swiglu is not None:
448
+ gate, y = y.chunk(2, dim=-1)
449
+ y = swiglu(gate, y)
450
+ else:
451
+ gate, y = y.chunk(2, dim=-1)
452
+ y = y * self.activation_fn(gate)
453
+
454
+ return self.fc2(y)
455
+
456
+
457
+ class SambaYAttention(nn.Module):
458
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
459
+
460
+ def __init__(self, config: Phi4FlashConfig, layer_idx: Optional[int] = None, yoco_cross: bool = False):
461
+ super().__init__()
462
+ self.config = config
463
+ self.layer_idx = layer_idx
464
+ if layer_idx is None:
465
+ logger.warning_once(
466
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
467
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
468
+ "when creating this class."
469
+ )
470
+
471
+ self.attention_dropout = config.attention_dropout
472
+ self.hidden_size = config.hidden_size
473
+ self.num_heads = config.num_attention_heads
474
+ self.head_dim = self.hidden_size // self.num_heads
475
+ self.num_key_value_heads = config.num_key_value_heads
476
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
477
+ self.max_position_embeddings = config.max_position_embeddings
478
+ self.is_causal = True
479
+ self.yoco_cross = yoco_cross
480
+
481
+ if (self.head_dim * self.num_heads) != self.hidden_size:
482
+ raise ValueError(
483
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
484
+ f" and `num_heads`: {self.num_heads})."
485
+ )
486
+
487
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
488
+ self.out_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
489
+ if yoco_cross:
490
+ self.Wqkv = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
491
+ else:
492
+ self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
493
+
494
+ self.inner_cross_attn = FlashDiffCustomAttention(self.head_dim, self.layer_idx,)
495
+
496
+
497
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
498
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
499
+
500
+ def forward(
501
+ self,
502
+ hidden_states: torch.Tensor,
503
+ attention_mask: Optional[torch.Tensor] = None,
504
+ position_ids: Optional[torch.LongTensor] = None,
505
+ past_key_value: Optional[Cache] = None,
506
+ output_attentions: bool = False,
507
+ use_cache: bool = False,
508
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
509
+ raise NotImplementedError("SambaYAttention only support flash attention")
510
+
511
+
512
+ class SambaYFlashAttention2(SambaYAttention):
513
+ """
514
+ SambaY flash attention module. This module inherits from `SambaYAttention` as the weights of the module stays
515
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
516
+ flash attention and deal with padding tokens in case the input contains any of them.
517
+ """
518
+
519
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
520
+ def __init__(self, *args, **kwargs):
521
+ super().__init__(*args, **kwargs)
522
+
523
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
524
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
525
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
526
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
527
+
528
+
529
+
530
+ def forward(
531
+ self,
532
+ hidden_states: torch.Tensor,
533
+ attention_mask: Optional[torch.LongTensor] = None,
534
+ position_ids: Optional[torch.LongTensor] = None,
535
+ past_key_value: Optional[Cache] = None,
536
+ output_attentions: bool = False,
537
+ use_cache: bool = False,
538
+ cache_position: Optional[torch.LongTensor] = None,
539
+ yoco_key_values: Optional[torch.Tensor] = None,
540
+ **kwargs,
541
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
542
+ # SambaYFlashAttention2 attention does not support output_attentions
543
+
544
+ output_attentions = False
545
+ if "padding_mask" in kwargs:
546
+ warnings.warn(
547
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
548
+ )
549
+
550
+ # overwrite attention_mask with padding_mask
551
+ attention_mask = kwargs.pop("padding_mask")
552
+
553
+ bsz, q_len, _ = hidden_states.size()
554
+ if self.yoco_cross:
555
+ q = self.Wqkv(hidden_states)
556
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim).transpose(1,2)
557
+ key_states, value_states = yoco_key_values
558
+ query_states = q
559
+
560
+ use_sliding_windows = False
561
+ else:
562
+
563
+ qkv = self.Wqkv(hidden_states)
564
+ query_pos = self.num_heads * self.head_dim
565
+ query_states = qkv[..., :query_pos]
566
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
567
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
568
+
569
+ # Flash attention requires the input to have the shape
570
+ # batch_size x seq_length x head_dim x hidden_dim
571
+ # therefore we just need to keep the original shape
572
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
573
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
574
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
575
+
576
+ use_sliding_windows = self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None
577
+
578
+ if past_key_value is not None:
579
+
580
+ cache_kwargs = {"cache_position": cache_position}# Specific to RoPE models
581
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
582
+
583
+
584
+ yoco_key_values = key_states, value_states
585
+
586
+ attn_dropout = self.attention_dropout if self.training else 0.0
587
+
588
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
589
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
590
+ # cast them back in the correct dtype just to be sure everything works as expected.
591
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
592
+ # in fp32.
593
+
594
+ if query_states.dtype == torch.float32:
595
+ if torch.is_autocast_enabled():
596
+ target_dtype = torch.get_autocast_gpu_dtype()
597
+ # Handle the case where the model is quantized
598
+ elif hasattr(self.config, "_pre_quantization_dtype"):
599
+ target_dtype = self.config._pre_quantization_dtype
600
+ else:
601
+ target_dtype = self.Wqkv.weight.dtype
602
+
603
+ logger.warning_once(
604
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
605
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
606
+ f" {target_dtype}."
607
+ )
608
+
609
+ query_states = query_states.to(target_dtype)
610
+ key_states = key_states.to(target_dtype)
611
+ value_states = value_states.to(target_dtype)
612
+
613
+ # Reashape to the expected shape for Flash Attention
614
+ # -> b,q,h,d
615
+ query_states = query_states.transpose(1, 2)
616
+ key_states = key_states.transpose(1, 2)
617
+ value_states = value_states.transpose(1, 2)
618
+ if attention_mask is not None:
619
+ key_states = key_states[:, :attention_mask.shape[-1]]
620
+ value_states = value_states[:, :attention_mask.shape[-1]]
621
+ attn_output = self._flash_attention_forward(
622
+ query_states,
623
+ key_states,
624
+ value_states,
625
+ attention_mask,
626
+ q_len,
627
+ dropout=attn_dropout,
628
+ use_sliding_windows=use_sliding_windows,
629
+ )
630
+
631
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
632
+ attn_output = self.out_proj(attn_output)
633
+
634
+ if not output_attentions:
635
+ attn_weights = None
636
+
637
+ return attn_output, attn_weights, yoco_key_values
638
+
639
+ def _flash_attention_forward(
640
+ self,
641
+ query_states,
642
+ key_states,
643
+ value_states,
644
+ attention_mask,
645
+ query_length,
646
+ dropout=0.0,
647
+ softmax_scale=None,
648
+ use_sliding_windows=False,
649
+ ):
650
+ """
651
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
652
+ first unpad the input, then computes the attention scores and pad the final attention scores.
653
+
654
+ Args:
655
+ query_states (`torch.Tensor`):
656
+ Input query states to be passed to Flash Attention API
657
+ key_states (`torch.Tensor`):
658
+ Input key states to be passed to Flash Attention API
659
+ value_states (`torch.Tensor`):
660
+ Input value states to be passed to Flash Attention API
661
+ attention_mask (`torch.Tensor`):
662
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
663
+ position of padding tokens and 1 for the position of non-padding tokens.
664
+ dropout (`float`):
665
+ Attention dropout
666
+ softmax_scale (`float`, *optional*):
667
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
668
+ use_sliding_windows (`bool`, *optional*):
669
+ Whether to activate sliding window attention.
670
+ """
671
+ causal = self.is_causal
672
+ # Contains at least one padding token in the sequence
673
+ if attention_mask is not None:
674
+ batch_size = query_states.shape[0]
675
+ (
676
+ query_states,
677
+ key_states,
678
+ value_states,
679
+ indices_q,
680
+ cu_seq_lens,
681
+ max_seq_lens,
682
+ ) = self._upad_input(query_states, key_states, value_states, attention_mask, query_length)
683
+
684
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
685
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
686
+
687
+ if not use_sliding_windows:
688
+ attn_output_unpad = self.inner_cross_attn(
689
+ query_states,
690
+ key_states,
691
+ value_states,
692
+ cu_seqlens_q=cu_seqlens_q,
693
+ cu_seqlens_k=cu_seqlens_k,
694
+ max_seqlen_q=max_seqlen_in_batch_q,
695
+ max_seqlen_k=max_seqlen_in_batch_k,
696
+ dropout_p=dropout,
697
+ softmax_scale=softmax_scale,
698
+ causal=causal,
699
+ )
700
+ else:
701
+ attn_output_unpad = self.inner_cross_attn(
702
+ query_states,
703
+ key_states,
704
+ value_states,
705
+ cu_seqlens_q=cu_seqlens_q,
706
+ cu_seqlens_k=cu_seqlens_k,
707
+ max_seqlen_q=max_seqlen_in_batch_q,
708
+ max_seqlen_k=max_seqlen_in_batch_k,
709
+ dropout_p=dropout,
710
+ softmax_scale=softmax_scale,
711
+ causal=causal,
712
+ window_size=(
713
+ self.config.sliding_window[self.layer_idx] -1,
714
+ self.config.sliding_window[self.layer_idx] -1,
715
+ ),
716
+ )
717
+
718
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
719
+ else:
720
+ if not use_sliding_windows:
721
+ attn_output = self.inner_cross_attn(
722
+ query_states,
723
+ key_states,
724
+ value_states,
725
+ dropout_p=dropout,
726
+ softmax_scale=softmax_scale,
727
+ causal=causal,
728
+ )
729
+ else:
730
+ attn_output = self.inner_cross_attn(
731
+ query_states,
732
+ key_states,
733
+ value_states,
734
+ dropout_p=dropout,
735
+ softmax_scale=softmax_scale,
736
+ causal=causal,
737
+ window_size=(
738
+ self.config.sliding_window[self.layer_idx] -1,
739
+ self.config.sliding_window[self.layer_idx] -1,
740
+ ),
741
+ )
742
+
743
+ return attn_output
744
+
745
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
746
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
747
+
748
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
749
+
750
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
751
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
752
+
753
+ if query_length == kv_seq_len:
754
+ query_layer = index_first_axis(
755
+ query_layer.reshape(batch_size * kv_seq_len, -1, head_dim),
756
+ indices_k,
757
+ )
758
+ cu_seqlens_q = cu_seqlens_k
759
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
760
+ indices_q = indices_k
761
+ elif query_length == 1:
762
+ max_seqlen_in_batch_q = 1
763
+ cu_seqlens_q = torch.arange(
764
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
765
+ ) # There is a memcpy here, that is very bad.
766
+ indices_q = cu_seqlens_q[:-1]
767
+ query_layer = query_layer.squeeze(1)
768
+ else:
769
+ # The -q_len: slice assumes left padding.
770
+ attention_mask = attention_mask[:, -query_length:]
771
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
772
+
773
+ return (
774
+ query_layer,
775
+ key_layer,
776
+ value_layer,
777
+ indices_q,
778
+ (cu_seqlens_q, cu_seqlens_k),
779
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
780
+ )
781
+
782
+
783
+
784
+ class Phi3Mamba(nn.Module):
785
+ def __init__(
786
+ self,
787
+ d_model,
788
+ d_state=16,
789
+ d_conv=4,
790
+ expand=2,
791
+ dt_rank="auto",
792
+ conv_bias=True,
793
+ bias=False,
794
+ use_fast_path=True, # Fused kernel options
795
+ layer_idx=None,
796
+ yoco_cross=False,
797
+ yoco_kv=False,
798
+ dtype=None,
799
+ ):
800
+ factory_kwargs = {"dtype": dtype}
801
+ super().__init__()
802
+ self.d_model = d_model
803
+ self.d_state = d_state
804
+ self.d_conv = d_conv
805
+ self.expand = expand
806
+ self.d_inner = int(self.expand * self.d_model)
807
+ self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
808
+ self.use_fast_path = use_fast_path
809
+ self.layer_idx = layer_idx
810
+
811
+ self.yoco_cross = yoco_cross
812
+ self.yoco_kv = yoco_kv
813
+ if self.yoco_cross:
814
+ self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs)
815
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
816
+ else:
817
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
818
+
819
+ self.conv1d = nn.Conv1d(
820
+ in_channels=self.d_inner,
821
+ out_channels=self.d_inner,
822
+ bias=conv_bias,
823
+ kernel_size=d_conv,
824
+ groups=self.d_inner,
825
+ padding=d_conv - 1,
826
+ **factory_kwargs,
827
+ )
828
+
829
+ self.activation = "silu"
830
+ self.act = nn.SiLU()
831
+
832
+ self.x_proj = nn.Linear(
833
+ self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
834
+ )
835
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)
836
+
837
+ # S4D real initialization
838
+ A = repeat(
839
+ torch.arange(1, self.d_state + 1, dtype=torch.float32),
840
+ "n -> d n",
841
+ d=self.d_inner,
842
+ ).contiguous()
843
+ A_log = torch.log(A) # Keep A_log in fp32
844
+ self.A_log = nn.Parameter(A_log)
845
+
846
+ # D "skip" parameter
847
+ self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
848
+
849
+ self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
850
+
851
+ def forward(self, hidden_states, inference_params=None, mask= None, yoco_key_values = None, cache_position = None):
852
+ """
853
+ hidden_states: (B, L, D)
854
+ Returns: same shape as hidden_states
855
+ """
856
+
857
+ if self.yoco_cross:
858
+ out = self.in_proj(hidden_states)
859
+ out = swiglu(out, yoco_key_values)
860
+ out = self.out_proj(out)
861
+ return out, yoco_key_values
862
+
863
+ batch, seqlen, _ = hidden_states.shape
864
+ conv_state, ssm_state = None, None
865
+ if inference_params is not None:
866
+ conv_state, ssm_state = self._get_states_from_cache(inference_params)
867
+ if cache_position[0] > 0: #inference_params.get_seq_length(self.layer_idx) > 0:
868
+ # The states are updated inplace
869
+ out, _, _, yoco_key_values = self.step(hidden_states, conv_state, ssm_state, yoco_key_values)
870
+ return out, yoco_key_values
871
+
872
+ # We do matmul and transpose BLH -> HBL at the same time
873
+ xz = rearrange(
874
+ self.in_proj.weight @ rearrange(hidden_states.to(dtype = self.in_proj.weight.dtype), "b l d -> d (b l)"),
875
+ "d (b l) -> b d l",
876
+ l=seqlen,
877
+ )
878
+ if self.in_proj.bias is not None:
879
+ xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")
880
+
881
+
882
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
883
+ # In the backward pass we write dx and dz next to each other to avoid torch.cat
884
+ if (not self.yoco_kv) and self.use_fast_path and inference_params is None: # Doesn't support outputting the states
885
+ out = mamba_inner_fn(
886
+ xz,
887
+ self.conv1d.weight,
888
+ self.conv1d.bias,
889
+ self.x_proj.weight,
890
+ self.dt_proj.weight,
891
+ self.out_proj.weight,
892
+ self.out_proj.bias,
893
+ A,
894
+ None, # input-dependent B
895
+ None, # input-dependent C
896
+ self.D.float(),
897
+ delta_bias=self.dt_proj.bias.float(),
898
+ mask=mask,
899
+ delta_softplus=True,
900
+ )
901
+ else:
902
+ x, z = xz.chunk(2, dim=1)
903
+ if self.yoco_kv:
904
+ z = z.transpose(-1,-2).contiguous()
905
+ if mask is not None:
906
+ x = x * mask.unsqueeze(1)
907
+ # Compute short convolution
908
+ if conv_state is not None:
909
+ # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
910
+ # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
911
+ conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) # Update state (B D W)
912
+ if causal_conv1d_fn is None:
913
+ x = self.act(self.conv1d(x)[..., :seqlen])
914
+ else:
915
+ assert self.activation in ["silu", "swish"]
916
+ x = causal_conv1d_fn(
917
+ x=x,
918
+ weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
919
+ bias=self.conv1d.bias,
920
+ activation=self.activation,
921
+ )
922
+ if mask is not None:
923
+ x = x * mask.unsqueeze(1)
924
+ # We're careful here about the layout, to avoid extra transposes.
925
+ # We want dt to have d as the slowest moving dimension
926
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
927
+ x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d)
928
+ dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
929
+ dt = self.dt_proj.weight @ dt.t()
930
+ dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
931
+ B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
932
+ C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
933
+ assert self.activation in ["silu", "swish"]
934
+ y = selective_scan_fn(
935
+ x,
936
+ dt,
937
+ A,
938
+ B,
939
+ C,
940
+ self.D.float(),
941
+ z= None if self.yoco_kv else z,
942
+ delta_bias=self.dt_proj.bias.float(),
943
+ delta_softplus=True,
944
+ return_last_state=ssm_state is not None,
945
+ )
946
+ if ssm_state is not None:
947
+ y, last_state = y
948
+ ssm_state.copy_(last_state)
949
+ y = rearrange(y, "b d l -> b l d")
950
+ if self.yoco_kv:
951
+ yoco_key_values = y
952
+ y = swiglu(z, y)
953
+ out = self.out_proj(y)
954
+ return out, yoco_key_values
955
+
956
+ def step(self, hidden_states, conv_state, ssm_state, yoco_key_values):
957
+ dtype = hidden_states.dtype
958
+ assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
959
+ xz = self.in_proj(hidden_states.to(dtype = self.in_proj.weight.dtype).squeeze(1)) # (B 2D)
960
+ x, z = xz.chunk(2, dim=-1) # (B D)
961
+
962
+ # Conv step
963
+ if causal_conv1d_update is None:
964
+ conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
965
+ conv_state[:, :, -1] = x
966
+ x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
967
+ if self.conv1d.bias is not None:
968
+ x = x + self.conv1d.bias
969
+ x = self.act(x).to(dtype=dtype)
970
+ else:
971
+ x = causal_conv1d_update(
972
+ x,
973
+ conv_state,
974
+ rearrange(self.conv1d.weight, "d 1 w -> d w"),
975
+ self.conv1d.bias,
976
+ self.activation,
977
+ )
978
+
979
+ x_db = self.x_proj(x) # (B dt_rank+2*d_state)
980
+ dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
981
+ # Don't add dt_bias here
982
+ dt = F.linear(dt, self.dt_proj.weight) # (B d_inner)
983
+ A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
984
+
985
+ # SSM step
986
+ if selective_state_update is None:
987
+ # Discretize A and B
988
+ dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
989
+ dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
990
+ dB = torch.einsum("bd,bn->bdn", dt, B)
991
+ ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
992
+ y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
993
+ y = y + self.D.to(dtype) * x
994
+ y = y * self.act(z) # (B D)
995
+ else:
996
+ y = selective_state_update(
997
+ ssm_state, x, dt, A, B, C, self.D, z= None if self.yoco_kv else z, dt_bias=self.dt_proj.bias, dt_softplus=True
998
+ )
999
+ if self.yoco_kv:
1000
+ yoco_key_values = y.unsqueeze(1)
1001
+ y = swiglu(z, y)
1002
+ out = self.out_proj(y)
1003
+ return out.unsqueeze(1), conv_state, ssm_state, yoco_key_values
1004
+
1005
+ def _get_states_from_cache(self, inference_params):
1006
+ conv_state, ssm_state = inference_params.key_cache[self.layer_idx], inference_params.value_cache[self.layer_idx]
1007
+ return conv_state, ssm_state
1008
+
1009
+
1010
+
1011
+
1012
+ class SambaYDecoderLayer(nn.Module):
1013
+ def __init__(self, config: Phi4FlashConfig, layer_idx: int):
1014
+ super().__init__()
1015
+
1016
+ self.mlp = SambaYMLP(config)
1017
+ self.input_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
1018
+
1019
+ self.yoco_kv = False
1020
+ self.yoco_cross = False
1021
+ self.yoco_mb = False
1022
+ self.layer_idx = layer_idx
1023
+ assert config.num_hidden_layers % 4 == 0, 'n_layer should be divisible by 4 for SambaY '
1024
+ if layer_idx >= config.num_hidden_layers//2:
1025
+ self.yoco_mb = True
1026
+ self.yoco_kv = (layer_idx >= (config.num_hidden_layers//2 +1))
1027
+ self.yoco_cross = (layer_idx >= (config.num_hidden_layers//2 +2))
1028
+ if (layer_idx >= (config.num_hidden_layers//2 +1)):
1029
+ config = copy.deepcopy(config)
1030
+ config.sliding_window = None
1031
+ self.config= config
1032
+
1033
+ self.use_mamba = config.mb_per_layer > 0 and layer_idx % config.mb_per_layer == 0
1034
+ if self.use_mamba:
1035
+ factory_kwargs = {"d_conv": config.mamba_d_conv, "d_state": config.mamba_d_state, "expand": config.mamba_expand , "dtype": None}
1036
+ self.attn = Phi3Mamba(config.hidden_size, layer_idx=layer_idx, yoco_cross=self.yoco_cross, yoco_kv=self.yoco_mb, **factory_kwargs)
1037
+ else:
1038
+ self.attn = SambaYFlashAttention2(config, layer_idx=layer_idx, yoco_cross=self.yoco_cross)
1039
+
1040
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
1041
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
1042
+ self.post_attention_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
1043
+
1044
+ def forward(
1045
+ self,
1046
+ hidden_states: torch.Tensor,
1047
+ attention_mask: Optional[torch.Tensor] = None,
1048
+ position_ids: Optional[torch.LongTensor] = None,
1049
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1050
+ output_attentions: Optional[bool] = False,
1051
+ use_cache: Optional[bool] = False,
1052
+ cache_position: Optional[torch.LongTensor] = None,
1053
+ ssm_output: Optional[torch.Tensor] = None,
1054
+ yoco_key_values: Optional[torch.Tensor] = None,
1055
+ **kwargs,
1056
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
1057
+ """
1058
+ Args:
1059
+ hidden_states (`torch.FloatTensor`):
1060
+ input to the layer of shape `(batch, seq_len, embed_dim)`
1061
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1062
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
1063
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
1064
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
1065
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
1066
+ output_attentions (`bool`, *optional*):
1067
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1068
+ returned tensors for more detail.
1069
+ use_cache (`bool`, *optional*):
1070
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1071
+ (see `past_key_values`).
1072
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1073
+ """
1074
+
1075
+ residual = hidden_states
1076
+
1077
+ hidden_states = self.input_layernorm(hidden_states.to(dtype=self.input_layernorm.weight.dtype))
1078
+
1079
+ if self.use_mamba:
1080
+ attn_outputs, ssm_output = self.attn(
1081
+ hidden_states, inference_params=past_key_value,
1082
+ mask = attention_mask, yoco_key_values = ssm_output,
1083
+ cache_position=cache_position,
1084
+ )
1085
+ residual = residual.to(torch.float32)
1086
+ self_attn_weights = None
1087
+ else:
1088
+ if self.config.sliding_window is not None and self.config.sliding_window[self.layer_idx] is not None and attention_mask is not None: # efficient SDPA and no padding
1089
+ if past_key_value is not None and cache_position[0] > 0: # when decoding
1090
+ attention_mask = attention_mask[:, -self.config.sliding_window[self.layer_idx]:]
1091
+ #hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
1092
+ # Self Attention
1093
+ attn_outputs, self_attn_weights, yoco_key_values = self.attn(
1094
+ hidden_states=hidden_states,
1095
+ attention_mask=attention_mask,
1096
+ position_ids=position_ids,
1097
+ past_key_value=past_key_value,
1098
+ output_attentions=output_attentions,
1099
+ use_cache=use_cache,
1100
+ cache_position=cache_position,
1101
+ yoco_key_values = yoco_key_values,
1102
+ )
1103
+
1104
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
1105
+
1106
+ residual = hidden_states
1107
+ hidden_states = self.post_attention_layernorm(hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
1108
+ hidden_states = self.mlp(hidden_states)
1109
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
1110
+
1111
+ outputs = (hidden_states,)
1112
+ outputs += (ssm_output,)
1113
+ outputs += (yoco_key_values,)
1114
+ if output_attentions:
1115
+ outputs += (self_attn_weights,)
1116
+
1117
+ return outputs
1118
+
1119
+
1120
+ PHI_START_DOCSTRING = r"""
1121
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1122
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1123
+ etc.)
1124
+
1125
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1126
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1127
+ and behavior.
1128
+
1129
+ Parameters:
1130
+ config ([`Phi4FlashConfig`]):
1131
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1132
+ load the weights associated with the model, only the configuration. Check out the
1133
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1134
+ """
1135
+
1136
+
1137
+ @add_start_docstrings(
1138
+ "The bare Phi4Flash Model outputting raw hidden-states without any specific head on top.",
1139
+ PHI_START_DOCSTRING,
1140
+ )
1141
+ class Phi4FlashPreTrainedModel(PreTrainedModel):
1142
+ config_class = Phi4FlashConfig
1143
+ base_model_prefix = "model"
1144
+ supports_gradient_checkpointing = True
1145
+ _no_split_modules = ["SambaYDecoderLayer"]
1146
+ _skip_keys_device_placement = "past_key_values"
1147
+ _supports_flash_attn_2 = True
1148
+ _supports_sdpa = False
1149
+ _supports_cache_class = True
1150
+
1151
+ def _init_weights(self, module):
1152
+ std = self.config.initializer_range
1153
+ if isinstance(module, nn.Linear):
1154
+ module.weight.data.normal_(mean=0.0, std=std)
1155
+ if module.bias is not None:
1156
+ module.bias.data.zero_()
1157
+ elif isinstance(module, nn.Embedding):
1158
+ module.weight.data.normal_(mean=0.0, std=std)
1159
+ if module.padding_idx is not None:
1160
+ module.weight.data[module.padding_idx].zero_()
1161
+
1162
+
1163
+ PHI_INPUTS_DOCSTRING = r"""
1164
+ Args:
1165
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1166
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1167
+ it.
1168
+
1169
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1170
+ [`PreTrainedTokenizer.__call__`] for details.
1171
+
1172
+ [What are input IDs?](../glossary#input-ids)
1173
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1174
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1175
+
1176
+ - 1 for tokens that are **not masked**,
1177
+ - 0 for tokens that are **masked**.
1178
+
1179
+ [What are attention masks?](../glossary#attention-mask)
1180
+
1181
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1182
+ [`PreTrainedTokenizer.__call__`] for details.
1183
+
1184
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1185
+ `past_key_values`).
1186
+
1187
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1188
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1189
+ information on the default strategy.
1190
+
1191
+ - 1 indicates the head is **not masked**,
1192
+ - 0 indicates the head is **masked**.
1193
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1194
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1195
+ config.n_positions - 1]`.
1196
+
1197
+ [What are position IDs?](../glossary#position-ids)
1198
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1199
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1200
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1201
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1202
+
1203
+ Two formats are allowed:
1204
+ - a [`~cache_utils.Cache`] instance;
1205
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1206
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1207
+ cache format.
1208
+
1209
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1210
+ legacy cache format will be returned.
1211
+
1212
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1213
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1214
+ of shape `(batch_size, sequence_length)`.
1215
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1216
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1217
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1218
+ model's internal embedding lookup matrix.
1219
+ use_cache (`bool`, *optional*):
1220
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1221
+ `past_key_values`).
1222
+ output_attentions (`bool`, *optional*):
1223
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1224
+ tensors for more detail.
1225
+ output_hidden_states (`bool`, *optional*):
1226
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1227
+ more detail.
1228
+ return_dict (`bool`, *optional*):
1229
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1230
+ """
1231
+
1232
+
1233
+ @add_start_docstrings(
1234
+ "The bare Phi4Flash Model outputting raw hidden-states without any specific head on top.",
1235
+ PHI_START_DOCSTRING,
1236
+ )
1237
+ class Phi4FlashModel(Phi4FlashPreTrainedModel):
1238
+ """
1239
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SambaYDecoderLayer`]
1240
+
1241
+ Args:
1242
+ config: Phi4FlashConfig
1243
+ """
1244
+
1245
+ def __init__(self, config: Phi4FlashConfig):
1246
+ super().__init__(config)
1247
+ self.padding_idx = config.pad_token_id
1248
+ self.vocab_size = config.vocab_size
1249
+
1250
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1251
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
1252
+ self.layers = nn.ModuleList(
1253
+ [SambaYDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1254
+ )
1255
+ self.final_layernorm = PHI_NORM_CLASS(config.hidden_size, eps=config.layer_norm_eps)
1256
+
1257
+ self._attn_implementation = config._attn_implementation
1258
+
1259
+ self.gradient_checkpointing = False
1260
+ # Initialize weights and apply final processing
1261
+ self.post_init()
1262
+
1263
+ def get_input_embeddings(self):
1264
+ return self.embed_tokens
1265
+
1266
+ def set_input_embeddings(self, value):
1267
+ self.embed_tokens = value
1268
+
1269
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1270
+ def forward(
1271
+ self,
1272
+ input_ids: torch.LongTensor = None,
1273
+ attention_mask: Optional[torch.Tensor] = None,
1274
+ position_ids: Optional[torch.LongTensor] = None,
1275
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1276
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1277
+ use_cache: Optional[bool] = None,
1278
+ output_attentions: Optional[bool] = None,
1279
+ output_hidden_states: Optional[bool] = None,
1280
+ return_dict: Optional[bool] = None,
1281
+ cache_position: Optional[torch.LongTensor] = None,
1282
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1283
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1284
+ output_hidden_states = (
1285
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1286
+ )
1287
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1288
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1289
+
1290
+ # retrieve input_ids and inputs_embeds
1291
+ if input_ids is not None and inputs_embeds is not None:
1292
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1293
+ elif input_ids is not None:
1294
+ batch_size, seq_length = input_ids.shape[:2]
1295
+ elif inputs_embeds is not None:
1296
+ batch_size, seq_length = inputs_embeds.shape[:2]
1297
+ else:
1298
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1299
+
1300
+
1301
+ if self.gradient_checkpointing and self.training:
1302
+ if use_cache:
1303
+ logger.warning_once(
1304
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1305
+ )
1306
+ use_cache = False
1307
+
1308
+ if inputs_embeds is None:
1309
+ inputs_embeds = self.embed_tokens(input_ids)
1310
+
1311
+ if use_cache and past_key_values is None and not self.training:
1312
+ batch_size, seq_len, _ = inputs_embeds.shape
1313
+ past_key_values = SambaYCache(
1314
+ self.config,
1315
+ max_batch_size=batch_size,
1316
+ max_cache_len=seq_len,
1317
+ device=self.device,
1318
+ dtype=inputs_embeds.dtype,
1319
+ )
1320
+
1321
+
1322
+ if cache_position is None:
1323
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1324
+ cache_position = torch.arange(
1325
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1326
+ )
1327
+
1328
+ if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache and not self.training:
1329
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
1330
+ if is_padding_right:
1331
+ raise ValueError(
1332
+ "You are attempting to perform batched generation with padding_side='right'"
1333
+ " this may lead to unexpected behaviour for Flash Attention version of Phi4Flash. Make sure to "
1334
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1335
+ )
1336
+
1337
+ hidden_states = inputs_embeds
1338
+
1339
+ # decoder layers
1340
+ all_hidden_states = () if output_hidden_states else None
1341
+ all_self_attns = () if output_attentions else None
1342
+ ssm_output = None
1343
+ yoco_key_values = None
1344
+ for decoder_layer in self.layers: # TODO: only need to inference the first half of the layers during pre-fill
1345
+ if output_hidden_states:
1346
+ all_hidden_states += (hidden_states,)
1347
+
1348
+ if self.gradient_checkpointing and self.training:
1349
+ layer_outputs = self._gradient_checkpointing_func(
1350
+ decoder_layer.__call__,
1351
+ hidden_states,
1352
+ attention_mask,
1353
+ position_ids,
1354
+ past_key_values,
1355
+ output_attentions,
1356
+ use_cache,
1357
+ cache_position,
1358
+ ssm_output,
1359
+ yoco_key_values,
1360
+ )
1361
+ else:
1362
+ layer_outputs = decoder_layer(
1363
+ hidden_states,
1364
+ attention_mask=attention_mask,
1365
+ position_ids=position_ids,
1366
+ past_key_value=past_key_values,
1367
+ output_attentions=output_attentions,
1368
+ use_cache=use_cache,
1369
+ cache_position = cache_position,
1370
+ ssm_output = ssm_output,
1371
+ yoco_key_values = yoco_key_values,
1372
+ )
1373
+
1374
+ hidden_states = layer_outputs[0]
1375
+ ssm_output = layer_outputs[1]
1376
+ yoco_key_values = layer_outputs[2]
1377
+
1378
+ if output_attentions:
1379
+ all_self_attns += (layer_outputs[3],)
1380
+
1381
+ hidden_states = self.final_layernorm(hidden_states.to(dtype=self.final_layernorm.weight.dtype))
1382
+
1383
+ # add hidden states from the last decoder layer
1384
+ if output_hidden_states:
1385
+ all_hidden_states += (hidden_states,)
1386
+
1387
+ output = BaseModelOutputWithPast(
1388
+ last_hidden_state=hidden_states,
1389
+ past_key_values=past_key_values,
1390
+ hidden_states=all_hidden_states,
1391
+ attentions=all_self_attns,
1392
+ )
1393
+ return output if return_dict else output.to_tuple()
1394
+
1395
+
1396
+
1397
+ class Phi4FlashForCausalLM(Phi4FlashPreTrainedModel, GenerationMixin):
1398
+ _tied_weights_keys = ["lm_head.weight"]
1399
+
1400
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi4Flash,bias=False->bias=True
1401
+ def __init__(self, config):
1402
+ super().__init__(config)
1403
+ self.model = Phi4FlashModel(config)
1404
+ self.vocab_size = config.vocab_size
1405
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1406
+
1407
+ # Initialize weights and apply final processing
1408
+ self.post_init()
1409
+
1410
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
1411
+ def get_input_embeddings(self):
1412
+ return self.model.embed_tokens
1413
+
1414
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
1415
+ def set_input_embeddings(self, value):
1416
+ self.model.embed_tokens = value
1417
+
1418
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
1419
+ def get_output_embeddings(self):
1420
+ return self.lm_head
1421
+
1422
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
1423
+ def set_output_embeddings(self, new_embeddings):
1424
+ self.lm_head = new_embeddings
1425
+
1426
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
1427
+ def set_decoder(self, decoder):
1428
+ self.model = decoder
1429
+
1430
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
1431
+ def get_decoder(self):
1432
+ return self.model
1433
+
1434
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1435
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1436
+ def forward(
1437
+ self,
1438
+ input_ids: torch.LongTensor = None,
1439
+ attention_mask: Optional[torch.Tensor] = None,
1440
+ position_ids: Optional[torch.LongTensor] = None,
1441
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1442
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1443
+ labels: Optional[torch.LongTensor] = None,
1444
+ use_cache: Optional[bool] = None,
1445
+ output_attentions: Optional[bool] = None,
1446
+ output_hidden_states: Optional[bool] = None,
1447
+ return_dict: Optional[bool] = None,
1448
+ cache_position: Optional[torch.LongTensor] = None,
1449
+ num_logits_to_keep: int = 0,
1450
+ **loss_kwargs,
1451
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1452
+ r"""
1453
+ Args:
1454
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1455
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1456
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1457
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1458
+
1459
+ Returns:
1460
+
1461
+ Example:
1462
+
1463
+ ```python
1464
+ >>> from transformers import AutoTokenizer, Phi4FlashForCausalLM
1465
+
1466
+ >>> model = Phi4FlashForCausalLM.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
1467
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi4-mini-flash-reasoning")
1468
+
1469
+ >>> prompt = "This is an example script ."
1470
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1471
+
1472
+ >>> # Generate
1473
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1474
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1475
+ 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str'
1476
+ ```"""
1477
+
1478
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1479
+ output_hidden_states = (
1480
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1481
+ )
1482
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1483
+
1484
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1485
+ outputs = self.model(
1486
+ input_ids=input_ids,
1487
+ attention_mask=attention_mask,
1488
+ position_ids=position_ids,
1489
+ past_key_values=past_key_values,
1490
+ inputs_embeds=inputs_embeds,
1491
+ use_cache=use_cache,
1492
+ output_attentions=output_attentions,
1493
+ output_hidden_states=output_hidden_states,
1494
+ return_dict=return_dict,
1495
+ cache_position = cache_position,
1496
+ )
1497
+
1498
+ hidden_states = outputs[0]
1499
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1500
+
1501
+ loss = None
1502
+ if labels is not None:
1503
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1504
+
1505
+ if not return_dict:
1506
+ output = (logits,) + outputs[1:]
1507
+ return (loss,) + output if loss is not None else output
1508
+
1509
+ return CausalLMOutputWithPast(
1510
+ loss=loss,
1511
+ logits=logits,
1512
+ past_key_values=outputs.past_key_values,
1513
+ hidden_states=outputs.hidden_states,
1514
+ attentions=outputs.attentions,
1515
+ )
1516
+
1517
+
1518
+ @add_start_docstrings(
1519
+ """
1520
+ The Phi4FlashModel with a sequence classification head on top (linear layer).
1521
+
1522
+ [`Phi4FlashForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1523
+ (e.g. GPT-2) do.
1524
+
1525
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1526
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1527
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1528
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1529
+ each row of the batch).
1530
+ """,
1531
+ PHI_START_DOCSTRING,
1532
+ )
1533
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi4Flash with self.transformer->self.model, transformer_outputs->model_outputs
1534
+ class Phi4FlashForSequenceClassification(Phi4FlashPreTrainedModel):
1535
+ def __init__(self, config):
1536
+ super().__init__(config)
1537
+ self.num_labels = config.num_labels
1538
+ self.model = Phi4FlashModel(config)
1539
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1540
+
1541
+ # Initialize weights and apply final processing
1542
+ self.post_init()
1543
+
1544
+ def get_input_embeddings(self):
1545
+ return self.model.embed_tokens
1546
+
1547
+ def set_input_embeddings(self, value):
1548
+ self.model.embed_tokens = value
1549
+
1550
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1551
+ def forward(
1552
+ self,
1553
+ input_ids: torch.LongTensor = None,
1554
+ attention_mask: Optional[torch.Tensor] = None,
1555
+ position_ids: Optional[torch.LongTensor] = None,
1556
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1557
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1558
+ labels: Optional[torch.LongTensor] = None,
1559
+ use_cache: Optional[bool] = None,
1560
+ output_attentions: Optional[bool] = None,
1561
+ output_hidden_states: Optional[bool] = None,
1562
+ return_dict: Optional[bool] = None,
1563
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1564
+ r"""
1565
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1566
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1567
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1568
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1569
+ """
1570
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1571
+
1572
+ model_outputs = self.model(
1573
+ input_ids,
1574
+ attention_mask=attention_mask,
1575
+ position_ids=position_ids,
1576
+ past_key_values=past_key_values,
1577
+ inputs_embeds=inputs_embeds,
1578
+ use_cache=use_cache,
1579
+ output_attentions=output_attentions,
1580
+ output_hidden_states=output_hidden_states,
1581
+ return_dict=return_dict,
1582
+ )
1583
+ hidden_states = model_outputs[0]
1584
+ logits = self.score(hidden_states)
1585
+
1586
+ if input_ids is not None:
1587
+ batch_size = input_ids.shape[0]
1588
+ else:
1589
+ batch_size = inputs_embeds.shape[0]
1590
+
1591
+ if self.config.pad_token_id is None and batch_size != 1:
1592
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1593
+ if self.config.pad_token_id is None:
1594
+ sequence_lengths = -1
1595
+ else:
1596
+ if input_ids is not None:
1597
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1598
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1599
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1600
+ sequence_lengths = sequence_lengths.to(logits.device)
1601
+ else:
1602
+ sequence_lengths = -1
1603
+
1604
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1605
+
1606
+ loss = None
1607
+ if labels is not None:
1608
+ labels = labels.to(logits.device)
1609
+ if self.config.problem_type is None:
1610
+ if self.num_labels == 1:
1611
+ self.config.problem_type = "regression"
1612
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1613
+ self.config.problem_type = "single_label_classification"
1614
+ else:
1615
+ self.config.problem_type = "multi_label_classification"
1616
+
1617
+ if self.config.problem_type == "regression":
1618
+ loss_fct = MSELoss()
1619
+ if self.num_labels == 1:
1620
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1621
+ else:
1622
+ loss = loss_fct(pooled_logits, labels)
1623
+ elif self.config.problem_type == "single_label_classification":
1624
+ loss_fct = CrossEntropyLoss()
1625
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1626
+ elif self.config.problem_type == "multi_label_classification":
1627
+ loss_fct = BCEWithLogitsLoss()
1628
+ loss = loss_fct(pooled_logits, labels)
1629
+ if not return_dict:
1630
+ output = (pooled_logits,) + model_outputs[1:]
1631
+ return ((loss,) + output) if loss is not None else output
1632
+
1633
+ return SequenceClassifierOutputWithPast(
1634
+ loss=loss,
1635
+ logits=pooled_logits,
1636
+ past_key_values=model_outputs.past_key_values,
1637
+ hidden_states=model_outputs.hidden_states,
1638
+ attentions=model_outputs.attentions,
1639
+ )
1640
+
1641
+
1642
+ @add_start_docstrings(
1643
+ """
1644
+ Phi4FlashModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1645
+ Named-Entity-Recognition (NER) tasks.
1646
+ """,
1647
+ PHI_START_DOCSTRING,
1648
+ )
1649
+ # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi4Flash,self.transformer->self.model,transformer_outputs->model_outputs
1650
+ class Phi4FlashForTokenClassification(Phi4FlashPreTrainedModel):
1651
+ def __init__(self, config: Phi4FlashConfig):
1652
+ super().__init__(config)
1653
+ self.num_labels = config.num_labels
1654
+
1655
+ self.model = Phi4FlashModel(config)
1656
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1657
+ classifier_dropout = config.classifier_dropout
1658
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1659
+ classifier_dropout = config.hidden_dropout
1660
+ else:
1661
+ classifier_dropout = 0.1
1662
+ self.dropout = nn.Dropout(classifier_dropout)
1663
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1664
+
1665
+ # Initialize weights and apply final processing
1666
+ self.post_init()
1667
+
1668
+ @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING)
1669
+ @add_code_sample_docstrings(
1670
+ checkpoint=_CHECKPOINT_FOR_DOC,
1671
+ output_type=TokenClassifierOutput,
1672
+ config_class=_CONFIG_FOR_DOC,
1673
+ )
1674
+ def forward(
1675
+ self,
1676
+ input_ids: Optional[torch.LongTensor] = None,
1677
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1678
+ attention_mask: Optional[torch.Tensor] = None,
1679
+ inputs_embeds: Optional[torch.Tensor] = None,
1680
+ labels: Optional[torch.Tensor] = None,
1681
+ use_cache: Optional[bool] = None,
1682
+ output_attentions: Optional[bool] = None,
1683
+ output_hidden_states: Optional[bool] = None,
1684
+ return_dict: Optional[bool] = None,
1685
+ **deprecated_arguments,
1686
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1687
+ r"""
1688
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1689
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1690
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1691
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1692
+ """
1693
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1694
+
1695
+ model_outputs = self.model(
1696
+ input_ids,
1697
+ past_key_values=past_key_values,
1698
+ attention_mask=attention_mask,
1699
+ inputs_embeds=inputs_embeds,
1700
+ use_cache=use_cache,
1701
+ output_attentions=output_attentions,
1702
+ output_hidden_states=output_hidden_states,
1703
+ return_dict=return_dict,
1704
+ )
1705
+
1706
+ hidden_states = model_outputs[0]
1707
+ hidden_states = self.dropout(hidden_states)
1708
+ logits = self.classifier(hidden_states)
1709
+
1710
+ loss = None
1711
+ if labels is not None:
1712
+ # move labels to correct device to enable model parallelism
1713
+ labels = labels.to(logits.device)
1714
+ batch_size, seq_length = labels.shape
1715
+ loss_fct = CrossEntropyLoss()
1716
+ loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
1717
+
1718
+ if not return_dict:
1719
+ output = (logits,) + model_outputs[2:]
1720
+ return ((loss,) + output) if loss is not None else output
1721
+
1722
+ return TokenClassifierOutput(
1723
+ loss=loss,
1724
+ logits=logits,
1725
+ hidden_states=model_outputs.hidden_states,
1726
+ attentions=model_outputs.attentions,
1727
+ )
1728
+
1729
+ ## support batched generation
1730
+
1731
+ class SelectiveScanFn(torch.autograd.Function):
1732
+
1733
+ @staticmethod
1734
+ def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
1735
+ return_last_state=False):
1736
+ if u.stride(-1) != 1:
1737
+ u = u.contiguous()
1738
+ if delta.stride(-1) != 1:
1739
+ delta = delta.contiguous()
1740
+ if D is not None:
1741
+ D = D.contiguous()
1742
+ if B.stride(-1) != 1:
1743
+ B = B.contiguous()
1744
+ if C.stride(-1) != 1:
1745
+ C = C.contiguous()
1746
+ if z is not None and z.stride(-1) != 1:
1747
+ z = z.contiguous()
1748
+ if B.dim() == 3:
1749
+ B = rearrange(B, "b dstate l -> b 1 dstate l")
1750
+ ctx.squeeze_B = True
1751
+ if C.dim() == 3:
1752
+ C = rearrange(C, "b dstate l -> b 1 dstate l")
1753
+ ctx.squeeze_C = True
1754
+ out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
1755
+ ctx.delta_softplus = delta_softplus
1756
+ ctx.has_z = z is not None
1757
+ last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
1758
+ if not ctx.has_z:
1759
+ ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
1760
+ return out if not return_last_state else (out, last_state)
1761
+ else:
1762
+ ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
1763
+ out_z = rest[0]
1764
+ return out_z if not return_last_state else (out_z, last_state)
1765
+
1766
+ @staticmethod
1767
+ def backward(ctx, dout, *args):
1768
+ if not ctx.has_z:
1769
+ u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
1770
+ z = None
1771
+ out = None
1772
+ else:
1773
+ u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
1774
+ if dout.stride(-1) != 1:
1775
+ dout = dout.contiguous()
1776
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
1777
+ # backward of selective_scan_cuda with the backward of chunk).
1778
+ # Here we just pass in None and dz will be allocated in the C++ code.
1779
+ du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
1780
+ u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
1781
+ False # option to recompute out_z, not used here
1782
+ )
1783
+ dz = rest[0] if ctx.has_z else None
1784
+ dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
1785
+ dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
1786
+ return (du, ddelta, dA, dB, dC,
1787
+ dD if D is not None else None,
1788
+ dz,
1789
+ ddelta_bias if delta_bias is not None else None,
1790
+ None,
1791
+ None)
1792
+
1793
+
1794
+ def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
1795
+ return_last_state=False):
1796
+ """if return_last_state is True, returns (out, last_state)
1797
+ last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
1798
+ not considered in the backward pass.
1799
+ """
1800
+ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
1801
+
1802
+
1803
+ class MambaInnerFn(torch.autograd.Function):
1804
+
1805
+ @staticmethod
1806
+ @custom_fwd(device_type="cuda")
1807
+ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
1808
+ out_proj_weight, out_proj_bias,
1809
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
1810
+ C_proj_bias=None, mask=None, delta_softplus=True, checkpoint_lvl=1,):
1811
+ """
1812
+ xz: (batch, dim, seqlen)
1813
+ """
1814
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
1815
+ assert checkpoint_lvl in [0, 1]
1816
+ L = xz.shape[-1]
1817
+ delta_rank = delta_proj_weight.shape[1]
1818
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
1819
+ if torch.is_autocast_enabled():
1820
+ x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
1821
+ delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
1822
+ out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
1823
+ out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
1824
+ if out_proj_bias is not None else None)
1825
+ if xz.stride(-1) != 1:
1826
+ xz = xz.contiguous()
1827
+ conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
1828
+ x, z = xz.chunk(2, dim=1)
1829
+ if mask is not None:
1830
+ x = x * mask.unsqueeze(1)
1831
+ conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
1832
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
1833
+ x, conv1d_weight, conv1d_bias, None, None, None, True
1834
+ )
1835
+ if mask is not None:
1836
+ conv1d_out = conv1d_out * mask.unsqueeze(1)
1837
+ # We're being very careful here about the layout, to avoid extra transposes.
1838
+ # We want delta to have d as the slowest moving dimension
1839
+ # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
1840
+ x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
1841
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
1842
+ ctx.is_variable_B = B is None
1843
+ ctx.is_variable_C = C is None
1844
+ ctx.B_proj_bias_is_None = B_proj_bias is None
1845
+ ctx.C_proj_bias_is_None = C_proj_bias is None
1846
+ if B is None: # variable B
1847
+ B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
1848
+ if B_proj_bias is not None:
1849
+ B = B + B_proj_bias.to(dtype=B.dtype)
1850
+ if not A.is_complex():
1851
+ # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
1852
+ B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
1853
+ else:
1854
+ B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
1855
+ else:
1856
+ if B.stride(-1) != 1:
1857
+ B = B.contiguous()
1858
+ if C is None: # variable C
1859
+ C = x_dbl[:, -d_state:] # (bl dstate)
1860
+ if C_proj_bias is not None:
1861
+ C = C + C_proj_bias.to(dtype=C.dtype)
1862
+ if not A.is_complex():
1863
+ # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
1864
+ C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
1865
+ else:
1866
+ C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
1867
+ else:
1868
+ if C.stride(-1) != 1:
1869
+ C = C.contiguous()
1870
+ if D is not None:
1871
+ D = D.contiguous()
1872
+ out, scan_intermediates, out_z = selective_scan_cuda.fwd(
1873
+ conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
1874
+ )
1875
+ ctx.delta_softplus = delta_softplus
1876
+ ctx.out_proj_bias_is_None = out_proj_bias is None
1877
+ ctx.checkpoint_lvl = checkpoint_lvl
1878
+ if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
1879
+ conv1d_out, delta = None, None
1880
+ ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
1881
+ delta_proj_weight, out_proj_weight, conv1d_out, delta,
1882
+ A, B, C, D, delta_bias, scan_intermediates, out)
1883
+ return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
1884
+
1885
+ @staticmethod
1886
+ @custom_bwd(device_type="cuda")
1887
+ def backward(ctx, dout):
1888
+ # dout: (batch, seqlen, dim)
1889
+ assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
1890
+ (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
1891
+ conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
1892
+ L = xz.shape[-1]
1893
+ delta_rank = delta_proj_weight.shape[1]
1894
+ d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
1895
+ x, z = xz.chunk(2, dim=1)
1896
+ if dout.stride(-1) != 1:
1897
+ dout = dout.contiguous()
1898
+ if ctx.checkpoint_lvl == 1:
1899
+ conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
1900
+ x, conv1d_weight, conv1d_bias, None, None, None, True
1901
+ )
1902
+ delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
1903
+ "d (b l) -> b d l", l = L)
1904
+ # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
1905
+ # backward of selective_scan_cuda with the backward of chunk).
1906
+ dxz = torch.empty_like(xz) # (batch, dim, seqlen)
1907
+ dx, dz = dxz.chunk(2, dim=1)
1908
+ dout = rearrange(dout, "b l e -> e (b l)")
1909
+ dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
1910
+ dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
1911
+ conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
1912
+ ctx.delta_softplus,
1913
+ True # option to recompute out_z
1914
+ )
1915
+ dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
1916
+ dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
1917
+ dD = dD if D is not None else None
1918
+ dx_dbl = torch.empty_like(x_dbl)
1919
+ dB_proj_bias = None
1920
+ if ctx.is_variable_B:
1921
+ if not A.is_complex():
1922
+ dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
1923
+ else:
1924
+ dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
1925
+ dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
1926
+ dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
1927
+ dB = None
1928
+ dC_proj_bias = None
1929
+ if ctx.is_variable_C:
1930
+ if not A.is_complex():
1931
+ dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
1932
+ else:
1933
+ dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
1934
+ dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
1935
+ dx_dbl[:, -d_state:] = dC # (bl d)
1936
+ dC = None
1937
+ ddelta = rearrange(ddelta, "b d l -> d (b l)")
1938
+ ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
1939
+ dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
1940
+ dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
1941
+ dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
1942
+ dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
1943
+ dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
1944
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
1945
+ # backward of conv1d with the backward of chunk).
1946
+ dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
1947
+ x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
1948
+ )
1949
+ dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
1950
+ dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
1951
+ return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
1952
+ dout_proj_weight, dout_proj_bias,
1953
+ dA, dB, dC, dD,
1954
+ ddelta_bias if delta_bias is not None else None,
1955
+ dB_proj_bias, dC_proj_bias, None, None)
1956
+
1957
+
1958
+ def mamba_inner_fn(
1959
+ xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
1960
+ out_proj_weight, out_proj_bias,
1961
+ A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
1962
+ C_proj_bias=None, mask=None, delta_softplus=True
1963
+ ):
1964
+ return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
1965
+ out_proj_weight, out_proj_bias,
1966
+ A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, mask, delta_softplus)
1967
+
1968
+
1969
+ def lambda_init_fn(depth):
1970
+ return 0.8 - 0.6 * math.exp(-0.3 * depth)
1971
+
1972
+
1973
+ def split_heads(x):
1974
+ # split by num_heads, the stripe pattern is friendly to tensor parallel.
1975
+ x = rearrange(x, "... (H two) D -> ... H two D", two=2)
1976
+ x1 = x[..., 0, :]
1977
+ x2 = x[..., 1, :]
1978
+ return x1, x2
1979
+
1980
+ class FlashDiffCustomAttention(nn.Module):
1981
+ """Implement the scaled dot product attention with softmax.
1982
+ Arguments
1983
+ ---------
1984
+ head_dim: The dimension of the heads.
1985
+ depth: The layer id, starting from 0.
1986
+ """
1987
+
1988
+ def __init__(
1989
+ self,
1990
+ head_dim,
1991
+ depth,
1992
+ fa_og = True,
1993
+ ):
1994
+ super().__init__()
1995
+ assert flash_attn_varlen_func is not None, "FlashAttention is not installed"
1996
+ assert flash_attn_func is not None, "FlashAttention is not installed"
1997
+ self.head_dim = head_dim
1998
+ self.fa_og = fa_og # turning it to false needs customized flash attention https://github.com/xiayuqing0622/flex_head_fa
1999
+ self.lambda_init = lambda_init_fn(depth)
2000
+ self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
2001
+ self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
2002
+ self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
2003
+ self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
2004
+
2005
+ self.subln = SambaYRMSNorm(2 * self.head_dim, eps=1e-5)
2006
+
2007
+ def forward(
2008
+ self,
2009
+ q,
2010
+ k,
2011
+ v,
2012
+ dropout_p = 0.0,
2013
+ cu_seqlens_q=None,
2014
+ max_seqlen_q=None,
2015
+ cu_seqlens_k=None,
2016
+ max_seqlen_k=None,
2017
+ softmax_scale=None,
2018
+ window_size=(-1, -1),
2019
+ causal=None,
2020
+ ):
2021
+ """Implements the multihead softmax attention.
2022
+ Arguments
2023
+ ---------
2024
+ q, k, v: The tensors containing the query, key, and value.
2025
+ If cu_seqlens is None and max_seqlen is None, then each has shape (B, S, H, D).
2026
+ If cu_seqlens is not None and max_seqlen is not None, then each has shape
2027
+ (total, H, D), where total is the sum of the sequence lengths in the batch.
2028
+ causal: if passed, will override self.causal
2029
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
2030
+ of the sequences in the batch, used to index into qkv.
2031
+ max_seqlen: int. Maximum sequence length in the batch.
2032
+ Returns:
2033
+ --------
2034
+ out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
2035
+ else (B, S, H, D).
2036
+ """
2037
+ q = q.to(torch.bfloat16)
2038
+ k = k.to(torch.bfloat16)
2039
+ v = v.to(torch.bfloat16)
2040
+
2041
+ assert q.dtype in [torch.float16, torch.bfloat16]
2042
+ assert q.is_cuda and k.is_cuda and v.is_cuda
2043
+ #causal = self.causal if causal is None else causal
2044
+ unpadded = cu_seqlens_q is not None
2045
+ q1, q2 = split_heads(q)
2046
+ k1, k2 = split_heads(k)
2047
+ if self.fa_og:
2048
+ v1, v2 = split_heads(v)
2049
+ else:
2050
+ v = rearrange(v, "... (H two) D -> ... H (two D)", two=2)
2051
+
2052
+ kwargs = {
2053
+ "dropout_p": dropout_p,
2054
+ "softmax_scale": softmax_scale,
2055
+ "causal": causal,
2056
+ "window_size": window_size,
2057
+ }
2058
+
2059
+ if unpadded:
2060
+ assert cu_seqlens_q.dtype == torch.int32
2061
+ assert max_seqlen_q is not None
2062
+ assert isinstance(max_seqlen_q, int)
2063
+ assert cu_seqlens_k is not None
2064
+ assert cu_seqlens_k.dtype == torch.int32
2065
+ assert max_seqlen_k is not None
2066
+ assert isinstance(max_seqlen_k, int)
2067
+
2068
+ kwargs.update({
2069
+ "cu_seqlens_q": cu_seqlens_q,
2070
+ "max_seqlen_q": max_seqlen_q,
2071
+ "cu_seqlens_k": cu_seqlens_k,
2072
+ "max_seqlen_k": max_seqlen_k,
2073
+ })
2074
+ attn_func = flash_attn_varlen_func
2075
+ else:
2076
+ attn_func = flash_attn_func
2077
+
2078
+ if self.fa_og:
2079
+ attn11 = attn_func(q1, k1, v1, **kwargs)
2080
+ attn12 = attn_func(q1, k1, v2, **kwargs)
2081
+ attn1 = torch.cat([attn11, attn12], dim=-1)
2082
+ attn21 = attn_func(q2, k2, v1, **kwargs)
2083
+ attn22 = attn_func(q2, k2, v2, **kwargs)
2084
+ attn2 = torch.cat([attn21, attn22], dim=-1)
2085
+ else:
2086
+ attn1 = attn_func(q1, k1, v, **kwargs)
2087
+ attn2 = attn_func(q2, k2, v, **kwargs)
2088
+
2089
+ lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
2090
+ lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
2091
+ lambda_full = lambda_1 - lambda_2 + self.lambda_init
2092
+
2093
+ attn = attn1 - lambda_full * attn2
2094
+ attn = self.subln(attn)
2095
+ attn = attn * (1 - self.lambda_init)
2096
+ # reshape back to 2 * num_head
2097
+ attn = rearrange(attn, "... H (two D) -> ... (H two) D", two=2)
2098
+ return attn
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "199999": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "200018": {
15
+ "content": "<|endofprompt|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "200019": {
23
+ "content": "<|assistant|>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": true,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "200020": {
31
+ "content": "<|end|>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": true,
35
+ "single_word": false,
36
+ "special": true
37
+ },
38
+ "200021": {
39
+ "content": "<|user|>",
40
+ "lstrip": false,
41
+ "normalized": false,
42
+ "rstrip": true,
43
+ "single_word": false,
44
+ "special": true
45
+ },
46
+ "200022": {
47
+ "content": "<|system|>",
48
+ "lstrip": false,
49
+ "normalized": false,
50
+ "rstrip": true,
51
+ "single_word": false,
52
+ "special": true
53
+ },
54
+ "200023": {
55
+ "content": "<|tool|>",
56
+ "lstrip": false,
57
+ "normalized": false,
58
+ "rstrip": true,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "200024": {
63
+ "content": "<|/tool|>",
64
+ "lstrip": false,
65
+ "normalized": false,
66
+ "rstrip": true,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "200025": {
71
+ "content": "<|tool_call|>",
72
+ "lstrip": false,
73
+ "normalized": false,
74
+ "rstrip": true,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "200026": {
79
+ "content": "<|/tool_call|>",
80
+ "lstrip": false,
81
+ "normalized": false,
82
+ "rstrip": true,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "200027": {
87
+ "content": "<|tool_response|>",
88
+ "lstrip": false,
89
+ "normalized": false,
90
+ "rstrip": true,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "200028": {
95
+ "content": "<|tag|>",
96
+ "lstrip": false,
97
+ "normalized": false,
98
+ "rstrip": true,
99
+ "single_word": false,
100
+ "special": true
101
+ }
102
+ },
103
+ "bos_token": "<|endoftext|>",
104
+ "chat_template": "{% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}",
105
+ "clean_up_tokenization_spaces": false,
106
+ "eos_token": "<|endoftext|>",
107
+ "model_max_length": 65536,
108
+ "pad_token": "<|endoftext|>",
109
+ "tokenizer_class": "GPT2Tokenizer",
110
+ "unk_token": "<|endoftext|>"
111
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff