lym00 commited on
Commit
0c1b5cc
·
verified ·
1 Parent(s): 6f59bc1

Upload scale.py

Browse files
Files changed (1) hide show
  1. scale.py +273 -0
scale.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Quantization scale module."""
3
+
4
+ import math
5
+ import typing as tp
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+
10
+ from ...data.dtype import QuantDataType
11
+ from ...data.range import DynamicRange, QuantRange, RangeBound
12
+ from ...data.scale import QuantScale
13
+ from ...data.utils import ScaleUtils
14
+ from ...data.zero import ZeroPointDomain
15
+ from .simple import simple_quantize
16
+
17
+ from deepcompressor.utils import tools
18
+ logger = tools.logging.getLogger(__name__)
19
+
20
+ __all__ = ["quantize_scale", "QuantScaleInfo"]
21
+
22
+
23
+ def quantize_scale(
24
+ s: torch.Tensor,
25
+ /,
26
+ *,
27
+ quant_dtypes: tp.Sequence[QuantDataType],
28
+ quant_spans: tp.Sequence[float],
29
+ view_shapes: tp.Sequence[torch.Size],
30
+ ) -> QuantScale:
31
+ """Quantize the scale tensor.
32
+
33
+ Args:
34
+ s (`torch.Tensor`):
35
+ The scale tensor.
36
+ quant_dtypes (`Sequence[QuantDataType]`):
37
+ The quantization dtypes of the scale tensor.
38
+ quant_spans (`Sequence[float]`):
39
+ The quantization spans of the scale tensor.
40
+ view_shapes (`Sequence[torch.Size]`):
41
+ The view shapes of the scale tensor.
42
+
43
+ Returns:
44
+ `QuantScale`:
45
+ The quantized scale tensor.
46
+ """
47
+ # Add validation at the start
48
+ if s.numel() == 0:
49
+ raise ValueError("Input tensor is empty")
50
+ if s.isnan().any() or s.isinf().any():
51
+ raise ValueError("Input tensor contains NaN or Inf values")
52
+ if (s == 0).all():
53
+ logger.warning("Input tensor contains all zeros - this may indicate meta tensor materialization issues")
54
+ # Create a minimal non-zero tensor to allow quantization to proceed
55
+ s = torch.ones_like(s) * 1e-6
56
+
57
+ # Add meta tensor check before any operations
58
+ if s.is_meta:
59
+ raise RuntimeError("Cannot quantize scale with meta tensor. Ensure model is loaded on actual device.")
60
+
61
+ # Existing validation
62
+ if s.isnan().any() or s.isinf().any():
63
+ raise ValueError("Input tensor contains NaN or Inf values")
64
+
65
+ scale = QuantScale()
66
+ s = s.abs()
67
+ for view_shape, quant_dtype, quant_span in zip(view_shapes[:-1], quant_dtypes[:-1], quant_spans[:-1], strict=True):
68
+ s = s.view(view_shape) # (#g0, rs0, #g1, rs1, #g2, rs2, ...)
69
+ ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True) # i.e., s_dynamic_span
70
+ ss = simple_quantize(
71
+ ss / quant_span, has_zero_point=False, quant_dtype=quant_dtype
72
+ ) # i.e., s_scale = s_dynamic_span / s_quant_span
73
+ s = s / ss
74
+ scale.append(ss)
75
+ view_shape = view_shapes[-1]
76
+ s = s.view(view_shape)
77
+ if any(v != 1 for v in view_shape[1::2]):
78
+ ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)
79
+ ss = simple_quantize(ss / quant_spans[-1], has_zero_point=False, quant_dtype=quant_dtypes[-1])
80
+ else:
81
+ assert quant_spans[-1] == 1, "The last quant span must be 1."
82
+ ss = simple_quantize(s, has_zero_point=False, quant_dtype=quant_dtypes[-1])
83
+ scale.append(ss)
84
+ scale.remove_zero()
85
+ return scale
86
+
87
+
88
+ @dataclass
89
+ class QuantScaleInfo:
90
+ # region tensor information
91
+ tensor_view_shape: torch.Size
92
+ tensor_quant_dtype: torch.dtype | QuantDataType
93
+ tensor_zero_domain: ZeroPointDomain | None
94
+ tensor_quant_range: QuantRange
95
+ tensor_range_bound: RangeBound | None
96
+ # endregion
97
+ default_quant_dtype: torch.dtype | QuantDataType
98
+ scale_view_shapes: list[torch.Size]
99
+ scale_quant_dtypes: list[torch.dtype | QuantDataType]
100
+ exponent_scale_level: int = field(init=False)
101
+ zero_quant_dtype: torch.dtype | QuantDataType = field(init=False)
102
+ # region linear scale information
103
+ linear_tensor_quant_span: float = field(init=False)
104
+ linear_scale_quant_dtypes: list[torch.dtype | QuantDataType] = field(init=False)
105
+ linear_scale_view_shapes: list[torch.Size] = field(init=False)
106
+ linear_scale_quant_spans: list[float] = field(init=False)
107
+ # endregion
108
+ # region exponent scale information
109
+ exponent_tensor_quant_span: float = field(init=False)
110
+ exponent_scale_quant_dtypes: list[torch.dtype | QuantDataType] = field(init=False)
111
+ exponent_scale_view_shapes: list[torch.Size] = field(init=False)
112
+ exponent_scale_quant_spans: list[float] = field(init=False)
113
+ # endregion
114
+
115
+ @property
116
+ def has_zero_point(self) -> bool:
117
+ return self.tensor_zero_domain is not None
118
+
119
+ def __post_init__(self):
120
+ if isinstance(self.tensor_quant_dtype, torch.dtype):
121
+ raise NotImplementedError("torch.dtype is not supported yet.")
122
+ self.tensor_quant_range = QuantRange.construct(
123
+ self.tensor_quant_dtype, has_zero_point=self.has_zero_point, quant_range=self.tensor_quant_range
124
+ )
125
+ self.scale_quant_dtypes = ScaleUtils.infer_scale_dtypes(self.scale_quant_dtypes, self.default_quant_dtype)
126
+ self.exponent_scale_level = ScaleUtils.infer_exponent_scale_level(self.scale_quant_dtypes)
127
+ if self.has_zero_point:
128
+ if self.tensor_zero_domain == ZeroPointDomain.PreScale:
129
+ self.zero_quant_dtype = self.tensor_quant_dtype
130
+ elif self.tensor_zero_domain == ZeroPointDomain.PostScale:
131
+ # TODO: fix zero quant dtype (signed or unsigned)
132
+ self.zero_quant_dtype = self.scale_quant_dtypes[-1]
133
+ if isinstance(self.zero_quant_dtype, QuantDataType) and self.zero_quant_dtype.is_exponent:
134
+ self.zero_quant_dtype = self.default_quant_dtype
135
+ else:
136
+ raise ValueError(f"Unsupported zero point domain: {self.tensor_zero_domain}")
137
+ self.linear_tensor_quant_span = self.tensor_quant_range.max - self.tensor_quant_range.min
138
+ self.exponent_tensor_quant_span = 2 ** int(
139
+ math.log2(self.tensor_quant_range.max) + int(self.tensor_quant_dtype.signed)
140
+ )
141
+ else:
142
+ self.zero_quant_dtype = None
143
+ self.linear_tensor_quant_span = self.tensor_quant_range.max
144
+ self.exponent_tensor_quant_span = 2 ** int(math.log2(self.tensor_quant_range.max))
145
+ if self.exponent_scale_level >= 0 and self.exponent_scale_level < len(self.scale_quant_dtypes):
146
+ lin_s_dtypes = self.scale_quant_dtypes[: self.exponent_scale_level]
147
+ exp_s_dtypes = self.scale_quant_dtypes[self.exponent_scale_level :]
148
+ lin_s_view_shapes = self.scale_view_shapes[: self.exponent_scale_level]
149
+ exp_s_view_shapes = self.scale_view_shapes[self.exponent_scale_level :]
150
+ exp_s_spans = ScaleUtils.infer_scale_quant_spans(exp_s_dtypes)
151
+ lin_s_spans = ScaleUtils.infer_scale_quant_spans(lin_s_dtypes, base=exp_s_spans[-1]) if lin_s_dtypes else []
152
+ else:
153
+ lin_s_dtypes, exp_s_dtypes = self.scale_quant_dtypes, []
154
+ lin_s_view_shapes, exp_s_view_shapes = self.scale_view_shapes, []
155
+ lin_s_spans, exp_s_spans = ScaleUtils.infer_scale_quant_spans(lin_s_dtypes), []
156
+ self.linear_scale_quant_dtypes = lin_s_dtypes
157
+ self.linear_scale_view_shapes = lin_s_view_shapes
158
+ self.linear_scale_quant_spans = lin_s_spans
159
+ self.exponent_scale_quant_dtypes = exp_s_dtypes
160
+ self.exponent_scale_view_shapes = exp_s_view_shapes
161
+ self.exponent_scale_quant_spans = exp_s_spans
162
+
163
+ def quantize(
164
+ self,
165
+ *,
166
+ # scale-based quantization related arguments
167
+ scale: torch.Tensor | None = None,
168
+ zero: torch.Tensor | None = None,
169
+ # range-based quantization related arguments
170
+ tensor: torch.Tensor | None = None,
171
+ dynamic_range: DynamicRange | None = None,
172
+ ) -> tuple[QuantScale, torch.Tensor]:
173
+ """Get the quantization scale and zero point of the tensor to be quantized.
174
+
175
+ Args:
176
+ scale (`torch.Tensor` or `None`, *optional*, defaults to `None`):
177
+ The scale tensor.
178
+ zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
179
+ The zero point tensor.
180
+ tensor (`torch.Tensor` or `None`, *optional*, defaults to `None`):
181
+ Ten tensor to be quantized. This is only used for range-based quantization.
182
+ dynamic_range (`DynamicRange` or `None`, *optional*, defaults to `None`):
183
+ The dynamic range of the tensor to be quantized.
184
+
185
+ Returns:
186
+ `tuple[QuantScale, torch.Tensor]`:
187
+ The scale and the zero point.
188
+ """
189
+ # region step 1: get the dynamic span for range-based scale or the scale tensor
190
+ if scale is None:
191
+ range_based = True
192
+ assert isinstance(tensor, torch.Tensor), "View tensor must be a tensor."
193
+ dynamic_range = dynamic_range or DynamicRange()
194
+ dynamic_range = dynamic_range.measure(
195
+ tensor.view(self.tensor_view_shape),
196
+ zero_domain=self.tensor_zero_domain,
197
+ is_float_point=self.tensor_quant_dtype.is_float_point,
198
+ )
199
+ dynamic_range = dynamic_range.intersect(self.tensor_range_bound)
200
+ dynamic_span = (dynamic_range.max - dynamic_range.min) if self.has_zero_point else dynamic_range.max
201
+ else:
202
+ range_based = False
203
+ scale = scale.view(self.scale_view_shapes[-1])
204
+ assert isinstance(scale, torch.Tensor), "Scale must be a tensor."
205
+ # endregion
206
+ # region step 2: get the scale
207
+ if self.linear_scale_quant_dtypes:
208
+ if range_based:
209
+ linear_scale = dynamic_span / self.linear_tensor_quant_span
210
+ elif self.exponent_scale_quant_dtypes:
211
+ linear_scale = scale.mul(self.exponent_tensor_quant_span).div(self.linear_tensor_quant_span)
212
+ else:
213
+ linear_scale = scale
214
+ lin_s = quantize_scale(
215
+ linear_scale,
216
+ quant_dtypes=self.linear_scale_quant_dtypes,
217
+ quant_spans=self.linear_scale_quant_spans,
218
+ view_shapes=self.linear_scale_view_shapes,
219
+ )
220
+ assert lin_s.data is not None, "Linear scale tensor is None."
221
+ if not lin_s.data.is_meta:
222
+ assert not lin_s.data.isnan().any(), "Linear scale tensor contains NaN."
223
+ assert not lin_s.data.isinf().any(), "Linear scale tensor contains Inf."
224
+ else:
225
+ lin_s = QuantScale()
226
+ if self.exponent_scale_quant_dtypes:
227
+ if range_based:
228
+ exp_scale = dynamic_span / self.exponent_tensor_quant_span
229
+ else:
230
+ exp_scale = scale
231
+ if lin_s.data is not None:
232
+ lin_s.data = lin_s.data.expand(self.linear_scale_view_shapes[-1]).reshape(self.scale_view_shapes[-1])
233
+ exp_scale = exp_scale / lin_s.data
234
+ exp_s = quantize_scale(
235
+ exp_scale,
236
+ quant_dtypes=self.exponent_scale_quant_dtypes,
237
+ quant_spans=self.exponent_scale_quant_spans,
238
+ view_shapes=self.exponent_scale_view_shapes,
239
+ )
240
+ assert exp_s.data is not None, "Exponential scale tensor is None."
241
+ assert not exp_s.data.isnan().any(), "Exponential scale tensor contains NaN."
242
+ assert not exp_s.data.isinf().any(), "Exponential scale tensor contains Inf."
243
+ s = exp_s if lin_s.data is None else lin_s.extend(exp_s)
244
+ else:
245
+ s = lin_s
246
+
247
+ # Before the final assertions, add debugging and validation
248
+ if s.data is None:
249
+ # Log debugging information
250
+ print(f"Linear scale dtypes: {self.linear_scale_quant_dtypes}")
251
+ print(f"Exponent scale dtypes: {self.exponent_scale_quant_dtypes}")
252
+ if hasattr(lin_s, 'data') and lin_s.data is not None:
253
+ print(f"Linear scale data shape: {lin_s.data.shape}")
254
+ raise RuntimeError("Scale computation failed - resulting scale is None")
255
+ assert s.data is not None, "Scale tensor is None."
256
+ assert not s.data.isnan().any(), "Scale tensor contains NaN."
257
+ assert not s.data.isinf().any(), "Scale tensor contains Inf."
258
+ # endregion
259
+ # region step 3: get the zero point
260
+ if self.has_zero_point:
261
+ if range_based:
262
+ if self.tensor_zero_domain == ZeroPointDomain.PreScale:
263
+ zero = self.tensor_quant_range.min - dynamic_range.min / s.data
264
+ else:
265
+ zero = self.tensor_quant_range.min * s.data - dynamic_range.min
266
+ assert isinstance(zero, torch.Tensor), "Zero point must be a tensor."
267
+ z = simple_quantize(zero, has_zero_point=True, quant_dtype=self.zero_quant_dtype)
268
+ else:
269
+ z = torch.tensor(0, dtype=s.data.dtype, device=s.data.device)
270
+ assert not z.isnan().any(), "Zero point tensor contains NaN."
271
+ assert not z.isinf().any(), "Zero point tensor contains Inf."
272
+ # endregion
273
+ return s, z