File size: 13,247 Bytes
0c1b5cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# -*- coding: utf-8 -*-
"""Quantization scale module."""

import math
import typing as tp
from dataclasses import dataclass, field

import torch

from ...data.dtype import QuantDataType
from ...data.range import DynamicRange, QuantRange, RangeBound
from ...data.scale import QuantScale
from ...data.utils import ScaleUtils
from ...data.zero import ZeroPointDomain
from .simple import simple_quantize

from deepcompressor.utils import tools
logger = tools.logging.getLogger(__name__)

__all__ = ["quantize_scale", "QuantScaleInfo"]


def quantize_scale(

    s: torch.Tensor,

    /,

    *,

    quant_dtypes: tp.Sequence[QuantDataType],

    quant_spans: tp.Sequence[float],

    view_shapes: tp.Sequence[torch.Size],

) -> QuantScale:
    """Quantize the scale tensor.



    Args:

        s (`torch.Tensor`):

            The scale tensor.

        quant_dtypes (`Sequence[QuantDataType]`):

            The quantization dtypes of the scale tensor.

        quant_spans (`Sequence[float]`):

            The quantization spans of the scale tensor.

        view_shapes (`Sequence[torch.Size]`):

            The view shapes of the scale tensor.



    Returns:

        `QuantScale`:

            The quantized scale tensor.

    """
    # Add validation at the start  
    if s.numel() == 0:  
        raise ValueError("Input tensor is empty")  
    if s.isnan().any() or s.isinf().any():  
        raise ValueError("Input tensor contains NaN or Inf values")  
    if (s == 0).all():  
        logger.warning("Input tensor contains all zeros - this may indicate meta tensor materialization issues")  
        # Create a minimal non-zero tensor to allow quantization to proceed  
        s = torch.ones_like(s) * 1e-6  

    # Add meta tensor check before any operations  
    if s.is_meta:  
        raise RuntimeError("Cannot quantize scale with meta tensor. Ensure model is loaded on actual device.")  
      
    # Existing validation  
    if s.isnan().any() or s.isinf().any():  
        raise ValueError("Input tensor contains NaN or Inf values")  

    scale = QuantScale()
    s = s.abs()
    for view_shape, quant_dtype, quant_span in zip(view_shapes[:-1], quant_dtypes[:-1], quant_spans[:-1], strict=True):
        s = s.view(view_shape)  # (#g0, rs0, #g1, rs1, #g2, rs2, ...)
        ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)  # i.e., s_dynamic_span
        ss = simple_quantize(
            ss / quant_span, has_zero_point=False, quant_dtype=quant_dtype
        )  # i.e., s_scale = s_dynamic_span / s_quant_span
        s = s / ss
        scale.append(ss)
    view_shape = view_shapes[-1]
    s = s.view(view_shape)
    if any(v != 1 for v in view_shape[1::2]):
        ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)
        ss = simple_quantize(ss / quant_spans[-1], has_zero_point=False, quant_dtype=quant_dtypes[-1])
    else:
        assert quant_spans[-1] == 1, "The last quant span must be 1."
        ss = simple_quantize(s, has_zero_point=False, quant_dtype=quant_dtypes[-1])
    scale.append(ss)
    scale.remove_zero()
    return scale


@dataclass
class QuantScaleInfo:
    # region tensor information
    tensor_view_shape: torch.Size
    tensor_quant_dtype: torch.dtype | QuantDataType
    tensor_zero_domain: ZeroPointDomain | None
    tensor_quant_range: QuantRange
    tensor_range_bound: RangeBound | None
    # endregion
    default_quant_dtype: torch.dtype | QuantDataType
    scale_view_shapes: list[torch.Size]
    scale_quant_dtypes: list[torch.dtype | QuantDataType]
    exponent_scale_level: int = field(init=False)
    zero_quant_dtype: torch.dtype | QuantDataType = field(init=False)
    # region linear scale information
    linear_tensor_quant_span: float = field(init=False)
    linear_scale_quant_dtypes: list[torch.dtype | QuantDataType] = field(init=False)
    linear_scale_view_shapes: list[torch.Size] = field(init=False)
    linear_scale_quant_spans: list[float] = field(init=False)
    # endregion
    # region exponent scale information
    exponent_tensor_quant_span: float = field(init=False)
    exponent_scale_quant_dtypes: list[torch.dtype | QuantDataType] = field(init=False)
    exponent_scale_view_shapes: list[torch.Size] = field(init=False)
    exponent_scale_quant_spans: list[float] = field(init=False)
    # endregion

    @property
    def has_zero_point(self) -> bool:
        return self.tensor_zero_domain is not None

    def __post_init__(self):
        if isinstance(self.tensor_quant_dtype, torch.dtype):
            raise NotImplementedError("torch.dtype is not supported yet.")
        self.tensor_quant_range = QuantRange.construct(
            self.tensor_quant_dtype, has_zero_point=self.has_zero_point, quant_range=self.tensor_quant_range
        )
        self.scale_quant_dtypes = ScaleUtils.infer_scale_dtypes(self.scale_quant_dtypes, self.default_quant_dtype)
        self.exponent_scale_level = ScaleUtils.infer_exponent_scale_level(self.scale_quant_dtypes)
        if self.has_zero_point:
            if self.tensor_zero_domain == ZeroPointDomain.PreScale:
                self.zero_quant_dtype = self.tensor_quant_dtype
            elif self.tensor_zero_domain == ZeroPointDomain.PostScale:
                # TODO: fix zero quant dtype (signed or unsigned)
                self.zero_quant_dtype = self.scale_quant_dtypes[-1]
                if isinstance(self.zero_quant_dtype, QuantDataType) and self.zero_quant_dtype.is_exponent:
                    self.zero_quant_dtype = self.default_quant_dtype
            else:
                raise ValueError(f"Unsupported zero point domain: {self.tensor_zero_domain}")
            self.linear_tensor_quant_span = self.tensor_quant_range.max - self.tensor_quant_range.min
            self.exponent_tensor_quant_span = 2 ** int(
                math.log2(self.tensor_quant_range.max) + int(self.tensor_quant_dtype.signed)
            )
        else:
            self.zero_quant_dtype = None
            self.linear_tensor_quant_span = self.tensor_quant_range.max
            self.exponent_tensor_quant_span = 2 ** int(math.log2(self.tensor_quant_range.max))
        if self.exponent_scale_level >= 0 and self.exponent_scale_level < len(self.scale_quant_dtypes):
            lin_s_dtypes = self.scale_quant_dtypes[: self.exponent_scale_level]
            exp_s_dtypes = self.scale_quant_dtypes[self.exponent_scale_level :]
            lin_s_view_shapes = self.scale_view_shapes[: self.exponent_scale_level]
            exp_s_view_shapes = self.scale_view_shapes[self.exponent_scale_level :]
            exp_s_spans = ScaleUtils.infer_scale_quant_spans(exp_s_dtypes)
            lin_s_spans = ScaleUtils.infer_scale_quant_spans(lin_s_dtypes, base=exp_s_spans[-1]) if lin_s_dtypes else []
        else:
            lin_s_dtypes, exp_s_dtypes = self.scale_quant_dtypes, []
            lin_s_view_shapes, exp_s_view_shapes = self.scale_view_shapes, []
            lin_s_spans, exp_s_spans = ScaleUtils.infer_scale_quant_spans(lin_s_dtypes), []
        self.linear_scale_quant_dtypes = lin_s_dtypes
        self.linear_scale_view_shapes = lin_s_view_shapes
        self.linear_scale_quant_spans = lin_s_spans
        self.exponent_scale_quant_dtypes = exp_s_dtypes
        self.exponent_scale_view_shapes = exp_s_view_shapes
        self.exponent_scale_quant_spans = exp_s_spans

    def quantize(

        self,

        *,

        # scale-based quantization related arguments

        scale: torch.Tensor | None = None,

        zero: torch.Tensor | None = None,

        # range-based quantization related arguments

        tensor: torch.Tensor | None = None,

        dynamic_range: DynamicRange | None = None,

    ) -> tuple[QuantScale, torch.Tensor]:
        """Get the quantization scale and zero point of the tensor to be quantized.



        Args:

            scale (`torch.Tensor` or `None`, *optional*, defaults to `None`):

                The scale tensor.

            zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):

                The zero point tensor.

            tensor (`torch.Tensor` or `None`, *optional*, defaults to `None`):

                Ten tensor to be quantized. This is only used for range-based quantization.

            dynamic_range (`DynamicRange` or `None`, *optional*, defaults to `None`):

                The dynamic range of the tensor to be quantized.



        Returns:

            `tuple[QuantScale, torch.Tensor]`:

                The scale and the zero point.

        """
        # region step 1: get the dynamic span for range-based scale or the scale tensor
        if scale is None:
            range_based = True
            assert isinstance(tensor, torch.Tensor), "View tensor must be a tensor."
            dynamic_range = dynamic_range or DynamicRange()
            dynamic_range = dynamic_range.measure(
                tensor.view(self.tensor_view_shape),
                zero_domain=self.tensor_zero_domain,
                is_float_point=self.tensor_quant_dtype.is_float_point,
            )
            dynamic_range = dynamic_range.intersect(self.tensor_range_bound)
            dynamic_span = (dynamic_range.max - dynamic_range.min) if self.has_zero_point else dynamic_range.max
        else:
            range_based = False
            scale = scale.view(self.scale_view_shapes[-1])
            assert isinstance(scale, torch.Tensor), "Scale must be a tensor."
        # endregion
        # region step 2: get the scale
        if self.linear_scale_quant_dtypes:
            if range_based:
                linear_scale = dynamic_span / self.linear_tensor_quant_span
            elif self.exponent_scale_quant_dtypes:
                linear_scale = scale.mul(self.exponent_tensor_quant_span).div(self.linear_tensor_quant_span)
            else:
                linear_scale = scale
            lin_s = quantize_scale(
                linear_scale,
                quant_dtypes=self.linear_scale_quant_dtypes,
                quant_spans=self.linear_scale_quant_spans,
                view_shapes=self.linear_scale_view_shapes,
            )
            assert lin_s.data is not None, "Linear scale tensor is None."
        if not lin_s.data.is_meta:  
            assert not lin_s.data.isnan().any(), "Linear scale tensor contains NaN."
            assert not lin_s.data.isinf().any(), "Linear scale tensor contains Inf."
        else:
            lin_s = QuantScale()
        if self.exponent_scale_quant_dtypes:
            if range_based:
                exp_scale = dynamic_span / self.exponent_tensor_quant_span
            else:
                exp_scale = scale
            if lin_s.data is not None:
                lin_s.data = lin_s.data.expand(self.linear_scale_view_shapes[-1]).reshape(self.scale_view_shapes[-1])
                exp_scale = exp_scale / lin_s.data
            exp_s = quantize_scale(
                exp_scale,
                quant_dtypes=self.exponent_scale_quant_dtypes,
                quant_spans=self.exponent_scale_quant_spans,
                view_shapes=self.exponent_scale_view_shapes,
            )
            assert exp_s.data is not None, "Exponential scale tensor is None."
            assert not exp_s.data.isnan().any(), "Exponential scale tensor contains NaN."
            assert not exp_s.data.isinf().any(), "Exponential scale tensor contains Inf."
            s = exp_s if lin_s.data is None else lin_s.extend(exp_s)
        else:
            s = lin_s

        # Before the final assertions, add debugging and validation  
        if s.data is None:  
            # Log debugging information  
            print(f"Linear scale dtypes: {self.linear_scale_quant_dtypes}")  
            print(f"Exponent scale dtypes: {self.exponent_scale_quant_dtypes}")  
            if hasattr(lin_s, 'data') and lin_s.data is not None:  
                print(f"Linear scale data shape: {lin_s.data.shape}")  
            raise RuntimeError("Scale computation failed - resulting scale is None")  
        assert s.data is not None, "Scale tensor is None."
        assert not s.data.isnan().any(), "Scale tensor contains NaN."
        assert not s.data.isinf().any(), "Scale tensor contains Inf."
        # endregion
        # region step 3: get the zero point
        if self.has_zero_point:
            if range_based:
                if self.tensor_zero_domain == ZeroPointDomain.PreScale:
                    zero = self.tensor_quant_range.min - dynamic_range.min / s.data
                else:
                    zero = self.tensor_quant_range.min * s.data - dynamic_range.min
            assert isinstance(zero, torch.Tensor), "Zero point must be a tensor."
            z = simple_quantize(zero, has_zero_point=True, quant_dtype=self.zero_quant_dtype)
        else:
            z = torch.tensor(0, dtype=s.data.dtype, device=s.data.device)
        assert not z.isnan().any(), "Zero point tensor contains NaN."
        assert not z.isinf().any(), "Zero point tensor contains Inf."
        # endregion
        return s, z