|  |  | 
					
						
						|  | """Collect calibration dataset.""" | 
					
						
						|  |  | 
					
						
						|  | import os | 
					
						
						|  | from dataclasses import dataclass | 
					
						
						|  |  | 
					
						
						|  | import datasets | 
					
						
						|  | import torch | 
					
						
						|  | from omniconfig import configclass | 
					
						
						|  | from torch import nn | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  |  | 
					
						
						|  | from deepcompressor.app.diffusion.config import DiffusionPtqRunConfig | 
					
						
						|  | from deepcompressor.utils.common import hash_str_to_int, tree_map | 
					
						
						|  |  | 
					
						
						|  | from ...utils import get_control | 
					
						
						|  | from ..data import get_dataset | 
					
						
						|  | from .utils import CollectHook | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def process(x: torch.Tensor) -> torch.Tensor: | 
					
						
						|  | dtype = x.dtype | 
					
						
						|  | return torch.from_numpy(x.float().numpy()).to(dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset): | 
					
						
						|  | samples_dirpath = os.path.join(config.output.root, "samples") | 
					
						
						|  | caches_dirpath = os.path.join(config.output.root, "caches") | 
					
						
						|  | os.makedirs(samples_dirpath, exist_ok=True) | 
					
						
						|  | os.makedirs(caches_dirpath, exist_ok=True) | 
					
						
						|  | caches = [] | 
					
						
						|  |  | 
					
						
						|  | pipeline = config.pipeline.build() | 
					
						
						|  | model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer | 
					
						
						|  | assert isinstance(model, nn.Module) | 
					
						
						|  | model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True) | 
					
						
						|  |  | 
					
						
						|  | batch_size = config.eval.batch_size | 
					
						
						|  | print(f"In total {len(dataset)} samples") | 
					
						
						|  | print(f"Evaluating with batch size {batch_size}") | 
					
						
						|  | pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1) | 
					
						
						|  | for batch in tqdm( | 
					
						
						|  | dataset.iter(batch_size=batch_size, drop_last_batch=False), | 
					
						
						|  | desc="Data", | 
					
						
						|  | leave=False, | 
					
						
						|  | dynamic_ncols=True, | 
					
						
						|  | total=(len(dataset) + batch_size - 1) // batch_size, | 
					
						
						|  | ): | 
					
						
						|  | filenames = batch["filename"] | 
					
						
						|  | prompts = batch["prompt"] | 
					
						
						|  | seeds = [hash_str_to_int(name) for name in filenames] | 
					
						
						|  | generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds] | 
					
						
						|  | pipeline_kwargs = config.eval.get_pipeline_kwargs() | 
					
						
						|  |  | 
					
						
						|  | task = config.pipeline.task | 
					
						
						|  | control_root = config.eval.control_root | 
					
						
						|  | if task in ["canny-to-image", "depth-to-image", "inpainting"]: | 
					
						
						|  | controls = get_control( | 
					
						
						|  | task, | 
					
						
						|  | batch["image"], | 
					
						
						|  | names=batch["filename"], | 
					
						
						|  | data_root=os.path.join( | 
					
						
						|  | control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}" | 
					
						
						|  | ), | 
					
						
						|  | ) | 
					
						
						|  | if task == "inpainting": | 
					
						
						|  | pipeline_kwargs["image"] = controls[0] | 
					
						
						|  | pipeline_kwargs["mask_image"] = controls[1] | 
					
						
						|  | else: | 
					
						
						|  | pipeline_kwargs["control_image"] = controls | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | pipeline = pipeline.to("cuda") | 
					
						
						|  | except NotImplementedError: | 
					
						
						|  |  | 
					
						
						|  | if hasattr(pipeline, 'transformer') and pipeline.transformer is not None: | 
					
						
						|  | try: | 
					
						
						|  | pipeline.transformer = pipeline.transformer.to("cuda") | 
					
						
						|  | except NotImplementedError: | 
					
						
						|  | pipeline.transformer = pipeline.transformer.to_empty(device="cuda") | 
					
						
						|  |  | 
					
						
						|  | if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None: | 
					
						
						|  | try: | 
					
						
						|  | pipeline.text_encoder = pipeline.text_encoder.to("cuda") | 
					
						
						|  | except NotImplementedError: | 
					
						
						|  | pipeline.text_encoder = pipeline.text_encoder.to_empty(device="cuda") | 
					
						
						|  |  | 
					
						
						|  | if hasattr(pipeline, 'text_encoder_2') and pipeline.text_encoder_2 is not None: | 
					
						
						|  | try: | 
					
						
						|  | pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda") | 
					
						
						|  | except NotImplementedError: | 
					
						
						|  | pipeline.text_encoder_2 = pipeline.text_encoder_2.to_empty(device="cuda") | 
					
						
						|  |  | 
					
						
						|  | if hasattr(pipeline, 'vae') and pipeline.vae is not None: | 
					
						
						|  | try: | 
					
						
						|  | pipeline.vae = pipeline.vae.to("cuda") | 
					
						
						|  | except NotImplementedError: | 
					
						
						|  | pipeline.vae = pipeline.vae.to_empty(device="cuda") | 
					
						
						|  |  | 
					
						
						|  | result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images | 
					
						
						|  | num_guidances = (len(caches) // batch_size) // config.eval.num_steps | 
					
						
						|  | num_steps = len(caches) // (batch_size * num_guidances) | 
					
						
						|  | assert ( | 
					
						
						|  | len(caches) == batch_size * num_steps * num_guidances | 
					
						
						|  | ), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}" | 
					
						
						|  | for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)): | 
					
						
						|  | image.save(os.path.join(samples_dirpath, f"{filename}.png")) | 
					
						
						|  | for s in range(num_steps): | 
					
						
						|  | for g in range(num_guidances): | 
					
						
						|  | c = caches[s * batch_size * num_guidances + g * batch_size + j] | 
					
						
						|  | c["filename"] = filename | 
					
						
						|  | c["step"] = s | 
					
						
						|  | c["guidance"] = g | 
					
						
						|  | c = tree_map(lambda x: process(x), c) | 
					
						
						|  | torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt")) | 
					
						
						|  | caches.clear() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @configclass | 
					
						
						|  | @dataclass | 
					
						
						|  | class CollectConfig: | 
					
						
						|  | """Configuration for collecting calibration dataset. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | root (`str`, *optional*, defaults to `"datasets"`): | 
					
						
						|  | Root directory to save the collected dataset. | 
					
						
						|  | dataset_name (`str`, *optional*, defaults to `"qdiff"`): | 
					
						
						|  | Name of the collected dataset. | 
					
						
						|  | prompt_path (`str`, *optional*, defaults to `"prompts/qdiff.yaml"`): | 
					
						
						|  | Path to the prompt file. | 
					
						
						|  | num_samples (`int`, *optional*, defaults to `128`): | 
					
						
						|  | Number of samples to collect. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | root: str = "datasets" | 
					
						
						|  | dataset_name: str = "qdiff" | 
					
						
						|  | data_path: str = "prompts/qdiff.yaml" | 
					
						
						|  | num_samples: int = 128 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | parser = DiffusionPtqRunConfig.get_parser() | 
					
						
						|  | parser.add_config(CollectConfig, scope="collect", prefix="collect") | 
					
						
						|  | configs, _, unused_cfgs, unused_args, unknown_args = parser.parse_known_args() | 
					
						
						|  | ptq_config, collect_config = configs[""], configs["collect"] | 
					
						
						|  | assert isinstance(ptq_config, DiffusionPtqRunConfig) | 
					
						
						|  | assert isinstance(collect_config, CollectConfig) | 
					
						
						|  | if len(unused_cfgs) > 0: | 
					
						
						|  | print(f"Warning: unused configurations {unused_cfgs}") | 
					
						
						|  | if unused_args is not None: | 
					
						
						|  | print(f"Warning: unused arguments {unused_args}") | 
					
						
						|  | assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}" | 
					
						
						|  |  | 
					
						
						|  | collect_dirpath = os.path.join( | 
					
						
						|  | collect_config.root, | 
					
						
						|  | str(ptq_config.pipeline.dtype), | 
					
						
						|  | ptq_config.pipeline.name, | 
					
						
						|  | ptq_config.eval.protocol, | 
					
						
						|  | collect_config.dataset_name, | 
					
						
						|  | f"s{collect_config.num_samples}", | 
					
						
						|  | ) | 
					
						
						|  | print(f"Saving caches to {collect_dirpath}") | 
					
						
						|  |  | 
					
						
						|  | dataset = get_dataset( | 
					
						
						|  | collect_config.data_path, | 
					
						
						|  | max_dataset_size=collect_config.num_samples, | 
					
						
						|  | return_gt=ptq_config.pipeline.task in ["canny-to-image"], | 
					
						
						|  | repeat=1, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | ptq_config.output.root = collect_dirpath | 
					
						
						|  | os.makedirs(ptq_config.output.root, exist_ok=True) | 
					
						
						|  | collect(ptq_config, dataset=dataset) | 
					
						
						|  |  |