Update README.md
Browse files
README.md
CHANGED
@@ -416,6 +416,68 @@ def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
|
|
416 |
Potential Fix: deepcompressor.quantizer.impl.scale.py
|
417 |
|
418 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
def quantize(
|
420 |
self,
|
421 |
*,
|
|
|
416 |
Potential Fix: deepcompressor.quantizer.impl.scale.py
|
417 |
|
418 |
```python
|
419 |
+
def quantize_scale(
|
420 |
+
s: torch.Tensor,
|
421 |
+
/,
|
422 |
+
*,
|
423 |
+
quant_dtypes: tp.Sequence[QuantDataType],
|
424 |
+
quant_spans: tp.Sequence[float],
|
425 |
+
view_shapes: tp.Sequence[torch.Size],
|
426 |
+
) -> QuantScale:
|
427 |
+
"""Quantize the scale tensor.
|
428 |
+
|
429 |
+
Args:
|
430 |
+
s (`torch.Tensor`):
|
431 |
+
The scale tensor.
|
432 |
+
quant_dtypes (`Sequence[QuantDataType]`):
|
433 |
+
The quantization dtypes of the scale tensor.
|
434 |
+
quant_spans (`Sequence[float]`):
|
435 |
+
The quantization spans of the scale tensor.
|
436 |
+
view_shapes (`Sequence[torch.Size]`):
|
437 |
+
The view shapes of the scale tensor.
|
438 |
+
|
439 |
+
Returns:
|
440 |
+
`QuantScale`:
|
441 |
+
The quantized scale tensor.
|
442 |
+
"""
|
443 |
+
# Add validation at the start
|
444 |
+
if s.numel() == 0:
|
445 |
+
raise ValueError("Input tensor is empty")
|
446 |
+
if s.isnan().any() or s.isinf().any():
|
447 |
+
raise ValueError("Input tensor contains NaN or Inf values")
|
448 |
+
if (s == 0).all():
|
449 |
+
raise ValueError("Input tensor contains all zeros")
|
450 |
+
|
451 |
+
# Add meta tensor check before any operations
|
452 |
+
if s.is_meta:
|
453 |
+
raise RuntimeError("Cannot quantize scale with meta tensor. Ensure model is loaded on actual device.")
|
454 |
+
|
455 |
+
# Existing validation
|
456 |
+
if s.isnan().any() or s.isinf().any():
|
457 |
+
raise ValueError("Input tensor contains NaN or Inf values")
|
458 |
+
|
459 |
+
scale = QuantScale()
|
460 |
+
s = s.abs()
|
461 |
+
for view_shape, quant_dtype, quant_span in zip(view_shapes[:-1], quant_dtypes[:-1], quant_spans[:-1], strict=True):
|
462 |
+
s = s.view(view_shape) # (#g0, rs0, #g1, rs1, #g2, rs2, ...)
|
463 |
+
ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True) # i.e., s_dynamic_span
|
464 |
+
ss = simple_quantize(
|
465 |
+
ss / quant_span, has_zero_point=False, quant_dtype=quant_dtype
|
466 |
+
) # i.e., s_scale = s_dynamic_span / s_quant_span
|
467 |
+
s = s / ss
|
468 |
+
scale.append(ss)
|
469 |
+
view_shape = view_shapes[-1]
|
470 |
+
s = s.view(view_shape)
|
471 |
+
if any(v != 1 for v in view_shape[1::2]):
|
472 |
+
ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)
|
473 |
+
ss = simple_quantize(ss / quant_spans[-1], has_zero_point=False, quant_dtype=quant_dtypes[-1])
|
474 |
+
else:
|
475 |
+
assert quant_spans[-1] == 1, "The last quant span must be 1."
|
476 |
+
ss = simple_quantize(s, has_zero_point=False, quant_dtype=quant_dtypes[-1])
|
477 |
+
scale.append(ss)
|
478 |
+
scale.remove_zero()
|
479 |
+
return scale
|
480 |
+
|
481 |
def quantize(
|
482 |
self,
|
483 |
*,
|