lym00 commited on
Commit
b5508c9
·
verified ·
1 Parent(s): 0a3f3cb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +91 -0
README.md CHANGED
@@ -293,7 +293,98 @@ Potential fix: app.diffusion.nn.struct.py
293
  Potential Fix: app.diffusion.dataset.collect.calib.py
294
 
295
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  ```
298
 
299
  References
 
293
  Potential Fix: app.diffusion.dataset.collect.calib.py
294
 
295
  ```python
296
+ def collect(config: DiffusionPtqRunConfig, dataset: datasets.Dataset):
297
+ samples_dirpath = os.path.join(config.output.root, "samples")
298
+ caches_dirpath = os.path.join(config.output.root, "caches")
299
+ os.makedirs(samples_dirpath, exist_ok=True)
300
+ os.makedirs(caches_dirpath, exist_ok=True)
301
+ caches = []
302
+
303
+ pipeline = config.pipeline.build()
304
+ model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
305
+ assert isinstance(model, nn.Module)
306
+ model.register_forward_hook(CollectHook(caches=caches), with_kwargs=True)
307
+
308
+ batch_size = config.eval.batch_size
309
+ print(f"In total {len(dataset)} samples")
310
+ print(f"Evaluating with batch size {batch_size}")
311
+ pipeline.set_progress_bar_config(desc="Sampling", leave=False, dynamic_ncols=True, position=1)
312
+ for batch in tqdm(
313
+ dataset.iter(batch_size=batch_size, drop_last_batch=False),
314
+ desc="Data",
315
+ leave=False,
316
+ dynamic_ncols=True,
317
+ total=(len(dataset) + batch_size - 1) // batch_size,
318
+ ):
319
+ filenames = batch["filename"]
320
+ prompts = batch["prompt"]
321
+ seeds = [hash_str_to_int(name) for name in filenames]
322
+ generators = [torch.Generator(device=pipeline.device).manual_seed(seed) for seed in seeds]
323
+ pipeline_kwargs = config.eval.get_pipeline_kwargs()
324
+
325
+ task = config.pipeline.task
326
+ control_root = config.eval.control_root
327
+ if task in ["canny-to-image", "depth-to-image", "inpainting"]:
328
+ controls = get_control(
329
+ task,
330
+ batch["image"],
331
+ names=batch["filename"],
332
+ data_root=os.path.join(
333
+ control_root, collect_config.dataset_name, f"{dataset.config_name}-{config.eval.num_samples}"
334
+ ),
335
+ )
336
+ if task == "inpainting":
337
+ pipeline_kwargs["image"] = controls[0]
338
+ pipeline_kwargs["mask_image"] = controls[1]
339
+ else:
340
+ pipeline_kwargs["control_image"] = controls
341
+
342
+ # Handle meta tensors by moving individual components
343
+ try:
344
+ pipeline = pipeline.to("cuda")
345
+ except NotImplementedError:
346
+ # Move individual pipeline components that have to_empty method
347
+ if hasattr(pipeline, 'transformer') and pipeline.transformer is not None:
348
+ try:
349
+ pipeline.transformer = pipeline.transformer.to("cuda")
350
+ except NotImplementedError:
351
+ pipeline.transformer = pipeline.transformer.to_empty(device="cuda")
352
+
353
+ if hasattr(pipeline, 'text_encoder') and pipeline.text_encoder is not None:
354
+ try:
355
+ pipeline.text_encoder = pipeline.text_encoder.to("cuda")
356
+ except NotImplementedError:
357
+ pipeline.text_encoder = pipeline.text_encoder.to_empty(device="cuda")
358
+
359
+ if hasattr(pipeline, 'text_encoder_2') and pipeline.text_encoder_2 is not None:
360
+ try:
361
+ pipeline.text_encoder_2 = pipeline.text_encoder_2.to("cuda")
362
+ except NotImplementedError:
363
+ pipeline.text_encoder_2 = pipeline.text_encoder_2.to_empty(device="cuda")
364
+
365
+ if hasattr(pipeline, 'vae') and pipeline.vae is not None:
366
+ try:
367
+ pipeline.vae = pipeline.vae.to("cuda")
368
+ except NotImplementedError:
369
+ pipeline.vae = pipeline.vae.to_empty(device="cuda")
370
+
371
  result_images = pipeline(prompt=prompts, generator=generators, **pipeline_kwargs).images
372
+ num_guidances = (len(caches) // batch_size) // config.eval.num_steps
373
+ num_steps = len(caches) // (batch_size * num_guidances)
374
+ assert (
375
+ len(caches) == batch_size * num_steps * num_guidances
376
+ ), f"Unexpected number of caches: {len(caches)} != {batch_size} * {config.eval.num_steps} * {num_guidances}"
377
+ for j, (filename, image) in enumerate(zip(filenames, result_images, strict=True)):
378
+ image.save(os.path.join(samples_dirpath, f"{filename}.png"))
379
+ for s in range(num_steps):
380
+ for g in range(num_guidances):
381
+ c = caches[s * batch_size * num_guidances + g * batch_size + j]
382
+ c["filename"] = filename
383
+ c["step"] = s
384
+ c["guidance"] = g
385
+ c = tree_map(lambda x: process(x), c)
386
+ torch.save(c, os.path.join(caches_dirpath, f"{filename}-{s:05d}-{g}.pt"))
387
+ caches.clear()
388
  ```
389
 
390
  References