Commit
·
34b0078
1
Parent(s):
62e0e7b
Feat (script): Added option to validate on MLPerf validation set & to load a pre-quantized checkpoint.
Browse files- quant_sdxl/quant_sdxl.py +62 -24
quant_sdxl/quant_sdxl.py
CHANGED
|
@@ -32,11 +32,12 @@ from brevitas.graph.quantize import layerwise_quantize
|
|
| 32 |
from brevitas.inject.enum import StatsOp
|
| 33 |
from brevitas.nn.equalized_layer import EqualizedModule
|
| 34 |
from brevitas.utils.torch_utils import KwargsForwardHook
|
|
|
|
| 35 |
|
| 36 |
from brevitas_examples.common.parse_utils import add_bool_arg
|
| 37 |
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
|
| 38 |
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
|
| 39 |
-
|
| 40 |
|
| 41 |
TEST_SEED = 123456
|
| 42 |
torch.manual_seed(TEST_SEED)
|
|
@@ -125,6 +126,20 @@ def main(args):
|
|
| 125 |
raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
|
| 126 |
|
| 127 |
pipe.set_progress_bar_config(disable=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
with activation_equalization_mode(
|
| 129 |
pipe.unet,
|
| 130 |
alpha=args.act_eq_alpha,
|
|
@@ -138,7 +153,7 @@ def main(args):
|
|
| 138 |
total_steps = args.calibration_steps
|
| 139 |
run_val_inference(
|
| 140 |
pipe,
|
| 141 |
-
calibration_prompts,
|
| 142 |
total_steps=total_steps,
|
| 143 |
test_latents=latents,
|
| 144 |
guidance_scale=args.guidance_scale)
|
|
@@ -186,26 +201,32 @@ def main(args):
|
|
| 186 |
|
| 187 |
pipe.set_progress_bar_config(disable=True)
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
if args.export_target:
|
| 211 |
pipe.unet.to('cpu').to(dtype)
|
|
@@ -229,6 +250,18 @@ if __name__ == "__main__":
|
|
| 229 |
type=int,
|
| 230 |
default=500,
|
| 231 |
help='Number of prompts to use for calibration. Default: %(default)s')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
parser.add_argument(
|
| 233 |
'--checkpoint-name',
|
| 234 |
type=str,
|
|
@@ -237,11 +270,16 @@ if __name__ == "__main__":
|
|
| 237 |
'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.'
|
| 238 |
)
|
| 239 |
parser.add_argument(
|
| 240 |
-
'--
|
| 241 |
type=str,
|
| 242 |
default=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
help=
|
| 244 |
-
'
|
| 245 |
parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
|
| 246 |
parser.add_argument(
|
| 247 |
'--calibration-steps', type=float, default=8, help='Steps used during calibration')
|
|
|
|
| 32 |
from brevitas.inject.enum import StatsOp
|
| 33 |
from brevitas.nn.equalized_layer import EqualizedModule
|
| 34 |
from brevitas.utils.torch_utils import KwargsForwardHook
|
| 35 |
+
import brevitas.config as config
|
| 36 |
|
| 37 |
from brevitas_examples.common.parse_utils import add_bool_arg
|
| 38 |
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
|
| 39 |
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
|
| 40 |
+
from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid
|
| 41 |
|
| 42 |
TEST_SEED = 123456
|
| 43 |
torch.manual_seed(TEST_SEED)
|
|
|
|
| 126 |
raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
|
| 127 |
|
| 128 |
pipe.set_progress_bar_config(disable=True)
|
| 129 |
+
|
| 130 |
+
if args.load_checkpoint is not None:
|
| 131 |
+
with load_quant_model_mode(pipe.unet):
|
| 132 |
+
pipe = pipe.to('cpu')
|
| 133 |
+
print(f"Loading checkpoint: {args.load_checkpoint}... ", end="")
|
| 134 |
+
pipe.unet.load_state_dict(torch.load(args.load_checkpoint, map_location='cpu'))
|
| 135 |
+
print(f"Checkpoint loaded!")
|
| 136 |
+
pipe = pipe.to(args.device)
|
| 137 |
+
|
| 138 |
+
if args.load_checkpoint is not None:
|
| 139 |
+
# Don't run full activation equalization if we're loading a quantized checkpoint
|
| 140 |
+
num_ae_prompts = 2
|
| 141 |
+
else:
|
| 142 |
+
num_ae_prompts = len(calibration_prompts)
|
| 143 |
with activation_equalization_mode(
|
| 144 |
pipe.unet,
|
| 145 |
alpha=args.act_eq_alpha,
|
|
|
|
| 153 |
total_steps = args.calibration_steps
|
| 154 |
run_val_inference(
|
| 155 |
pipe,
|
| 156 |
+
calibration_prompts[:num_ae_prompts],
|
| 157 |
total_steps=total_steps,
|
| 158 |
test_latents=latents,
|
| 159 |
guidance_scale=args.guidance_scale)
|
|
|
|
| 201 |
|
| 202 |
pipe.set_progress_bar_config(disable=True)
|
| 203 |
|
| 204 |
+
if args.load_checkpoint is None:
|
| 205 |
+
print("Applying activation calibration")
|
| 206 |
+
with torch.no_grad(), calibration_mode(pipe.unet):
|
| 207 |
+
run_val_inference(
|
| 208 |
+
pipe,
|
| 209 |
+
calibration_prompts,
|
| 210 |
+
total_steps=args.calibration_steps,
|
| 211 |
+
test_latents=latents,
|
| 212 |
+
guidance_scale=args.guidance_scale)
|
| 213 |
+
|
| 214 |
+
print("Applying bias correction")
|
| 215 |
+
with torch.no_grad(), bias_correction_mode(pipe.unet):
|
| 216 |
+
run_val_inference(
|
| 217 |
+
pipe,
|
| 218 |
+
calibration_prompts,
|
| 219 |
+
total_steps=args.calibration_steps,
|
| 220 |
+
test_latents=latents,
|
| 221 |
+
guidance_scale=args.guidance_scale)
|
| 222 |
+
|
| 223 |
+
if args.checkpoint_name is not None:
|
| 224 |
+
torch.save(pipe.unet.state_dict(), os.path.join(output_dir, args.checkpoint_name))
|
| 225 |
+
|
| 226 |
+
# Perform inference
|
| 227 |
+
if args.validation_prompts > 0:
|
| 228 |
+
print(f"Computing validation accuracy")
|
| 229 |
+
compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.validation_prompts, output_dir)
|
| 230 |
|
| 231 |
if args.export_target:
|
| 232 |
pipe.unet.to('cpu').to(dtype)
|
|
|
|
| 250 |
type=int,
|
| 251 |
default=500,
|
| 252 |
help='Number of prompts to use for calibration. Default: %(default)s')
|
| 253 |
+
parser.add_argument(
|
| 254 |
+
'--validation-prompts',
|
| 255 |
+
type=int,
|
| 256 |
+
default=0,
|
| 257 |
+
help='Number of prompt to use for validation. Default: %(default)s')
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
'--path-to-coco',
|
| 260 |
+
type=str,
|
| 261 |
+
default=None,
|
| 262 |
+
help=
|
| 263 |
+
'Path to MLPerf compliant Coco dataset. Required when the --validation-prompts > 0 flag is set. Default: None'
|
| 264 |
+
)
|
| 265 |
parser.add_argument(
|
| 266 |
'--checkpoint-name',
|
| 267 |
type=str,
|
|
|
|
| 270 |
'Name to use to store the checkpoint in the output dir. If not provided, no checkpoint is saved.'
|
| 271 |
)
|
| 272 |
parser.add_argument(
|
| 273 |
+
'--load-checkpoint',
|
| 274 |
type=str,
|
| 275 |
default=None,
|
| 276 |
+
help='Path to checkpoint to load. If provided, PTQ techniques are skipped.')
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
'--path-to-latents',
|
| 279 |
+
type=str,
|
| 280 |
+
required=True,
|
| 281 |
help=
|
| 282 |
+
'Path to pre-defined latents.')
|
| 283 |
parser.add_argument('--guidance-scale', type=float, default=8., help='Guidance scale.')
|
| 284 |
parser.add_argument(
|
| 285 |
'--calibration-steps', type=float, default=8, help='Steps used during calibration')
|