JunHowie's picture
Upload folder using huggingface_hub
8e91b6a verified
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from fractions import Fraction
from typing import Optional, Union
import regex as re
import torch
from vllm.config import QuantizationConfig
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, UnquantizedEmbeddingMethod)
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits",
config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size",
config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act",
config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = Fraction(32, config.weight_bits) # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}")
config.quant_type = config.TYPE_MAP[(config.weight_bits,
config.is_sym)]
elif config.get_name() == "gptq":
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits.")
def get_dynamic_override(
config: QuantizationConfig,
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool,
None] = None) -> Union[dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value
def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
cloned_config = deepcopy(config)
parallel_lm_head_quantized = isinstance(
layer, ParallelLMHead) and cloned_config.lm_head_quantized
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
# False = skip module, None = no override, else = Positive match
if get_dynamic_override( # noqa: E712
cloned_config, # noqa: E712
layer_name=prefix) == False: # noqa: E712
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return linear_method_cls(cloned_config)
return None