lym00 commited on
Commit
b1d4119
·
verified ·
1 Parent(s): d6f7ae7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -0
README.md CHANGED
@@ -416,6 +416,68 @@ def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
416
  Potential Fix: deepcompressor.quantizer.impl.scale.py
417
 
418
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  def quantize(
420
  self,
421
  *,
 
416
  Potential Fix: deepcompressor.quantizer.impl.scale.py
417
 
418
  ```python
419
+ def quantize_scale(
420
+ s: torch.Tensor,
421
+ /,
422
+ *,
423
+ quant_dtypes: tp.Sequence[QuantDataType],
424
+ quant_spans: tp.Sequence[float],
425
+ view_shapes: tp.Sequence[torch.Size],
426
+ ) -> QuantScale:
427
+ """Quantize the scale tensor.
428
+
429
+ Args:
430
+ s (`torch.Tensor`):
431
+ The scale tensor.
432
+ quant_dtypes (`Sequence[QuantDataType]`):
433
+ The quantization dtypes of the scale tensor.
434
+ quant_spans (`Sequence[float]`):
435
+ The quantization spans of the scale tensor.
436
+ view_shapes (`Sequence[torch.Size]`):
437
+ The view shapes of the scale tensor.
438
+
439
+ Returns:
440
+ `QuantScale`:
441
+ The quantized scale tensor.
442
+ """
443
+ # Add validation at the start
444
+ if s.numel() == 0:
445
+ raise ValueError("Input tensor is empty")
446
+ if s.isnan().any() or s.isinf().any():
447
+ raise ValueError("Input tensor contains NaN or Inf values")
448
+ if (s == 0).all():
449
+ raise ValueError("Input tensor contains all zeros")
450
+
451
+ # Add meta tensor check before any operations
452
+ if s.is_meta:
453
+ raise RuntimeError("Cannot quantize scale with meta tensor. Ensure model is loaded on actual device.")
454
+
455
+ # Existing validation
456
+ if s.isnan().any() or s.isinf().any():
457
+ raise ValueError("Input tensor contains NaN or Inf values")
458
+
459
+ scale = QuantScale()
460
+ s = s.abs()
461
+ for view_shape, quant_dtype, quant_span in zip(view_shapes[:-1], quant_dtypes[:-1], quant_spans[:-1], strict=True):
462
+ s = s.view(view_shape) # (#g0, rs0, #g1, rs1, #g2, rs2, ...)
463
+ ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True) # i.e., s_dynamic_span
464
+ ss = simple_quantize(
465
+ ss / quant_span, has_zero_point=False, quant_dtype=quant_dtype
466
+ ) # i.e., s_scale = s_dynamic_span / s_quant_span
467
+ s = s / ss
468
+ scale.append(ss)
469
+ view_shape = view_shapes[-1]
470
+ s = s.view(view_shape)
471
+ if any(v != 1 for v in view_shape[1::2]):
472
+ ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)
473
+ ss = simple_quantize(ss / quant_spans[-1], has_zero_point=False, quant_dtype=quant_dtypes[-1])
474
+ else:
475
+ assert quant_spans[-1] == 1, "The last quant span must be 1."
476
+ ss = simple_quantize(s, has_zero_point=False, quant_dtype=quant_dtypes[-1])
477
+ scale.append(ss)
478
+ scale.remove_zero()
479
+ return scale
480
+
481
  def quantize(
482
  self,
483
  *,