lym00 commited on
Commit
c65ecd1
·
verified ·
1 Parent(s): a9e8301

Upload calib.py

Browse files
Files changed (1) hide show
  1. calib.py +174 -0
calib.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Collect calibration dataset."""
3
+
4
+ import os
5
+ from dataclasses import dataclass
6
+
7
+ import datasets
8
+ import torch
9
+ from omniconfig import configclass
10
+ from torch import nn
11
+ from tqdm import tqdm
12
+
13
+ from deepcompressor.app.diffusion.config import DiffusionPtqRunConfig
14
+ from deepcompressor.utils.common import hash_str_to_int, tree_map
15
+
16
+ from ...utils import get_control
17
+ from ..data import get_dataset
18
+ from .utils import CollectHook
19
+
20
+
21
+ def process(x: torch.Tensor) -> torch.Tensor:
22
+ dtype = x.dtype
23
+ return torch.from_numpy(x.float().numpy()).to(dtype)
24
+
25
+
26
+ def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
27
+ samples_dirpath = os.path.join(config.output.root, "samples")
28
+ caches_dirpath = os.path.join(config.output.root, "caches")
29
+ os.makedirs(samples_dirpath, exist_ok=True)
30
+ os.makedirs(caches_dirpath, exist_ok=True)
31
+ caches = []
32
+
33
+ pipeline = config.pipeline.build()
34
+ model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
35
+ assert isinstance(model, nn.Module)
36
+ model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)
37
+
38
+ batch_size = config.eval.batch_size
39
+ print(f"In total {len(dataset)} samples")
40
+ print(f"Evaluating with batch size {batch_size}")
41
+ pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
42
+ for batch in tqdm(
43
+ dataset.iter(batch_size=batch_size, drop_last_batch=False),
44
+ desc="Data",
45
+ leave=False,
46
+ dynamic_ncols=True,
47
+ total=(len(dataset) + batch_size - 1) // batch_size,
48
+ ):
49
+ filenames = batch["filename"]
50
+ prompts = batch["prompt"]
51
+ seeds = [hash_str_to_int(name) for name in filenames]
52
+ generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
53
+ pipeline_kwargs = config.eval.get_pipeline_kwargs()
54
+
55
+ task = config.pipeline.task
56
+ control_root = config.eval.control_root
57
+ if task in ["canny-to-image", "depth-to-image", "inpainting"]:
58
+ controls = get_control(
59
+ task,
60
+ batch["image"],
61
+ names=batch["filename"],
62
+ data_root=os.path.join(
63
+ control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
64
+ ),
65
+ )
66
+ if task == "inpainting":
67
+ pipeline_kwargs["image"] = controls[0]
68
+ pipeline_kwargs["mask_image"] = controls[1]
69
+ else:
70
+ pipeline_kwargs["control_image"] = controls
71
+
72
+ # Handle meta tensors by moving individual components
73
+ try:
74
+ pipeline = pipeline.to("cuda")
75
+ except NotImplementedError:
76
+ # Move individual pipeline components that have to_empty method
77
+ if hasattr(pipeline, 'transformer') and pipeline.transformer is not None:
78
+ try:
79
+ pipeline.transformer = pipeline.transformer.to("cuda")
80
+ except NotImplementedError:
81
+ pipeline.transformer = pipeline.transformer.to_empty(device="cuda")
82
+
83
+ if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
84
+ try:
85
+ pipeline.text_encoder = pipeline.text_encoder.to("cuda")
86
+ except NotImplementedError:
87
+ pipeline.text_encoder = pipeline.text_encoder.to_empty(device="cuda")
88
+
89
+ if hasattr(pipeline, 'text_encoder_2') and pipeline.text_encoder_2 is not None:
90
+ try:
91
+ pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda")
92
+ except NotImplementedError:
93
+ pipeline.text_encoder_2 = pipeline.text_encoder_2.to_empty(device="cuda")
94
+
95
+ if hasattr(pipeline, 'vae') and pipeline.vae is not None:
96
+ try:
97
+ pipeline.vae = pipeline.vae.to("cuda")
98
+ except NotImplementedError:
99
+ pipeline.vae = pipeline.vae.to_empty(device="cuda")
100
+
101
+ result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images
102
+ num_guidances = (len(caches) // batch_size) // config.eval.num_steps
103
+ num_steps = len(caches) // (batch_size * num_guidances)
104
+ assert (
105
+ len(caches) == batch_size * num_steps * num_guidances
106
+ ), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
107
+ for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
108
+ image.save(os.path.join(samples_dirpath, f"{filename}.png"))
109
+ for s in range(num_steps):
110
+ for g in range(num_guidances):
111
+ c = caches[s * batch_size * num_guidances + g * batch_size + j]
112
+ c["filename"] = filename
113
+ c["step"] = s
114
+ c["guidance"] = g
115
+ c = tree_map(lambda x: process(x), c)
116
+ torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt"))
117
+ caches.clear()
118
+
119
+
120
+ @configclass
121
+ @dataclass
122
+ class CollectConfig:
123
+ """Configuration for collecting calibration dataset.
124
+
125
+ Args:
126
+ root (`str`, *optional*, defaults to `"datasets"`):
127
+ Root directory to save the collected dataset.
128
+ dataset_name (`str`, *optional*, defaults to `"qdiff"`):
129
+ Name of the collected dataset.
130
+ prompt_path (`str`, *optional*, defaults to `"prompts/qdiff.yaml"`):
131
+ Path to the prompt file.
132
+ num_samples (`int`, *optional*, defaults to `128`):
133
+ Number of samples to collect.
134
+ """
135
+
136
+ root: str = "datasets"
137
+ dataset_name: str = "qdiff"
138
+ data_path: str = "prompts/qdiff.yaml"
139
+ num_samples: int = 128
140
+
141
+
142
+ if __name__ == "__main__":
143
+ parser = DiffusionPtqRunConfig.get_parser()
144
+ parser.add_config(CollectConfig, scope="collect", prefix="collect")
145
+ configs, _, unused_cfgs, unused_args, unknown_args = parser.parse_known_args()
146
+ ptq_config, collect_config = configs[""], configs["collect"]
147
+ assert isinstance(ptq_config, DiffusionPtqRunConfig)
148
+ assert isinstance(collect_config, CollectConfig)
149
+ if len(unused_cfgs) > 0:
150
+ print(f"Warning: unused configurations {unused_cfgs}")
151
+ if unused_args is not None:
152
+ print(f"Warning: unused arguments {unused_args}")
153
+ assert len(unknown_args) == 0, f"Unknown arguments: {unknown_args}"
154
+
155
+ collect_dirpath = os.path.join(
156
+ collect_config.root,
157
+ str(ptq_config.pipeline.dtype),
158
+ ptq_config.pipeline.name,
159
+ ptq_config.eval.protocol,
160
+ collect_config.dataset_name,
161
+ f"s{collect_config.num_samples}",
162
+ )
163
+ print(f"Saving caches to {collect_dirpath}")
164
+
165
+ dataset = get_dataset(
166
+ collect_config.data_path,
167
+ max_dataset_size=collect_config.num_samples,
168
+ return_gt=ptq_config.pipeline.task in ["canny-to-image"],
169
+ repeat=1,
170
+ )
171
+
172
+ ptq_config.output.root = collect_dirpath
173
+ os.makedirs(ptq_config.output.root, exist_ok=True)
174
+ collect(ptq_config, dataset=dataset)