Upload config.py
Browse files
config.py
CHANGED
@@ -370,18 +370,18 @@ class DiffusionPipelineConfig:
|
|
370 |
print(">>> PIPELINE TYPE:", type(pipeline))
|
371 |
|
372 |
# Try to move each component using .to_empty()
|
373 |
-
for name in ["unet", "transformer", "vae", "text_encoder"]:
|
374 |
-
module = getattr(pipeline, name, None)
|
375 |
-
if isinstance(module, torch.nn.Module):
|
376 |
-
try:
|
377 |
-
print(f">>> Moving {name} to {device} using to_empty()")
|
378 |
-
module.to_empty(device)
|
379 |
-
except Exception as e:
|
380 |
-
print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
|
381 |
-
try:
|
382 |
-
print(f">>> Falling back to {name}.to({device})")
|
383 |
-
module.to(device)
|
384 |
-
except Exception as ee:
|
385 |
print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
|
386 |
|
387 |
# Identify main model (for patching)
|
|
|
370 |
print(">>> PIPELINE TYPE:", type(pipeline))
|
371 |
|
372 |
# Try to move each component using .to_empty()
|
373 |
+
for name in ["unet", "transformer", "vae", "text_encoder"]:
|
374 |
+
module = getattr(pipeline, name, None)
|
375 |
+
if isinstance(module, torch.nn.Module):
|
376 |
+
try:
|
377 |
+
print(f">>> Moving {name} to {device} using to_empty()")
|
378 |
+
module.to_empty(device=device) # Use keyword argument
|
379 |
+
except Exception as e:
|
380 |
+
print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
|
381 |
+
try:
|
382 |
+
print(f">>> Falling back to {name}.to({device})")
|
383 |
+
module.to(device)
|
384 |
+
except Exception as ee:
|
385 |
print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
|
386 |
|
387 |
# Identify main model (for patching)
|