Update README.md
Browse files
README.md
CHANGED
@@ -293,7 +293,98 @@ Potential fix: app.diffusion.nn.struct.py
|
|
293 |
Potential Fix: app.diffusion.dataset.collect.calib.py
|
294 |
|
295 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
```
|
298 |
|
299 |
References
|
|
|
293 |
Potential Fix: app.diffusion.dataset.collect.calib.py
|
294 |
|
295 |
```python
|
296 |
+
def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
|
297 |
+
samples_dirpath = os.path.join(config.output.root, "samples")
|
298 |
+
caches_dirpath = os.path.join(config.output.root, "caches")
|
299 |
+
os.makedirs(samples_dirpath, exist_ok=True)
|
300 |
+
os.makedirs(caches_dirpath, exist_ok=True)
|
301 |
+
caches = []
|
302 |
+
|
303 |
+
pipeline = config.pipeline.build()
|
304 |
+
model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
|
305 |
+
assert isinstance(model, nn.Module)
|
306 |
+
model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)
|
307 |
+
|
308 |
+
batch_size = config.eval.batch_size
|
309 |
+
print(f"In total {len(dataset)} samples")
|
310 |
+
print(f"Evaluating with batch size {batch_size}")
|
311 |
+
pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
|
312 |
+
for batch in tqdm(
|
313 |
+
dataset.iter(batch_size=batch_size, drop_last_batch=False),
|
314 |
+
desc="Data",
|
315 |
+
leave=False,
|
316 |
+
dynamic_ncols=True,
|
317 |
+
total=(len(dataset) + batch_size - 1) // batch_size,
|
318 |
+
):
|
319 |
+
filenames = batch["filename"]
|
320 |
+
prompts = batch["prompt"]
|
321 |
+
seeds = [hash_str_to_int(name) for name in filenames]
|
322 |
+
generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
|
323 |
+
pipeline_kwargs = config.eval.get_pipeline_kwargs()
|
324 |
+
|
325 |
+
task = config.pipeline.task
|
326 |
+
control_root = config.eval.control_root
|
327 |
+
if task in ["canny-to-image", "depth-to-image", "inpainting"]:
|
328 |
+
controls = get_control(
|
329 |
+
task,
|
330 |
+
batch["image"],
|
331 |
+
names=batch["filename"],
|
332 |
+
data_root=os.path.join(
|
333 |
+
control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
|
334 |
+
),
|
335 |
+
)
|
336 |
+
if task == "inpainting":
|
337 |
+
pipeline_kwargs["image"] = controls[0]
|
338 |
+
pipeline_kwargs["mask_image"] = controls[1]
|
339 |
+
else:
|
340 |
+
pipeline_kwargs["control_image"] = controls
|
341 |
+
|
342 |
+
# Handle meta tensors by moving individual components
|
343 |
+
try:
|
344 |
+
pipeline = pipeline.to("cuda")
|
345 |
+
except NotImplementedError:
|
346 |
+
# Move individual pipeline components that have to_empty method
|
347 |
+
if hasattr(pipeline, 'transformer') and pipeline.transformer is not None:
|
348 |
+
try:
|
349 |
+
pipeline.transformer = pipeline.transformer.to("cuda")
|
350 |
+
except NotImplementedError:
|
351 |
+
pipeline.transformer = pipeline.transformer.to_empty(device="cuda")
|
352 |
+
|
353 |
+
if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
|
354 |
+
try:
|
355 |
+
pipeline.text_encoder = pipeline.text_encoder.to("cuda")
|
356 |
+
except NotImplementedError:
|
357 |
+
pipeline.text_encoder = pipeline.text_encoder.to_empty(device="cuda")
|
358 |
+
|
359 |
+
if hasattr(pipeline, 'text_encoder_2') and pipeline.text_encoder_2 is not None:
|
360 |
+
try:
|
361 |
+
pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda")
|
362 |
+
except NotImplementedError:
|
363 |
+
pipeline.text_encoder_2 = pipeline.text_encoder_2.to_empty(device="cuda")
|
364 |
+
|
365 |
+
if hasattr(pipeline, 'vae') and pipeline.vae is not None:
|
366 |
+
try:
|
367 |
+
pipeline.vae = pipeline.vae.to("cuda")
|
368 |
+
except NotImplementedError:
|
369 |
+
pipeline.vae = pipeline.vae.to_empty(device="cuda")
|
370 |
+
|
371 |
result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images
|
372 |
+
num_guidances = (len(caches) // batch_size) // config.eval.num_steps
|
373 |
+
num_steps = len(caches) // (batch_size * num_guidances)
|
374 |
+
assert (
|
375 |
+
len(caches) == batch_size * num_steps * num_guidances
|
376 |
+
), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
|
377 |
+
for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
|
378 |
+
image.save(os.path.join(samples_dirpath, f"{filename}.png"))
|
379 |
+
for s in range(num_steps):
|
380 |
+
for g in range(num_guidances):
|
381 |
+
c = caches[s * batch_size * num_guidances + g * batch_size + j]
|
382 |
+
c["filename"] = filename
|
383 |
+
c["step"] = s
|
384 |
+
c["guidance"] = g
|
385 |
+
c = tree_map(lambda x: process(x), c)
|
386 |
+
torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt"))
|
387 |
+
caches.clear()
|
388 |
```
|
389 |
|
390 |
References
|