Chris Scott commited on
Commit
22e647c
·
verified ·
1 Parent(s): cfd2541

Upload gptq_marlin.py

Browse files
Files changed (1) hide show
  1. gptq_marlin.py +643 -0
gptq_marlin.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Any, Callable, Dict, List, Optional, Set, Union
4
+
5
+ import torch
6
+
7
+ import vllm.model_executor.layers.fused_moe # noqa
8
+ from vllm import _custom_ops as ops
9
+ from vllm.logger import init_logger
10
+ from vllm.model_executor.layers.fused_moe.layer import (
11
+ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
12
+ from vllm.model_executor.layers.linear import (LinearMethodBase,
13
+ set_weight_attrs)
14
+ from vllm.model_executor.layers.quantization.base_config import (
15
+ QuantizationConfig, QuantizeMethodBase)
16
+ from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
17
+ MPLinearLayerConfig, choose_mp_linear_kernel)
18
+ from vllm.model_executor.layers.quantization.utils import replace_parameter
19
+ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
20
+ get_linear_quant_method)
21
+ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
22
+ check_marlin_supported, check_moe_marlin_supports_layer,
23
+ marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks,
24
+ verify_marlin_supported)
25
+ from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
26
+ GroupQuantScaleParameter,
27
+ PackedColumnParameter,
28
+ PackedvLLMParameter,
29
+ RowvLLMParameter)
30
+ from vllm.platforms import current_platform
31
+ from vllm.scalar_type import scalar_types
32
+
33
+ logger = init_logger(__name__)
34
+
35
+
36
+ class GPTQMarlinConfig(QuantizationConfig):
37
+ """Config class for GPTQ Marlin"""
38
+
39
+ # (num_bits, is_sym) -> quant_type
40
+ TYPE_MAP = {
41
+ (4, True): scalar_types.uint4b8,
42
+ (8, True): scalar_types.uint8b128,
43
+ }
44
+
45
+ def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
46
+ is_sym: bool, lm_head_quantized: bool,
47
+ dynamic: Dict[str, Dict[str, Union[int, bool]]],
48
+ full_config: Dict[str, Any]) -> None:
49
+ super().__init__()
50
+ if desc_act and group_size == -1:
51
+ # In this case, act_order == True is the same as act_order == False
52
+ # (since we have only one group per output channel)
53
+ desc_act = False
54
+
55
+ # GPTQModel use `dynamic` config property to allow per module
56
+ # quantization config so each module can be individually optimized.
57
+ # Format is Dict[str, Dict] where key is a regex string that can
58
+ # perform both positive ("+:" prefixed) or negative ("-:" prefixed)
59
+ # matching of a module.
60
+ # Default to positive match, override base quant config mode, if no
61
+ # prefix is used. Value is in dict format of field key and override
62
+ # value.
63
+ # Negative matching will skip quantization init for this module
64
+ # entirely:
65
+ # non-quantized inference. More details and quantization examples can be
66
+ # found at: https://github.com/ModelCloud/GPTQModel
67
+ # Example:
68
+ # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
69
+ # # last 1/4 of the layers 16-21 has 8bit and group_size 64
70
+ # dynamic = {
71
+ # #`.*\.` matches the layers_node prefix
72
+ # # positive match layer 10-15
73
+ # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
74
+ # # positive match layer 16-21
75
+ # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
76
+ # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
77
+ # }
78
+ self.dynamic = dynamic
79
+
80
+ self.weight_bits = weight_bits
81
+ self.is_sym = is_sym
82
+
83
+ self.pack_factor = 32 // weight_bits # packed into int32
84
+ self.group_size = group_size
85
+ self.desc_act = desc_act
86
+ self.lm_head_quantized = lm_head_quantized
87
+ self.full_config = full_config
88
+
89
+ if (weight_bits, is_sym) not in self.TYPE_MAP:
90
+ raise ValueError("Unsupported quantization config: "
91
+ f"bits={weight_bits}, sym={is_sym}")
92
+
93
+ self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
94
+
95
+ def __repr__(self) -> str:
96
+ return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
97
+ f"group_size={self.group_size}, "
98
+ f"desc_act={self.desc_act}, "
99
+ f"lm_head_quantized={self.lm_head_quantized}), "
100
+ f"dynamic={self.dynamic}")
101
+
102
+ @classmethod
103
+ def get_name(cls) -> str:
104
+ return "gptq_marlin"
105
+
106
+ @classmethod
107
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
108
+ return [torch.half, torch.bfloat16]
109
+
110
+ @classmethod
111
+ def get_min_capability(cls) -> int:
112
+ return 80
113
+
114
+ @classmethod
115
+ def get_config_filenames(cls) -> List[str]:
116
+ return ["quantize_config.json"]
117
+
118
+ @classmethod
119
+ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
120
+ dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
121
+ dynamic = {} if dynamic is None else dynamic
122
+
123
+ weight_bits = cls.get_from_keys(config, ["bits"])
124
+ group_size = cls.get_from_keys(config, ["group_size"])
125
+ desc_act = cls.get_from_keys(config, ["desc_act"])
126
+ is_sym = cls.get_from_keys(config, ["sym"])
127
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
128
+ default=False)
129
+ return cls(weight_bits, group_size, desc_act, is_sym,
130
+ lm_head_quantized, dynamic, config)
131
+
132
+ @classmethod
133
+ def override_quantization_method(cls, hf_quant_cfg,
134
+ user_quant) -> Optional[str]:
135
+ can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
136
+
137
+ is_valid_user_quant = (user_quant is None or user_quant == "marlin"
138
+ or user_quant == "gptq_marlin")
139
+
140
+ if can_convert and is_valid_user_quant:
141
+ msg = ("The model is convertible to {} during runtime."
142
+ " Using {} kernel.".format(cls.get_name(), cls.get_name()))
143
+ logger.info(msg)
144
+ return cls.get_name()
145
+
146
+ if can_convert and user_quant == "gptq":
147
+ logger.info("Detected that the model can run with gptq_marlin"
148
+ ", however you specified quantization=gptq explicitly,"
149
+ " so forcing gptq. Use quantization=gptq_marlin for"
150
+ " faster inference")
151
+ return None
152
+
153
+ def get_quant_method(self, layer: torch.nn.Module,
154
+ prefix: str) -> Optional["QuantizeMethodBase"]:
155
+ if isinstance(layer, FusedMoE):
156
+ from vllm.model_executor.layers.quantization.moe_wna16 import (
157
+ MoeWNA16Config)
158
+ if not check_moe_marlin_supports_layer(layer, self.group_size):
159
+ logger.warning(
160
+ f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
161
+ "Falling back to Moe WNA16 kernels.")
162
+ return MoeWNA16Config.from_config(
163
+ self.full_config).get_quant_method(layer, prefix)
164
+ return GPTQMarlinMoEMethod(self)
165
+ return get_linear_quant_method(self, layer, prefix,
166
+ GPTQMarlinLinearMethod)
167
+
168
+ @classmethod
169
+ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
170
+ quant_method = quant_config.get("quant_method", "").lower()
171
+ num_bits = quant_config.get("bits")
172
+ group_size = quant_config.get("group_size")
173
+ sym = quant_config.get("sym")
174
+ desc_act = quant_config.get("desc_act")
175
+
176
+ if not current_platform.is_cuda():
177
+ return False
178
+
179
+ if quant_method != "gptq":
180
+ return False
181
+
182
+ # Marlin conversion is only valid if required properties are found
183
+ if (num_bits is None or group_size is None or sym is None
184
+ or desc_act is None):
185
+ return False
186
+
187
+ if (num_bits, sym) not in cls.TYPE_MAP:
188
+ return False
189
+
190
+ return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
191
+ group_size=group_size)
192
+
193
+
194
+ class GPTQMarlinLinearMethod(LinearMethodBase):
195
+ """Linear method for GPTQ Marlin.
196
+
197
+ Args:
198
+ quant_config: The GPTQ Marlin quantization config.
199
+ """
200
+
201
+ _kernel_backends_being_used: Set[str] = set()
202
+
203
+ def __init__(self, quant_config: GPTQMarlinConfig) -> None:
204
+ self.quant_config = quant_config
205
+
206
+ # Verify supported on platform.
207
+ verify_marlin_supported(quant_type=self.quant_config.quant_type,
208
+ group_size=self.quant_config.group_size)
209
+
210
+ def create_weights(
211
+ self,
212
+ layer: torch.nn.Module,
213
+ input_size_per_partition: int,
214
+ output_partition_sizes: List[int],
215
+ input_size: int,
216
+ output_size: int,
217
+ params_dtype: torch.dtype,
218
+ **extra_weight_attrs,
219
+ ) -> None:
220
+ output_size_per_partition = sum(output_partition_sizes)
221
+ is_row_parallel = input_size != input_size_per_partition
222
+ weight_loader = extra_weight_attrs.get("weight_loader")
223
+
224
+ mp_linear_kernel_config = MPLinearLayerConfig(
225
+ full_weight_shape=(input_size, output_size),
226
+ partition_weight_shape=\
227
+ (input_size_per_partition, output_size_per_partition),
228
+ weight_type=self.quant_config.quant_type,
229
+ act_type=params_dtype,
230
+ group_size=self.quant_config.group_size,
231
+ zero_points=False,
232
+ has_g_idx=self.quant_config.desc_act
233
+ )
234
+
235
+ kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
236
+
237
+ if kernel_type.__name__ not in self._kernel_backends_being_used:
238
+ logger.info("Using %s for GPTQMarlinLinearMethod",
239
+ kernel_type.__name__)
240
+ self._kernel_backends_being_used.add(kernel_type.__name__)
241
+
242
+ # Normalize group_size
243
+ if self.quant_config.group_size != -1:
244
+ group_size = self.quant_config.group_size
245
+ else:
246
+ group_size = input_size
247
+
248
+ # Determine sharding
249
+ if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
250
+ self.quant_config.group_size,
251
+ is_row_parallel):
252
+ # By setting scale_dim == None, weight_loader will
253
+ # repeat the scales on each GPU in TP>1 case.
254
+ scales_and_zp_input_dim = None
255
+ scales_and_zp_size = input_size // group_size
256
+ else:
257
+ # By setting scale_dim == 0, weight_loader will
258
+ # shard the scales in TP>1 case.
259
+ scales_and_zp_input_dim = 0
260
+ scales_and_zp_size = input_size_per_partition // group_size
261
+
262
+ # Quantized weights
263
+ qweight = PackedvLLMParameter(
264
+ data=torch.empty(
265
+ input_size_per_partition // self.quant_config.pack_factor,
266
+ output_size_per_partition,
267
+ dtype=torch.int32,
268
+ ),
269
+ input_dim=0,
270
+ output_dim=1,
271
+ packed_dim=0,
272
+ packed_factor=self.quant_config.pack_factor,
273
+ weight_loader=weight_loader)
274
+
275
+ # Activation order
276
+ g_idx = RowvLLMParameter(data=torch.empty(
277
+ input_size_per_partition,
278
+ dtype=torch.int32,
279
+ ),
280
+ input_dim=0,
281
+ weight_loader=weight_loader)
282
+
283
+ qzeros_args = {
284
+ "data":
285
+ torch.empty(
286
+ scales_and_zp_size,
287
+ output_size_per_partition // self.quant_config.pack_factor,
288
+ dtype=torch.int32,
289
+ ),
290
+ "weight_loader":
291
+ weight_loader
292
+ }
293
+ weight_scale_args = {
294
+ "data":
295
+ torch.empty(
296
+ scales_and_zp_size,
297
+ output_size_per_partition,
298
+ dtype=params_dtype,
299
+ ),
300
+ "weight_loader":
301
+ weight_loader
302
+ }
303
+
304
+ if scales_and_zp_input_dim is None:
305
+ scales = ChannelQuantScaleParameter(output_dim=1,
306
+ **weight_scale_args)
307
+ qzeros = PackedColumnParameter(
308
+ output_dim=1,
309
+ packed_dim=1,
310
+ packed_factor=self.quant_config.pack_factor,
311
+ **qzeros_args)
312
+
313
+ else:
314
+ scales = GroupQuantScaleParameter(output_dim=1,
315
+ input_dim=0,
316
+ **weight_scale_args)
317
+ qzeros = PackedvLLMParameter(
318
+ input_dim=0,
319
+ output_dim=1,
320
+ packed_dim=1,
321
+ packed_factor=self.quant_config.pack_factor,
322
+ **qzeros_args)
323
+
324
+ layer.register_parameter("qweight", qweight)
325
+ layer.register_parameter("g_idx", g_idx)
326
+ layer.register_parameter("scales", scales)
327
+ layer.register_parameter("qzeros", qzeros)
328
+
329
+ self.kernel = kernel_type(mp_linear_kernel_config,
330
+ w_q_param_name="qweight",
331
+ w_s_param_name="scales",
332
+ w_zp_param_name="qzeros",
333
+ w_gidx_param_name="g_idx")
334
+
335
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
336
+ self.kernel.process_weights_after_loading(layer)
337
+
338
+ def apply(
339
+ self,
340
+ layer: torch.nn.Module,
341
+ x: torch.Tensor,
342
+ bias: Optional[torch.Tensor] = None,
343
+ ) -> torch.Tensor:
344
+ return self.kernel.apply_weights(layer, x, bias)
345
+
346
+
347
+ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
348
+ """MoE Marlin method with quantization."""
349
+
350
+ def __init__(self, quant_config: GPTQMarlinConfig) -> None:
351
+ self.quant_config = quant_config
352
+
353
+ def create_weights(
354
+ self,
355
+ layer: torch.nn.Module,
356
+ num_experts: int,
357
+ hidden_size: int,
358
+ intermediate_size_per_partition: int,
359
+ params_dtype: torch.dtype,
360
+ **extra_weight_attrs,
361
+ ):
362
+ intermediate_size_full = extra_weight_attrs.pop(
363
+ "intermediate_size_full")
364
+
365
+ self.is_k_full = (not self.quant_config.desc_act) or (
366
+ intermediate_size_per_partition == intermediate_size_full)
367
+
368
+ if self.quant_config.group_size != -1:
369
+ scales_size13 = hidden_size // self.quant_config.group_size
370
+ w2_scales_size = (intermediate_size_full
371
+ if self.quant_config.desc_act else
372
+ intermediate_size_per_partition)
373
+ scales_size2 = (w2_scales_size // self.quant_config.group_size)
374
+ strategy = FusedMoeWeightScaleSupported.GROUP.value
375
+ else:
376
+ scales_size13 = 1
377
+ scales_size2 = 1
378
+ strategy = FusedMoeWeightScaleSupported.CHANNEL.value
379
+
380
+ extra_weight_attrs.update({
381
+ "quant_method": strategy,
382
+ "is_transposed": True
383
+ })
384
+ # Fused gate_up_proj (column parallel)
385
+ w13_qweight = torch.nn.Parameter(
386
+ torch.empty(
387
+ num_experts,
388
+ hidden_size // self.quant_config.pack_factor,
389
+ 2 * intermediate_size_per_partition,
390
+ dtype=torch.int32,
391
+ ),
392
+ requires_grad=False,
393
+ )
394
+ layer.register_parameter("w13_qweight", w13_qweight)
395
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
396
+ # down_proj (row parallel)
397
+ w2_qweight = torch.nn.Parameter(
398
+ torch.empty(
399
+ num_experts,
400
+ intermediate_size_per_partition //
401
+ self.quant_config.pack_factor,
402
+ hidden_size,
403
+ dtype=torch.int32,
404
+ ),
405
+ requires_grad=False,
406
+ )
407
+ layer.register_parameter("w2_qweight", w2_qweight)
408
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
409
+ # up_proj scales
410
+ w13_scales = torch.nn.Parameter(
411
+ torch.empty(num_experts,
412
+ scales_size13,
413
+ 2 * intermediate_size_per_partition,
414
+ dtype=params_dtype),
415
+ requires_grad=False,
416
+ )
417
+ layer.register_parameter("w13_scales", w13_scales)
418
+ set_weight_attrs(w13_scales, extra_weight_attrs)
419
+ # down_proj scales
420
+ w2_scales = torch.nn.Parameter(
421
+ torch.empty(num_experts,
422
+ scales_size2,
423
+ hidden_size,
424
+ dtype=params_dtype),
425
+ requires_grad=False,
426
+ )
427
+ layer.register_parameter("w2_scales", w2_scales)
428
+ set_weight_attrs(w2_scales, extra_weight_attrs)
429
+ # dont shard the w2 scales when running act order
430
+ set_weight_attrs(w2_scales,
431
+ {"load_full_w2": self.quant_config.desc_act})
432
+ # up_proj scales
433
+ w13_qzeros = torch.nn.Parameter(
434
+ torch.empty(num_experts,
435
+ scales_size13,
436
+ 2 * intermediate_size_per_partition //
437
+ self.quant_config.pack_factor,
438
+ dtype=params_dtype),
439
+ requires_grad=False,
440
+ )
441
+ layer.register_parameter("w13_qzeros", w13_qzeros)
442
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
443
+ # down_proj scales
444
+ w2_qzeros = torch.nn.Parameter(
445
+ torch.empty(num_experts,
446
+ scales_size2,
447
+ hidden_size // self.quant_config.pack_factor,
448
+ dtype=params_dtype),
449
+ requires_grad=False,
450
+ )
451
+ layer.register_parameter("w2_qzeros", w2_qzeros)
452
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
453
+ # dont shard the w2 scales when running act order
454
+ set_weight_attrs(w2_qzeros,
455
+ {"load_full_w2": self.quant_config.desc_act})
456
+ w13_g_idx = torch.nn.Parameter(
457
+ torch.empty(
458
+ num_experts,
459
+ hidden_size,
460
+ dtype=torch.int32,
461
+ ),
462
+ requires_grad=False,
463
+ )
464
+ layer.register_parameter("w13_g_idx", w13_g_idx)
465
+ set_weight_attrs(w13_g_idx, extra_weight_attrs)
466
+ w2_g_idx = torch.nn.Parameter(
467
+ torch.empty(
468
+ num_experts,
469
+ intermediate_size_per_partition,
470
+ dtype=torch.int32,
471
+ ),
472
+ requires_grad=False,
473
+ )
474
+ layer.register_parameter("w2_g_idx", w2_g_idx)
475
+ set_weight_attrs(w2_g_idx, extra_weight_attrs)
476
+ w13_g_idx_sort_indices = torch.nn.Parameter(
477
+ torch.empty(
478
+ num_experts,
479
+ hidden_size,
480
+ dtype=torch.int32,
481
+ ),
482
+ requires_grad=False,
483
+ )
484
+ layer.register_parameter("w13_g_idx_sort_indices",
485
+ w13_g_idx_sort_indices)
486
+ set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
487
+ w2_g_idx_sort_indices = torch.nn.Parameter(
488
+ torch.empty(
489
+ num_experts,
490
+ intermediate_size_per_partition,
491
+ dtype=torch.int32,
492
+ ),
493
+ requires_grad=False,
494
+ )
495
+ layer.register_parameter("w2_g_idx_sort_indices",
496
+ w2_g_idx_sort_indices)
497
+ set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
498
+
499
+ device = layer.w13_qweight.device
500
+ sms = torch.cuda.get_device_properties(device).multi_processor_count
501
+ layer.workspace = torch.zeros((sms * 4, ),
502
+ dtype=torch.int,
503
+ device=device,
504
+ requires_grad=False)
505
+
506
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
507
+
508
+ # Process act_order
509
+ if self.quant_config.desc_act:
510
+ # Get sorting based on g_idx
511
+ num_experts = layer.w13_g_idx.shape[0]
512
+ w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
513
+ w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
514
+ w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
515
+ w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
516
+ for e in range(num_experts):
517
+ w13_g_idx_sort_indices[e] = torch.argsort(
518
+ layer.w13_g_idx[e]).to(torch.int32)
519
+ w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
520
+ torch.int32)
521
+ w13_sorted_g_idx[e] = layer.w13_g_idx[e][
522
+ w13_g_idx_sort_indices[e]]
523
+ w2_sorted_g_idx[e] = layer.w2_g_idx[e][
524
+ w2_g_idx_sort_indices[e]]
525
+ replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
526
+ replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
527
+ replace_parameter(layer, "w13_g_idx_sort_indices",
528
+ w13_g_idx_sort_indices)
529
+ replace_parameter(layer, "w2_g_idx_sort_indices",
530
+ w2_g_idx_sort_indices)
531
+ else:
532
+ # Reset g_idx related tensors
533
+ num_experts = layer.w13_g_idx.shape[0]
534
+ device = layer.w13_g_idx.device
535
+ layer.w13_g_idx = torch.nn.Parameter(
536
+ torch.empty((num_experts, 0), dtype=torch.int32,
537
+ device=device),
538
+ requires_grad=False,
539
+ )
540
+ layer.w2_g_idx = torch.nn.Parameter(
541
+ torch.empty((num_experts, 0), dtype=torch.int32,
542
+ device=device),
543
+ requires_grad=False,
544
+ )
545
+ layer.w13_g_idx_sort_indices = torch.nn.Parameter(
546
+ torch.empty((num_experts, 0), dtype=torch.int32,
547
+ device=device),
548
+ requires_grad=False,
549
+ )
550
+ layer.w2_g_idx_sort_indices = torch.nn.Parameter(
551
+ torch.empty((num_experts, 0), dtype=torch.int32,
552
+ device=device),
553
+ requires_grad=False,
554
+ )
555
+ # Repack weights
556
+ marlin_w13_qweight = ops.gptq_marlin_moe_repack(
557
+ layer.w13_qweight,
558
+ layer.w13_g_idx_sort_indices,
559
+ layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
560
+ layer.w13_qweight.shape[2],
561
+ self.quant_config.quant_type.size_bits,
562
+ )
563
+ replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
564
+ marlin_w2_qweight = ops.gptq_marlin_moe_repack(
565
+ layer.w2_qweight,
566
+ layer.w2_g_idx_sort_indices,
567
+ layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
568
+ layer.w2_qweight.shape[2],
569
+ self.quant_config.quant_type.size_bits,
570
+ )
571
+ replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
572
+ # Repack scales
573
+ marlin_w13_scales = marlin_moe_permute_scales(
574
+ s=layer.w13_scales,
575
+ size_k=layer.intermediate_size_per_partition,
576
+ size_n=layer.w13_scales.shape[2],
577
+ group_size=self.quant_config.group_size,
578
+ )
579
+ replace_parameter(layer, "w13_scales", marlin_w13_scales)
580
+ marlin_w2_scales = marlin_moe_permute_scales(
581
+ s=layer.w2_scales,
582
+ size_k=layer.w2_scales.shape[1] *
583
+ (self.quant_config.group_size if self.quant_config.group_size != -1
584
+ else self.quant_config.pack_factor),
585
+ size_n=layer.w2_scales.shape[2],
586
+ group_size=self.quant_config.group_size,
587
+ )
588
+ replace_parameter(layer, "w2_scales", marlin_w2_scales)
589
+
590
+ def apply(
591
+ self,
592
+ layer: torch.nn.Module,
593
+ x: torch.Tensor,
594
+ router_logits: torch.Tensor,
595
+ top_k: int,
596
+ renormalize: bool,
597
+ use_grouped_topk: bool = False,
598
+ topk_group: Optional[int] = None,
599
+ num_expert_group: Optional[int] = None,
600
+ global_num_experts: int = -1,
601
+ expert_map: Optional[torch.Tensor] = None,
602
+ custom_routing_function: Optional[Callable] = None,
603
+ scoring_func: str = "softmax",
604
+ e_score_correction_bias: Optional[torch.Tensor] = None,
605
+ apply_router_weight_on_input: bool = False,
606
+ activation: str = "silu",
607
+ ) -> torch.Tensor:
608
+ assert activation == "silu", "Only SiLU activation is supported."
609
+ if apply_router_weight_on_input:
610
+ raise NotImplementedError(
611
+ "Apply router weight on input is not supported for"
612
+ "fused Marlin MoE method.")
613
+
614
+ topk_weights, topk_ids = FusedMoE.select_experts(
615
+ hidden_states=x,
616
+ router_logits=router_logits,
617
+ use_grouped_topk=use_grouped_topk,
618
+ top_k=top_k,
619
+ renormalize=renormalize,
620
+ topk_group=topk_group,
621
+ num_expert_group=num_expert_group,
622
+ custom_routing_function=custom_routing_function,
623
+ scoring_func=scoring_func,
624
+ e_score_correction_bias=e_score_correction_bias)
625
+
626
+ return torch.ops.vllm.fused_marlin_moe(
627
+ x,
628
+ layer.w13_qweight,
629
+ layer.w2_qweight,
630
+ layer.w13_scales,
631
+ layer.w2_scales,
632
+ router_logits,
633
+ topk_weights,
634
+ topk_ids,
635
+ global_num_experts=global_num_experts,
636
+ expert_map=expert_map,
637
+ g_idx1=layer.w13_g_idx,
638
+ g_idx2=layer.w2_g_idx,
639
+ sort_indices1=layer.w13_g_idx_sort_indices,
640
+ sort_indices2=layer.w2_g_idx_sort_indices,
641
+ num_bits=self.quant_config.quant_type.size_bits,
642
+ workspace=layer.workspace,
643
+ is_k_full=self.is_k_full)