yuanshuai commited on
Commit
d20c3ca
·
verified ·
1 Parent(s): e5be3b4

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. draft/pytorch_model.bin +1 -1
  2. draft/qwen2.py +641 -0
draft/pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5e1f84cbe63318de1308a4a7ed2b509349b30639c47635747ac3a55b8a7f8bb0
3
  size 1534011812
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f80f49be625de0c703ae279805a764b07377bae6d2a31468e75e12fbf2c17298
3
  size 1534011812
draft/qwen2.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from llama2.py
16
+ # Modify details for the adaptation of Qwen2 model.
17
+ """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
+ import logging
19
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union, List
20
+
21
+ import torch
22
+ from torch import nn
23
+
24
+ from sglang.srt.distributed import (
25
+ get_pp_group,
26
+ get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ )
29
+ from sglang.srt.layers.activation import SiluAndMul
30
+ from sglang.srt.layers.layernorm import RMSNorm
31
+ from sglang.srt.layers.linear import (
32
+ MergedColumnParallelLinear,
33
+ QKVParallelLinear,
34
+ RowParallelLinear,
35
+ )
36
+ from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.pooler import Pooler, PoolingType
38
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
+ from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.rotary_embedding import get_rope
41
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
42
+ from sglang.srt.layers.vocab_parallel_embedding import (
43
+ ParallelLMHead,
44
+ VocabParallelEmbedding,
45
+ )
46
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
+ from sglang.srt.model_loader.weight_utils import (
49
+ default_weight_loader,
50
+ kv_cache_scales_loader,
51
+ )
52
+ from sglang.srt.utils import add_prefix, make_layers
53
+
54
+ Qwen2Config = None
55
+
56
+
57
+ logger = logging.getLogger(__name__)
58
+
59
+
60
+ class Qwen2MLP(nn.Module):
61
+ def __init__(
62
+ self,
63
+ hidden_size: int,
64
+ intermediate_size: int,
65
+ hidden_act: str,
66
+ quant_config: Optional[QuantizationConfig] = None,
67
+ prefix: str = "",
68
+ ) -> None:
69
+ super().__init__()
70
+ self.gate_up_proj = MergedColumnParallelLinear(
71
+ hidden_size,
72
+ [intermediate_size] * 2,
73
+ bias=False,
74
+ quant_config=quant_config,
75
+ prefix=add_prefix("gate_up_proj", prefix),
76
+ )
77
+ self.down_proj = RowParallelLinear(
78
+ intermediate_size,
79
+ hidden_size,
80
+ bias=False,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("down_proj", prefix),
83
+ )
84
+ if hidden_act != "silu":
85
+ raise ValueError(
86
+ f"Unsupported activation: {hidden_act}. "
87
+ "Only silu is supported for now."
88
+ )
89
+ self.act_fn = SiluAndMul()
90
+
91
+ def forward(self, x):
92
+ gate_up, _ = self.gate_up_proj(x)
93
+ x = self.act_fn(gate_up)
94
+ x, _ = self.down_proj(x)
95
+ return x
96
+
97
+
98
+ class Qwen2Attention(nn.Module):
99
+ def __init__(
100
+ self,
101
+ hidden_size: int,
102
+ num_heads: int,
103
+ num_kv_heads: int,
104
+ head_dim: Optional[int] = None,
105
+ layer_id: int = 0,
106
+ rope_theta: float = 1000000,
107
+ rope_scaling: Optional[Dict[str, Any]] = None,
108
+ max_position_embeddings: int = 32768,
109
+ quant_config: Optional[QuantizationConfig] = None,
110
+ dual_chunk_attention_config: Optional[dict[str, Any]] = None,
111
+ prefix: str = "",
112
+ ) -> None:
113
+ super().__init__()
114
+ self.hidden_size = hidden_size
115
+ tp_size = get_tensor_model_parallel_world_size()
116
+ self.total_num_heads = num_heads
117
+ assert self.total_num_heads % tp_size == 0
118
+ self.num_heads = self.total_num_heads // tp_size
119
+ self.total_num_kv_heads = num_kv_heads
120
+ if self.total_num_kv_heads >= tp_size:
121
+ # Number of KV heads is greater than TP size, so we partition
122
+ # the KV heads across multiple tensor parallel GPUs.
123
+ assert self.total_num_kv_heads % tp_size == 0
124
+ else:
125
+ # Number of KV heads is less than TP size, so we replicate
126
+ # the KV heads across multiple tensor parallel GPUs.
127
+ assert tp_size % self.total_num_kv_heads == 0
128
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
129
+ if head_dim is not None:
130
+ self.head_dim = head_dim
131
+ else:
132
+ self.head_dim = hidden_size // self.total_num_heads
133
+ self.q_size = self.num_heads * self.head_dim
134
+ self.kv_size = self.num_kv_heads * self.head_dim
135
+ self.scaling = self.head_dim**-0.5
136
+ self.rope_theta = rope_theta
137
+ self.max_position_embeddings = max_position_embeddings
138
+
139
+ self.qkv_proj = QKVParallelLinear(
140
+ hidden_size,
141
+ self.head_dim,
142
+ self.total_num_heads,
143
+ self.total_num_kv_heads,
144
+ bias=True,
145
+ quant_config=quant_config,
146
+ prefix=add_prefix("qkv_proj", prefix),
147
+ )
148
+ self.o_proj = RowParallelLinear(
149
+ self.total_num_heads * self.head_dim,
150
+ hidden_size,
151
+ bias=False,
152
+ quant_config=quant_config,
153
+ prefix=add_prefix("o_proj", prefix),
154
+ )
155
+
156
+ self.rotary_emb = get_rope(
157
+ self.head_dim,
158
+ rotary_dim=self.head_dim,
159
+ max_position=max_position_embeddings,
160
+ base=rope_theta,
161
+ rope_scaling=rope_scaling,
162
+ dual_chunk_attention_config=dual_chunk_attention_config,
163
+ )
164
+ self.attn = RadixAttention(
165
+ self.num_heads,
166
+ self.head_dim,
167
+ self.scaling,
168
+ num_kv_heads=self.num_kv_heads,
169
+ layer_id=layer_id,
170
+ quant_config=quant_config,
171
+ prefix=add_prefix("attn", prefix),
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ forward_batch: ForwardBatch,
179
+ ) -> torch.Tensor:
180
+ qkv, _ = self.qkv_proj(hidden_states)
181
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
182
+ q, k = self.rotary_emb(positions, q, k)
183
+ attn_output = self.attn(q, k, v, forward_batch)
184
+ output, _ = self.o_proj(attn_output)
185
+ return output
186
+
187
+
188
+ class Qwen2DecoderLayer(nn.Module):
189
+ def __init__(
190
+ self,
191
+ config: Qwen2Config,
192
+ layer_id: int = 0,
193
+ quant_config: Optional[QuantizationConfig] = None,
194
+ prefix: str = "",
195
+ alt_stream: Optional[torch.cuda.Stream] = None,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.hidden_size = config.hidden_size
199
+ rope_theta = getattr(config, "rope_theta", 1000000)
200
+ rope_scaling = getattr(config, "rope_scaling", None)
201
+ max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
202
+ head_dim = getattr(config, "head_dim", None)
203
+ dual_chunk_attention_config = getattr(
204
+ config, "dual_chunk_attention_config", None
205
+ )
206
+ self.self_attn = Qwen2Attention(
207
+ hidden_size=self.hidden_size,
208
+ num_heads=config.num_attention_heads,
209
+ num_kv_heads=config.num_key_value_heads,
210
+ head_dim=head_dim,
211
+ layer_id=layer_id,
212
+ rope_theta=rope_theta,
213
+ rope_scaling=rope_scaling,
214
+ max_position_embeddings=max_position_embeddings,
215
+ quant_config=quant_config,
216
+ dual_chunk_attention_config=dual_chunk_attention_config,
217
+ prefix=add_prefix("self_attn", prefix),
218
+ )
219
+ self.mlp = Qwen2MLP(
220
+ hidden_size=self.hidden_size,
221
+ intermediate_size=config.intermediate_size,
222
+ hidden_act=config.hidden_act,
223
+ quant_config=quant_config,
224
+ prefix=add_prefix("mlp", prefix),
225
+ )
226
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
227
+ self.post_attention_layernorm = RMSNorm(
228
+ config.hidden_size, eps=config.rms_norm_eps
229
+ )
230
+
231
+ def forward(
232
+ self,
233
+ positions: torch.Tensor,
234
+ hidden_states: torch.Tensor,
235
+ forward_batch: ForwardBatch,
236
+ residual: Optional[torch.Tensor],
237
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
238
+ # Self Attention
239
+ if residual is None:
240
+ residual = hidden_states
241
+ hidden_states = self.input_layernorm(hidden_states)
242
+ else:
243
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
244
+ hidden_states = self.self_attn(
245
+ positions=positions,
246
+ hidden_states=hidden_states,
247
+ forward_batch=forward_batch,
248
+ )
249
+
250
+ # Fully Connected
251
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
252
+ hidden_states = self.mlp(hidden_states)
253
+ return hidden_states, residual
254
+
255
+
256
+ class Qwen2Model(nn.Module):
257
+ def __init__(
258
+ self,
259
+ config: Qwen2Config,
260
+ quant_config: Optional[QuantizationConfig] = None,
261
+ prefix: str = "",
262
+ decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
263
+ alt_stream: Optional[torch.cuda.Stream] = None,
264
+ ) -> None:
265
+ super().__init__()
266
+ self.config = config
267
+ self.padding_idx = config.pad_token_id
268
+ self.vocab_size = config.vocab_size
269
+ self.pp_group = get_pp_group()
270
+
271
+ if self.pp_group.is_first_rank:
272
+ self.embed_tokens = VocabParallelEmbedding(
273
+ config.vocab_size,
274
+ config.hidden_size,
275
+ quant_config=quant_config,
276
+ enable_tp=not global_server_args_dict["enable_dp_attention"],
277
+ prefix=add_prefix("embed_tokens", prefix),
278
+ )
279
+ else:
280
+ self.embed_tokens = PPMissingLayer()
281
+
282
+ # Use the provided decoder layer type or default to Qwen2DecoderLayer
283
+ decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
284
+ self.layers, self.start_layer, self.end_layer = make_layers(
285
+ config.num_hidden_layers,
286
+ lambda idx, prefix: decoder_layer_type(
287
+ layer_id=idx,
288
+ config=config,
289
+ quant_config=quant_config,
290
+ prefix=prefix,
291
+ alt_stream=alt_stream,
292
+ ),
293
+ pp_rank=self.pp_group.rank_in_group,
294
+ pp_size=self.pp_group.world_size,
295
+ prefix=add_prefix("layers", prefix),
296
+ )
297
+ if self.pp_group.is_last_rank:
298
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
299
+ else:
300
+ self.norm = PPMissingLayer(return_tuple=True)
301
+
302
+ # For EAGLE3 support
303
+ self.layers_to_capture = []
304
+
305
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
306
+ if hasattr(self.config, "scale_emb"):
307
+ return self.get_input_embeddings()(input_ids) * self.config.scale_emb
308
+ else:
309
+ return self.get_input_embeddings()(input_ids)
310
+
311
+ def get_input_embeddings(self) -> nn.Embedding:
312
+ return self.embed_tokens
313
+
314
+ def forward(
315
+ self,
316
+ input_ids: torch.Tensor,
317
+ positions: torch.Tensor,
318
+ forward_batch: ForwardBatch,
319
+ input_embeds: torch.Tensor = None,
320
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
321
+ ) -> Union[torch.Tensor, PPProxyTensors]:
322
+ if self.pp_group.is_first_rank:
323
+ if input_embeds is None:
324
+ hidden_states = self.embed_tokens(input_ids)
325
+ else:
326
+ hidden_states = input_embeds
327
+ residual = None
328
+ else:
329
+ assert pp_proxy_tensors is not None
330
+ hidden_states = pp_proxy_tensors["hidden_states"]
331
+ residual = pp_proxy_tensors["residual"]
332
+
333
+ aux_hidden_states = []
334
+ for i in range(self.start_layer, self.end_layer):
335
+ if i in self.layers_to_capture:
336
+ aux_hidden_states.append(
337
+ hidden_states + residual if residual is not None else hidden_states
338
+ )
339
+ layer = self.layers[i]
340
+ hidden_states, residual = layer(
341
+ positions,
342
+ hidden_states,
343
+ forward_batch,
344
+ residual,
345
+ )
346
+ if not self.pp_group.is_last_rank:
347
+ return PPProxyTensors(
348
+ {
349
+ "hidden_states": hidden_states,
350
+ "residual": residual,
351
+ }
352
+ )
353
+ else:
354
+ if hidden_states.shape[0] != 0:
355
+ if residual is None:
356
+ hidden_states = self.norm(hidden_states)
357
+ else:
358
+ hidden_states, _ = self.norm(hidden_states, residual)
359
+
360
+ if len(aux_hidden_states) == 0:
361
+ return hidden_states
362
+
363
+ return hidden_states, aux_hidden_states
364
+
365
+ # If this function is called, it should always initialize KV cache scale
366
+ # factors (or else raise an exception). Thus, handled exceptions should
367
+ # make sure to leave KV cache scale factors in a known good (dummy) state
368
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
369
+ tp_size = get_tensor_model_parallel_world_size()
370
+ tp_rank = get_tensor_model_parallel_rank()
371
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
372
+ quantization_param_path,
373
+ tp_rank,
374
+ tp_size,
375
+ self.config.num_hidden_layers,
376
+ self.config.__class__.model_type,
377
+ ):
378
+ if not isinstance(self.layers[layer_idx], nn.Identity):
379
+ layer_self_attn = self.layers[layer_idx].self_attn
380
+ if hasattr(layer_self_attn.attn, "k_scale"):
381
+ layer_self_attn.attn.k_scale = scaling_factor
382
+ layer_self_attn.attn.v_scale = scaling_factor
383
+ else:
384
+ raise RuntimeError(
385
+ "Self attention has no KV cache scaling " "factor attribute!"
386
+ )
387
+
388
+
389
+ class Qwen2ForCausalLM(nn.Module):
390
+ # BitandBytes specific attributes
391
+ default_bitsandbytes_target_modules = [
392
+ ".gate_proj.",
393
+ ".down_proj.",
394
+ ".up_proj.",
395
+ ".q_proj.",
396
+ ".k_proj.",
397
+ ".v_proj.",
398
+ ".o_proj.",
399
+ ]
400
+ bitsandbytes_stacked_params_mapping = {
401
+ # shard_name, weight_name, index
402
+ "q_proj": ("qkv_proj", 0),
403
+ "k_proj": ("qkv_proj", 1),
404
+ "v_proj": ("qkv_proj", 2),
405
+ "gate_proj": ("gate_up_proj", 0),
406
+ "up_proj": ("gate_up_proj", 1),
407
+ }
408
+
409
+ def __init__(
410
+ self,
411
+ config: Qwen2Config,
412
+ quant_config: Optional[QuantizationConfig] = None,
413
+ prefix: str = "",
414
+ ) -> None:
415
+ super().__init__()
416
+ self.pp_group = get_pp_group()
417
+ self.config = config
418
+ self.quant_config = quant_config
419
+ self.model = Qwen2Model(
420
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
421
+ )
422
+ self.capture_aux_hidden_states = False
423
+
424
+ # handle the lm head on different pp ranks
425
+ if self.pp_group.is_last_rank:
426
+ if self.pp_group.world_size == 1 and config.tie_word_embeddings:
427
+ self.lm_head = self.model.embed_tokens
428
+ else:
429
+ self.lm_head = ParallelLMHead(
430
+ config.vocab_size,
431
+ config.hidden_size,
432
+ quant_config=quant_config,
433
+ prefix=add_prefix("lm_head", prefix),
434
+ )
435
+ else:
436
+ # ranks other than the last rank will have a placeholder layer
437
+ self.lm_head = PPMissingLayer()
438
+
439
+ # perform weight tying for PP
440
+ if self.pp_group.world_size > 1 and config.tie_word_embeddings:
441
+ if self.pp_group.is_first_rank:
442
+ self.pp_group.send(
443
+ self.model.embed_tokens.weight, dst=self.pp_group.last_rank
444
+ )
445
+ else:
446
+ emb_token_weight = self.pp_group.recv(
447
+ size=(config.vocab_size, config.hidden_size),
448
+ dtype=next(self.model.parameters()).dtype,
449
+ src=self.pp_group.first_rank,
450
+ )
451
+ self.lm_head.weight.copy_(emb_token_weight)
452
+
453
+ self.logits_processor = LogitsProcessor(config)
454
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
455
+
456
+ def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
457
+ return self.model.get_input_embedding(input_ids)
458
+
459
+ def get_input_embeddings(self) -> nn.Embedding:
460
+ return self.model.embed_tokens
461
+
462
+ @torch.no_grad()
463
+ def forward(
464
+ self,
465
+ input_ids: torch.Tensor,
466
+ positions: torch.Tensor,
467
+ forward_batch: ForwardBatch,
468
+ input_embeds: torch.Tensor = None,
469
+ get_embedding: bool = False,
470
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
471
+ ) -> torch.Tensor:
472
+ hidden_states = self.model(
473
+ input_ids,
474
+ positions,
475
+ forward_batch,
476
+ input_embeds,
477
+ pp_proxy_tensors=pp_proxy_tensors,
478
+ )
479
+ aux_hidden_states = None
480
+ if self.capture_aux_hidden_states:
481
+ hidden_states, aux_hidden_states = hidden_states
482
+
483
+ if self.pp_group.is_last_rank:
484
+ if not get_embedding:
485
+ return self.logits_processor(
486
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
487
+ )
488
+ else:
489
+ return self.pooler(hidden_states, forward_batch)
490
+ else:
491
+ return hidden_states
492
+
493
+ @torch.no_grad()
494
+ def forward_split_prefill(
495
+ self,
496
+ input_ids: torch.Tensor,
497
+ positions: torch.Tensor,
498
+ forward_batch: ForwardBatch,
499
+ split_interval: Tuple[int, int], # [start, end) 0-based
500
+ input_embeds: torch.Tensor = None,
501
+ ):
502
+ start, end = split_interval
503
+ # embed
504
+ if start == 0:
505
+ if input_embeds is None:
506
+ forward_batch.hidden_states = self.model.embed_tokens(input_ids)
507
+ else:
508
+ forward_batch.hidden_states = input_embeds
509
+ # decoder layer
510
+ for i in range(start, end):
511
+ layer = self.model.layers[i]
512
+ forward_batch.hidden_states, forward_batch.residual = layer(
513
+ positions,
514
+ forward_batch.hidden_states,
515
+ forward_batch,
516
+ forward_batch.residual,
517
+ )
518
+
519
+ if end == self.model.config.num_hidden_layers:
520
+ # norm
521
+ hidden_states, _ = self.model.norm(
522
+ forward_batch.hidden_states, forward_batch.residual
523
+ )
524
+ forward_batch.hidden_states = hidden_states
525
+ # logits process
526
+ result = self.logits_processor(
527
+ input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
528
+ )
529
+ else:
530
+ result = None
531
+
532
+ return result
533
+
534
+ @property
535
+ def start_layer(self):
536
+ return self.model.start_layer
537
+
538
+ @property
539
+ def end_layer(self):
540
+ return self.model.end_layer
541
+
542
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
543
+ stacked_params_mapping = [
544
+ # (param_name, shard_name, shard_id)
545
+ ("qkv_proj", "q_proj", "q"),
546
+ ("qkv_proj", "k_proj", "k"),
547
+ ("qkv_proj", "v_proj", "v"),
548
+ ("gate_up_proj", "gate_proj", 0),
549
+ ("gate_up_proj", "up_proj", 1),
550
+ ]
551
+
552
+ params_dict = dict(self.named_parameters())
553
+ for name, loaded_weight in weights:
554
+ layer_id = get_layer_id(name)
555
+ if (
556
+ layer_id is not None
557
+ and hasattr(self.model, "start_layer")
558
+ and (
559
+ layer_id < self.model.start_layer
560
+ or layer_id >= self.model.end_layer
561
+ )
562
+ ):
563
+ continue
564
+
565
+ if "rotary_emb.inv_freq" in name or "projector" in name:
566
+ continue
567
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
568
+ # Models trained using ColossalAI may include these tensors in
569
+ # the checkpoint. Skip them.
570
+ continue
571
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
572
+ if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
573
+ # Handle pp weight tying here
574
+ # find the embed_tokens.weight in the weights
575
+ embed_token_weights = next(
576
+ filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
577
+ )[1]
578
+ loaded_weight = embed_token_weights
579
+ else:
580
+ continue
581
+ if name.startswith("model.vision_tower") and name not in params_dict:
582
+ continue
583
+
584
+ for param_name, weight_name, shard_id in stacked_params_mapping:
585
+ if weight_name not in name:
586
+ continue
587
+ name = name.replace(weight_name, param_name)
588
+ # Skip loading extra bias for GPTQ models.
589
+ if name.endswith(".bias") and name not in params_dict:
590
+ continue
591
+ if name not in params_dict:
592
+ continue
593
+ param = params_dict[name]
594
+ weight_loader = param.weight_loader
595
+ weight_loader(param, loaded_weight, shard_id)
596
+ break
597
+ else:
598
+ # Skip loading extra bias for GPTQ models.
599
+ if name.endswith(".bias") and name not in params_dict:
600
+ continue
601
+
602
+ if name in params_dict.keys():
603
+ param = params_dict[name]
604
+ weight_loader = getattr(
605
+ param, "weight_loader", default_weight_loader
606
+ )
607
+ weight_loader(param, loaded_weight)
608
+ else:
609
+ logger.warning(f"Parameter {name} not found in params_dict")
610
+
611
+ def get_embed_and_head(self):
612
+ return self.model.embed_tokens.weight, self.lm_head.weight
613
+
614
+ def set_embed_and_head(self, embed, head):
615
+ del self.model.embed_tokens.weight
616
+ del self.lm_head.weight
617
+ self.model.embed_tokens.weight = embed
618
+ self.lm_head.weight = head
619
+ torch.cuda.empty_cache()
620
+ torch.cuda.synchronize()
621
+
622
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
623
+ self.model.load_kv_cache_scales(quantization_param_path)
624
+
625
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
626
+ if not self.pp_group.is_last_rank:
627
+ return
628
+
629
+ self.capture_aux_hidden_states = True
630
+ if layer_ids is None:
631
+ num_layers = self.config.num_hidden_layers
632
+ self.model.layers_to_capture = [
633
+ 2,
634
+ num_layers // 2,
635
+ num_layers - 3,
636
+ ] # Specific layers for EAGLE3 support
637
+ else:
638
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
639
+
640
+
641
+ EntryClass = Qwen2ForCausalLM