license: apache-2.0
Repos
https://github.com/mit-han-lab/deepcompressor
Installation
https://github.com/mit-han-lab/deepcompressor/issues/56
https://github.com/nunchaku-tech/deepcompressor/issues/80
Windows
https://learn.microsoft.com/en-us/windows/wsl/install
https://www.anaconda.com/docs/getting-started/miniconda/install
Environment
Hardware:
Nvidia RTX 5060 Ti (Blackwell, sm_120)
Software (WSL):
Python 3.12.11
pip 25.1
CUDA 12.8
Torch 2.7.1+cu128
Diffusers 0.35.0.dev0
Transformers 4.53.2
flash_attn 2.7.4.post1
xformers 0.0.31.post1
Calibration Dataset Preparation
Example: python -m deepcompressor.app.diffusion.dataset.collect.calib svdq/flux.1-kontext-dev.yaml examples/diffusion/configs/collect/qdiff.yaml --pipeline-path svdq/flux.1-kontext-dev/
Sample Log
In total 32 samples
Evaluating with batch size 1
Data: 3%|██▎ | 1/32 [13:57<7:12:32, 837.19s/it]
Sampling: 12%|█████████▍ | 1/8 [01:34<11:01, 94.44s/it]
Quantization
Model Path: https://github.com/nunchaku-tech/deepcompressor/issues/70#issuecomment-2788155233
Save model: --save-model true
or --save-model /PATH/TO/CHECKPOINT/DIR
Example: python -m deepcompressor.app.diffusion.ptq svdq/flux.1-kontext-dev.yaml examples/diffusion/configs/svdquant/nvfp4.yaml --pipeline-path svdq/flux.1-kontext-dev/ --save-model ~/svdq/
Model Files Structure
Deploy
https://github.com/nunchaku-tech/deepcompressor/blob/main/examples/diffusion/README.md#deployment
Example python -m deepcompressor.backend.nunchaku.convert --quant-path ~/svdq/ --output-root ~/svdq/ --model-name flux.1-kontext-dev-svdq-fp4
Blockers
- NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
Potential fix: app.diffusion.pipeline.config.py
@staticmethod
def _default_build(
name: str, path: str, dtype: str | torch.dtype, device: str | torch.device, shift_activations: bool
) -> DiffusionPipeline:
if not path:
if name == "sdxl":
path = "stabilityai/stable-diffusion-xl-base-1.0"
elif name == "sdxl-turbo":
path = "stabilityai/sdxl-turbo"
elif name == "pixart-sigma":
path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
elif name == "flux.1-kontext-dev":
path = "black-forest-labs/FLUX.1-Kontext-dev"
elif name == "flux.1-dev":
path = "black-forest-labs/FLUX.1-dev"
elif name == "flux.1-canny-dev":
path = "black-forest-labs/FLUX.1-Canny-dev"
elif name == "flux.1-depth-dev":
path = "black-forest-labs/FLUX.1-Depth-dev"
elif name == "flux.1-fill-dev":
path = "black-forest-labs/FLUX.1-Fill-dev"
elif name == "flux.1-schnell":
path = "black-forest-labs/FLUX.1-schnell"
else:
raise ValueError(f"Path for {name} is not specified.")
if name in ["flux.1-kontext-dev"]:
pipeline = FluxKontextPipeline.from_pretrained(path, torch_dtype=dtype)
elif name in ["flux.1-canny-dev", "flux.1-depth-dev"]:
pipeline = FluxControlPipeline.from_pretrained(path, torch_dtype=dtype)
elif name == "flux.1-fill-dev":
pipeline = FluxFillPipeline.from_pretrained(path, torch_dtype=dtype)
elif name.startswith("sana-"):
if dtype == torch.bfloat16:
pipeline = SanaPipeline.from_pretrained(path, variant="bf16", torch_dtype=dtype, use_safetensors=True)
pipeline.vae.to(dtype)
pipeline.text_encoder.to(dtype)
else:
pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype)
else:
pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype)
# Debug output
print(">>> DEVICE:", device)
print(">>> PIPELINE TYPE:", type(pipeline))
# Try to move each component using .to_empty()
for name in ["unet", "transformer", "vae", "text_encoder"]:
module = getattr(pipeline, name, None)
if isinstance(module, torch.nn.Module):
try:
print(f">>> Moving {name} to {device} using to_empty()")
module.to_empty(device=device)
except Exception as e:
print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
try:
print(f">>> Falling back to {name}.to({device})")
module.to(device)
except Exception as ee:
print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
# Identify main model (for patching)
model = getattr(pipeline, "unet", None) or getattr(pipeline, "transformer", None)
if model is not None:
replace_fused_linear_with_concat_linear(model)
replace_up_block_conv_with_concat_conv(model)
if shift_activations:
shift_input_activations(model)
else:
print(">>> WARNING: No model (unet/transformer) found for patching")
return pipeline
- KeyError: <class 'diffusers.models.transformers.transformer_flux.FluxAttention'>
Potential fix: app.diffusion.nn.struct.py
@staticmethod
def _default_construct(
module: Attention,
/,
parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
fname: str = "",
rname: str = "",
rkey: str = "",
idx: int = 0,
**kwargs,
) -> "DiffusionAttentionStruct":
if isinstance(module, FluxAttention):
# FluxAttention has different attribute names than standard attention
with_rope = True
num_query_heads = module.heads # FluxAttention uses 'heads', not 'num_heads'
num_key_value_heads = module.heads # FLUX typically uses same for q/k/v
# FluxAttention doesn't have 'to_out', but may have other output projections
# Check what output projection attributes actually exist
o_proj = None
o_proj_rname = ""
# Try to find the correct output projection
if hasattr(module, 'to_out') and module.to_out is not None:
o_proj = module.to_out[0] if isinstance(module.to_out, (list, tuple)) else module.to_out
o_proj_rname = "to_out.0" if isinstance(module.to_out, (list, tuple)) else "to_out"
elif hasattr(module, 'to_add_out'):
o_proj = module.to_add_out
o_proj_rname = "to_add_out"
q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
q, k, v = module.to_q, module.to_k, module.to_v
q_rname, k_rname, v_rname = "to_q", "to_k", "to_v"
# Handle the add_* projections that FluxAttention has
add_q_proj = getattr(module, "add_q_proj", None)
add_k_proj = getattr(module, "add_k_proj", None)
add_v_proj = getattr(module, "add_v_proj", None)
add_o_proj = getattr(module, "to_add_out", None)
add_q_proj_rname = "add_q_proj" if add_q_proj else ""
add_k_proj_rname = "add_k_proj" if add_k_proj else ""
add_v_proj_rname = "add_v_proj" if add_v_proj else ""
add_o_proj_rname = "to_add_out" if add_o_proj else ""
kwargs = (
"encoder_hidden_states",
"attention_mask",
"image_rotary_emb",
)
cross_attention = add_k_proj is not None
elif module.is_cross_attention:
q_proj, k_proj, v_proj = module.to_q, None, None
add_q_proj, add_k_proj, add_v_proj, add_o_proj = None, module.to_k, module.to_v, None
q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "", ""
add_q_proj_rname, add_k_proj_rname, add_v_proj_rname, add_o_proj_rname = "", "to_k", "to_v", ""
else:
q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
add_q_proj = getattr(module, "add_q_proj", None)
add_k_proj = getattr(module, "add_k_proj", None)
add_v_proj = getattr(module, "add_v_proj", None)
add_o_proj = getattr(module, "to_add_out", None)
q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
add_q_proj_rname, add_k_proj_rname, add_v_proj_rname = "add_q_proj", "add_k_proj", "add_v_proj"
add_o_proj_rname = "to_add_out"
if getattr(module, "to_out", None) is not None:
o_proj = module.to_out[0]
o_proj_rname = "to_out.0"
assert isinstance(o_proj, nn.Linear)
elif parent is not None:
assert isinstance(parent.module, FluxSingleTransformerBlock)
assert isinstance(parent.module.proj_out, ConcatLinear)
assert len(parent.module.proj_out.linears) == 2
o_proj = parent.module.proj_out.linears[0]
o_proj_rname = ".proj_out.linears.0"
else:
raise RuntimeError("Cannot find the output projection.")
if isinstance(module.processor, DiffusionAttentionProcessor):
with_rope = module.processor.rope is not None
elif module.processor.__class__.__name__.startswith("Flux"):
with_rope = True
else:
with_rope = False # TODO: fix for other processors
config = AttentionConfigStruct(
hidden_size=q_proj.weight.shape[1],
add_hidden_size=add_k_proj.weight.shape[1] if add_k_proj is not None else 0,
inner_size=q_proj.weight.shape[0],
num_query_heads=module.heads,
num_key_value_heads=module.to_k.weight.shape[0] // (module.to_q.weight.shape[0] // module.heads),
with_qk_norm=module.norm_q is not None,
with_rope=with_rope,
linear_attn=isinstance(module.processor, SanaLinearAttnProcessor2_0),
)
return DiffusionAttentionStruct(
module=module,
parent=parent,
fname=fname,
idx=idx,
rname=rname,
rkey=rkey,
config=config,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
o_proj=o_proj,
add_q_proj=add_q_proj,
add_k_proj=add_k_proj,
add_v_proj=add_v_proj,
add_o_proj=add_o_proj,
q=None, # TODO: add q, k, v
k=None,
v=None,
q_proj_rname=q_proj_rname,
k_proj_rname=k_proj_rname,
v_proj_rname=v_proj_rname,
o_proj_rname=o_proj_rname,
add_q_proj_rname=add_q_proj_rname,
add_k_proj_rname=add_k_proj_rname,
add_v_proj_rname=add_v_proj_rname,
add_o_proj_rname=add_o_proj_rname,
q_rname="",
k_rname="",
v_rname="",
)
- ValueError: Provide either
prompt
orprompt_embeds
. Cannot leave bothprompt
andprompt_embeds
undefined.
Potential Fix: app.diffusion.dataset.collect.calib.py
def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
samples_dirpath = os.path.join(config.output.root, "samples")
caches_dirpath = os.path.join(config.output.root, "caches")
os.makedirs(samples_dirpath, exist_ok=True)
os.makedirs(caches_dirpath, exist_ok=True)
caches = []
pipeline = config.pipeline.build()
model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
assert isinstance(model, nn.Module)
model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)
batch_size = config.eval.batch_size
print(f"In total {len(dataset)} samples")
print(f"Evaluating with batch size {batch_size}")
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
for batch in tqdm(
dataset.iter(batch_size=batch_size, drop_last_batch=False),
desc="Data",
leave=False,
dynamic_ncols=True,
total=(len(dataset) + batch_size - 1) // batch_size,
):
filenames = batch["filename"]
prompts = batch["prompt"]
seeds = [hash_str_to_int(name) for name in filenames]
generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
pipeline_kwargs = config.eval.get_pipeline_kwargs()
task = config.pipeline.task
control_root = config.eval.control_root
if task in ["canny-to-image", "depth-to-image", "inpainting"]:
controls = get_control(
task,
batch["image"],
names=batch["filename"],
data_root=os.path.join(
control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
),
)
if task == "inpainting":
pipeline_kwargs["image"] = controls[0]
pipeline_kwargs["mask_image"] = controls[1]
else:
pipeline_kwargs["control_image"] = controls
# Handle meta tensors by moving individual components
try:
pipeline = pipeline.to("cuda")
except NotImplementedError:
# Move individual pipeline components that have to_empty method
if hasattr(pipeline, 'transformer') and pipeline.transformer is not None:
try:
pipeline.transformer = pipeline.transformer.to("cuda")
except NotImplementedError:
pipeline.transformer = pipeline.transformer.to_empty(device="cuda")
if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
try:
pipeline.text_encoder = pipeline.text_encoder.to("cuda")
except NotImplementedError:
pipeline.text_encoder = pipeline.text_encoder.to_empty(device="cuda")
if hasattr(pipeline, 'text_encoder_2') and pipeline.text_encoder_2 is not None:
try:
pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda")
except NotImplementedError:
pipeline.text_encoder_2 = pipeline.text_encoder_2.to_empty(device="cuda")
if hasattr(pipeline, 'vae') and pipeline.vae is not None:
try:
pipeline.vae = pipeline.vae.to("cuda")
except NotImplementedError:
pipeline.vae = pipeline.vae.to_empty(device="cuda")
result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images
num_guidances = (len(caches) // batch_size) // config.eval.num_steps
num_steps = len(caches) // (batch_size * num_guidances)
assert (
len(caches) == batch_size * num_steps * num_guidances
), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
image.save(os.path.join(samples_dirpath, f"{filename}.png"))
for s in range(num_steps):
for g in range(num_guidances):
c = caches[s * batch_size * num_guidances + g * batch_size + j]
c["filename"] = filename
c["step"] = s
c["guidance"] = g
c = tree_map(lambda x: process(x), c)
torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt"))
caches.clear()
- RuntimeError: Tensor.item() cannot be called on meta tensors
Potential Fix: quantizer.impl.scale.py
def quantize_scale(
s: torch.Tensor,
/,
*,
quant_dtypes: tp.Sequence[QuantDataType],
quant_spans: tp.Sequence[float],
view_shapes: tp.Sequence[torch.Size],
) -> QuantScale:
"""Quantize the scale tensor.
Args:
s (`torch.Tensor`):
The scale tensor.
quant_dtypes (`Sequence[QuantDataType]`):
The quantization dtypes of the scale tensor.
quant_spans (`Sequence[float]`):
The quantization spans of the scale tensor.
view_shapes (`Sequence[torch.Size]`):
The view shapes of the scale tensor.
Returns:
`QuantScale`:
The quantized scale tensor.
"""
# Add validation at the start
if s.numel() == 0:
raise ValueError("Input tensor is empty")
if s.isnan().any() or s.isinf().any():
raise ValueError("Input tensor contains NaN or Inf values")
if (s == 0).all():
raise ValueError("Input tensor contains all zeros")
# Add meta tensor check before any operations
if s.is_meta:
raise RuntimeError("Cannot quantize scale with meta tensor. Ensure model is loaded on actual device.")
# Existing validation
if s.isnan().any() or s.isinf().any():
raise ValueError("Input tensor contains NaN or Inf values")
scale = QuantScale()
s = s.abs()
for view_shape, quant_dtype, quant_span in zip(view_shapes[:-1], quant_dtypes[:-1], quant_spans[:-1], strict=True):
s = s.view(view_shape) # (#g0, rs0, #g1, rs1, #g2, rs2, ...)
ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True) # i.e., s_dynamic_span
ss = simple_quantize(
ss / quant_span, has_zero_point=False, quant_dtype=quant_dtype
) # i.e., s_scale = s_dynamic_span / s_quant_span
s = s / ss
scale.append(ss)
view_shape = view_shapes[-1]
s = s.view(view_shape)
if any(v != 1 for v in view_shape[1::2]):
ss = s.amax(dim=list(range(1, len(view_shape), 2)), keepdim=True)
ss = simple_quantize(ss / quant_spans[-1], has_zero_point=False, quant_dtype=quant_dtypes[-1])
else:
assert quant_spans[-1] == 1, "The last quant span must be 1."
ss = simple_quantize(s, has_zero_point=False, quant_dtype=quant_dtypes[-1])
scale.append(ss)
scale.remove_zero()
return scale
def quantize(
self,
*,
# scale-based quantization related arguments
scale: torch.Tensor | None = None,
zero: torch.Tensor | None = None,
# range-based quantization related arguments
tensor: torch.Tensor | None = None,
dynamic_range: DynamicRange | None = None,
) -> tuple[QuantScale, torch.Tensor]:
"""Get the quantization scale and zero point of the tensor to be quantized.
Args:
scale (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The scale tensor.
zero (`torch.Tensor` or `None`, *optional*, defaults to `None`):
The zero point tensor.
tensor (`torch.Tensor` or `None`, *optional*, defaults to `None`):
Ten tensor to be quantized. This is only used for range-based quantization.
dynamic_range (`DynamicRange` or `None`, *optional*, defaults to `None`):
The dynamic range of the tensor to be quantized.
Returns:
`tuple[QuantScale, torch.Tensor]`:
The scale and the zero point.
"""
# region step 1: get the dynamic span for range-based scale or the scale tensor
if scale is None:
range_based = True
assert isinstance(tensor, torch.Tensor), "View tensor must be a tensor."
dynamic_range = dynamic_range or DynamicRange()
dynamic_range = dynamic_range.measure(
tensor.view(self.tensor_view_shape),
zero_domain=self.tensor_zero_domain,
is_float_point=self.tensor_quant_dtype.is_float_point,
)
dynamic_range = dynamic_range.intersect(self.tensor_range_bound)
dynamic_span = (dynamic_range.max - dynamic_range.min) if self.has_zero_point else dynamic_range.max
else:
range_based = False
scale = scale.view(self.scale_view_shapes[-1])
assert isinstance(scale, torch.Tensor), "Scale must be a tensor."
# endregion
# region step 2: get the scale
if self.linear_scale_quant_dtypes:
if range_based:
linear_scale = dynamic_span / self.linear_tensor_quant_span
elif self.exponent_scale_quant_dtypes:
linear_scale = scale.mul(self.exponent_tensor_quant_span).div(self.linear_tensor_quant_span)
else:
linear_scale = scale
lin_s = quantize_scale(
linear_scale,
quant_dtypes=self.linear_scale_quant_dtypes,
quant_spans=self.linear_scale_quant_spans,
view_shapes=self.linear_scale_view_shapes,
)
assert lin_s.data is not None, "Linear scale tensor is None."
if not lin_s.data.is_meta:
assert not lin_s.data.isnan().any(), "Linear scale tensor contains NaN."
assert not lin_s.data.isinf().any(), "Linear scale tensor contains Inf."
else:
lin_s = QuantScale()
if self.exponent_scale_quant_dtypes:
if range_based:
exp_scale = dynamic_span / self.exponent_tensor_quant_span
else:
exp_scale = scale
if lin_s.data is not None:
lin_s.data = lin_s.data.expand(self.linear_scale_view_shapes[-1]).reshape(self.scale_view_shapes[-1])
exp_scale = exp_scale / lin_s.data
exp_s = quantize_scale(
exp_scale,
quant_dtypes=self.exponent_scale_quant_dtypes,
quant_spans=self.exponent_scale_quant_spans,
view_shapes=self.exponent_scale_view_shapes,
)
assert exp_s.data is not None, "Exponential scale tensor is None."
assert not exp_s.data.isnan().any(), "Exponential scale tensor contains NaN."
assert not exp_s.data.isinf().any(), "Exponential scale tensor contains Inf."
s = exp_s if lin_s.data is None else lin_s.extend(exp_s)
else:
s = lin_s
# Before the final assertions, add debugging and validation
if s.data is None:
# Log debugging information
print(f"Linear scale dtypes: {self.linear_scale_quant_dtypes}")
print(f"Exponent scale dtypes: {self.exponent_scale_quant_dtypes}")
if hasattr(lin_s, 'data') and lin_s.data is not None:
print(f"Linear scale data shape: {lin_s.data.shape}")
raise RuntimeError("Scale computation failed - resulting scale is None")
assert s.data is not None, "Scale tensor is None."
assert not s.data.isnan().any(), "Scale tensor contains NaN."
assert not s.data.isinf().any(), "Scale tensor contains Inf."
# endregion
# region step 3: get the zero point
if self.has_zero_point:
if range_based:
if self.tensor_zero_domain == ZeroPointDomain.PreScale:
zero = self.tensor_quant_range.min - dynamic_range.min / s.data
else:
zero = self.tensor_quant_range.min * s.data - dynamic_range.min
assert isinstance(zero, torch.Tensor), "Zero point must be a tensor."
z = simple_quantize(zero, has_zero_point=True, quant_dtype=self.zero_quant_dtype)
else:
z = torch.tensor(0, dtype=s.data.dtype, device=s.data.device)
assert not z.isnan().any(), "Zero point tensor contains NaN."
assert not z.isinf().any(), "Zero point tensor contains Inf."
# endregion
return s, z
Potential Fix: app.diffusion.ptq.py
def ptq( # noqa: C901
model: DiffusionModelStruct,
config: DiffusionQuantConfig,
cache: DiffusionPtqCacheConfig | None = None,
load_dirpath: str = "",
save_dirpath: str = "",
copy_on_save: bool = False,
save_model: bool = False,
) -> DiffusionModelStruct:
"""Post-training quantization of a diffusion model.
Args:
model (`DiffusionModelStruct`):
The diffusion model.
config (`DiffusionQuantConfig`):
The diffusion model post-training quantization configuration.
cache (`DiffusionPtqCacheConfig`, *optional*, defaults to `None`):
The diffusion model quantization cache path configuration.
load_dirpath (`str`, *optional*, defaults to `""`):
The directory path to load the quantization checkpoint.
save_dirpath (`str`, *optional*, defaults to `""`):
The directory path to save the quantization checkpoint.
copy_on_save (`bool`, *optional*, defaults to `False`):
Whether to copy the cache to the save directory.
save_model (`bool`, *optional*, defaults to `False`):
Whether to save the quantized model checkpoint.
Returns:
`DiffusionModelStruct`:
The quantized diffusion model.
"""
logger = tools.logging.getLogger(__name__)
if not isinstance(model, DiffusionModelStruct):
model = DiffusionModelStruct.construct(model)
assert isinstance(model, DiffusionModelStruct)
quant_wgts = config.enabled_wgts
quant_ipts = config.enabled_ipts
quant_opts = config.enabled_opts
quant_acts = quant_ipts or quant_opts
quant = quant_wgts or quant_acts
load_model_path, load_path, save_path = "", None, None
if load_dirpath:
load_path = DiffusionQuantCacheConfig(
smooth=os.path.join(load_dirpath, "smooth.pt"),
branch=os.path.join(load_dirpath, "branch.pt"),
wgts=os.path.join(load_dirpath, "wgts.pt"),
acts=os.path.join(load_dirpath, "acts.pt"),
)
load_model_path = os.path.join(load_dirpath, "model.pt")
if os.path.exists(load_model_path):
if config.enabled_wgts and config.wgts.enabled_low_rank:
if os.path.exists(load_path.branch):
load_model = True
else:
logger.warning(f"Model low-rank branch checkpoint {load_path.branch} does not exist")
load_model = False
else:
load_model = True
if load_model:
logger.info(f"* Loading model from {load_model_path}")
save_dirpath = "" # do not save the model if loading
else:
logger.warning(f"Model checkpoint {load_model_path} does not exist")
load_model = False
else:
load_model = False
if save_dirpath:
os.makedirs(save_dirpath, exist_ok=True)
save_path = DiffusionQuantCacheConfig(
smooth=os.path.join(save_dirpath, "smooth.pt"),
branch=os.path.join(save_dirpath, "branch.pt"),
wgts=os.path.join(save_dirpath, "wgts.pt"),
acts=os.path.join(save_dirpath, "acts.pt"),
)
else:
save_model = False
if quant and config.enabled_rotation:
logger.info("* Rotating model for quantization")
tools.logging.Formatter.indent_inc()
rotate_diffusion(model, config=config)
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
# region smooth quantization
if quant and config.enabled_smooth:
logger.info("* Smoothing model for quantization")
tools.logging.Formatter.indent_inc()
load_from = ""
if load_path and os.path.exists(load_path.smooth):
load_from = load_path.smooth
elif cache and cache.path.smooth and os.path.exists(cache.path.smooth):
load_from = cache.path.smooth
if load_from:
logger.info(f"- Loading smooth scales from {load_from}")
smooth_cache = torch.load(load_from)
smooth_diffusion(model, config, smooth_cache=smooth_cache)
else:
logger.info("- Generating smooth scales")
smooth_cache = smooth_diffusion(model, config)
if cache and cache.path.smooth:
logger.info(f"- Saving smooth scales to {cache.path.smooth}")
os.makedirs(cache.dirpath.smooth, exist_ok=True)
torch.save(smooth_cache, cache.path.smooth)
load_from = cache.path.smooth
if save_path:
if not copy_on_save and load_from:
logger.info(f"- Linking smooth scales to {save_path.smooth}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.smooth)
else:
logger.info(f"- Saving smooth scales to {save_path.smooth}")
torch.save(smooth_cache, save_path.smooth)
del smooth_cache
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
# endregion
# region collect original state dict
if config.needs_acts_quantizer_cache:
if load_path and os.path.exists(load_path.acts):
orig_state_dict = None
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
orig_state_dict = None
else:
orig_state_dict: dict[str, torch.Tensor] = {
name: param.detach().clone() for name, param in model.module.named_parameters() if param.ndim > 1
}
else:
orig_state_dict = None
# endregion
if load_model:
logger.info(f"* Loading model checkpoint from {load_model_path}")
load_diffusion_weights_state_dict(
model,
config,
state_dict=torch.load(load_model_path),
branch_state_dict=torch.load(load_path.branch) if os.path.exists(load_path.branch) else None,
)
gc.collect()
torch.cuda.empty_cache()
elif quant_wgts:
logger.info("* Ensuring model is on actual device before quantization")
# Check if model has meta tensors
has_meta_tensors = any(param.is_meta for param in model.module.parameters())
if has_meta_tensors:
logger.info("* Model contains meta tensors, materializing to actual device")
# Option 1: Use to_empty() and reload weights (recommended)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Store original state dict if available
try:
original_state_dict = model.module.state_dict()
model.module = model.module.to_empty(device=device)
model.module.load_state_dict(original_state_dict)
logger.info("* Successfully materialized model with original weights")
except Exception as e:
logger.warning(f"* Failed to preserve weights during materialization: {e}")
# Fallback: just move to empty device (weights will be zero)
model.module = model.module.to_empty(device=device)
logger.warning("* Model moved to device but weights may be uninitialized")
else:
# Model already has real tensors, just ensure it's on the right device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.module = model.module.to(device)
# Verify no meta tensors remain
remaining_meta = [name for name, param in model.module.named_parameters() if param.is_meta]
if remaining_meta:
raise RuntimeError(f"Parameters still on meta device: {remaining_meta}")
logger.info("* Model successfully prepared for quantization")
logger.info("* Quantizing weights")
tools.logging.Formatter.indent_inc()
quantizer_state_dict, quantizer_load_from = None, ""
if load_path and os.path.exists(load_path.wgts):
quantizer_load_from = load_path.wgts
elif cache and cache.path.wgts and os.path.exists(cache.path.wgts):
quantizer_load_from = cache.path.wgts
if quantizer_load_from:
logger.info(f"- Loading weight settings from {quantizer_load_from}")
quantizer_state_dict = torch.load(quantizer_load_from)
branch_state_dict, branch_load_from = None, ""
if load_path and os.path.exists(load_path.branch):
branch_load_from = load_path.branch
elif cache and cache.path.branch and os.path.exists(cache.path.branch):
branch_load_from = cache.path.branch
if branch_load_from:
logger.info(f"- Loading branch settings from {branch_load_from}")
branch_state_dict = torch.load(branch_load_from)
if not quantizer_load_from:
logger.info("- Generating weight settings")
if not branch_load_from:
logger.info("- Generating branch settings")
quantizer_state_dict, branch_state_dict, scale_state_dict = quantize_diffusion_weights(
model,
config,
quantizer_state_dict=quantizer_state_dict,
branch_state_dict=branch_state_dict,
return_with_scale_state_dict=bool(save_dirpath),
)
if not quantizer_load_from and cache and cache.dirpath.wgts:
logger.info(f"- Saving weight settings to {cache.path.wgts}")
os.makedirs(cache.dirpath.wgts, exist_ok=True)
torch.save(quantizer_state_dict, cache.path.wgts)
quantizer_load_from = cache.path.wgts
if not branch_load_from and cache and cache.dirpath.branch:
logger.info(f"- Saving branch settings to {cache.path.branch}")
os.makedirs(cache.dirpath.branch, exist_ok=True)
torch.save(branch_state_dict, cache.path.branch)
branch_load_from = cache.path.branch
if save_path:
if not copy_on_save and quantizer_load_from:
logger.info(f"- Linking weight settings to {save_path.wgts}")
os.symlink(os.path.relpath(quantizer_load_from, save_dirpath), save_path.wgts)
else:
logger.info(f"- Saving weight settings to {save_path.wgts}")
torch.save(quantizer_state_dict, save_path.wgts)
if not copy_on_save and branch_load_from:
logger.info(f"- Linking branch settings to {save_path.branch}")
os.symlink(os.path.relpath(branch_load_from, save_dirpath), save_path.branch)
else:
logger.info(f"- Saving branch settings to {save_path.branch}")
torch.save(branch_state_dict, save_path.branch)
if save_model:
logger.info(f"- Saving model to {save_dirpath}")
torch.save(scale_state_dict, os.path.join(save_dirpath, "scale.pt"))
torch.save(model.module.state_dict(), os.path.join(save_dirpath, "model.pt"))
del quantizer_state_dict, branch_state_dict, scale_state_dict
tools.logging.Formatter.indent_dec()
gc.collect()
torch.cuda.empty_cache()
if quant_acts:
logger.info(" * Quantizing activations")
tools.logging.Formatter.indent_inc()
if config.needs_acts_quantizer_cache:
load_from = ""
if load_path and os.path.exists(load_path.acts):
load_from = load_path.acts
elif cache and cache.path.acts and os.path.exists(cache.path.acts):
load_from = cache.path.acts
if load_from:
logger.info(f"- Loading activation settings from {load_from}")
quantizer_state_dict = torch.load(load_from)
quantize_diffusion_activations(
model, config, quantizer_state_dict=quantizer_state_dict, orig_state_dict=orig_state_dict
)
else:
logger.info("- Generating activation settings")
quantizer_state_dict = quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
if cache and cache.dirpath.acts and quantizer_state_dict is not None:
logger.info(f"- Saving activation settings to {cache.path.acts}")
os.makedirs(cache.dirpath.acts, exist_ok=True)
torch.save(quantizer_state_dict, cache.path.acts)
load_from = cache.path.acts
if save_dirpath:
if not copy_on_save and load_from:
logger.info(f"- Linking activation quantizer settings to {save_path.acts}")
os.symlink(os.path.relpath(load_from, save_dirpath), save_path.acts)
else:
logger.info(f"- Saving activation quantizer settings to {save_path.acts}")
torch.save(quantizer_state_dict, save_path.acts)
del quantizer_state_dict
else:
logger.info("- No need to generate/load activation quantizer settings")
quantize_diffusion_activations(model, config, orig_state_dict=orig_state_dict)
tools.logging.Formatter.indent_dec()
del orig_state_dict
gc.collect()
torch.cuda.empty_cache()
return model
- RuntimeError: Dataset scripts are no longer supported, but found COCO.py
References
https://github.com/nunchaku-tech/nunchaku/commit/b99fb8be615bc98c6915bbe06a1e0092cbc074a5
https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-kontext-dev.py
https://github.com/nunchaku-tech/deepcompressor/issues/91
https://deepwiki.com/nunchaku-tech/deepcompressor
Dependencies
https://github.com/Dao-AILab/flash-attention
https://github.com/facebookresearch/xformers
https://github.com/openai/CLIP
https://github.com/THUDM/ImageReward
Wheels
https://huggingface.co/datasets/siraxe/PrecompiledWheels_Torch-2.8-cu128-cp312
https://huggingface.co/lldacing/flash-attention-windows-wheel