lym00 commited on
Commit
b1bdeca
·
verified ·
1 Parent(s): 0c1b5cc

Upload config.py

Browse files
Files changed (1) hide show
  1. config.py +12 -12
config.py CHANGED
@@ -370,18 +370,18 @@ class DiffusionPipelineConfig:
370
  print(">>> PIPELINE TYPE:", type(pipeline))
371
 
372
  # Try to move each component using .to_empty()
373
- for name in ["unet", "transformer", "vae", "text_encoder"]:
374
- module = getattr(pipeline, name, None)
375
- if isinstance(module, torch.nn.Module):
376
- try:
377
- print(f">>> Moving {name} to {device} using to_empty()")
378
- module.to_empty(device)
379
- except Exception as e:
380
- print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
381
- try:
382
- print(f">>> Falling back to {name}.to({device})")
383
- module.to(device)
384
- except Exception as ee:
385
  print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
386
 
387
  # Identify main model (for patching)
 
370
  print(">>> PIPELINE TYPE:", type(pipeline))
371
 
372
  # Try to move each component using .to_empty()
373
+ for name in ["unet", "transformer", "vae", "text_encoder"]:
374
+ module = getattr(pipeline, name, None)
375
+ if isinstance(module, torch.nn.Module):
376
+ try:
377
+ print(f">>> Moving {name} to {device} using to_empty()")
378
+ module.to_empty(device=device) # Use keyword argument
379
+ except Exception as e:
380
+ print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
381
+ try:
382
+ print(f">>> Falling back to {name}.to({device})")
383
+ module.to(device)
384
+ except Exception as ee:
385
  print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
386
 
387
  # Identify main model (for patching)