Fix model loading
Browse files- quant_sdxl/quant_sdxl.py +4 -1
quant_sdxl/quant_sdxl.py
CHANGED
|
@@ -102,7 +102,10 @@ def main(args):
|
|
| 102 |
|
| 103 |
# Load model from float checkpoint
|
| 104 |
print(f"Loading model from {args.model}...")
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
print(f"Model loaded from {args.model}.")
|
| 107 |
|
| 108 |
# Move model to target device
|
|
|
|
| 102 |
|
| 103 |
# Load model from float checkpoint
|
| 104 |
print(f"Loading model from {args.model}...")
|
| 105 |
+
variant = 'fp16' if dtype == torch.float16 else None
|
| 106 |
+
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype, variant=variant, use_safetensors=True)
|
| 107 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 108 |
+
pipe.vae.config.force_upcast=True
|
| 109 |
print(f"Model loaded from {args.model}.")
|
| 110 |
|
| 111 |
# Move model to target device
|