Commit
·
0012f0c
1
Parent(s):
004d6aa
Refactor argument parsing and improve error handling in model loading
Browse filesUpdate .gitignore to exclude PNG files and remove unnecessary output images
- .gitignore +1 -0
- inference.py +2 -4
- models/modeling_dit.py +1 -1
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
*.pyc
|
|
|
|
1 |
*.pyc
|
2 |
+
*.png
|
inference.py
CHANGED
@@ -51,7 +51,6 @@ def main(args):
|
|
51 |
|
52 |
# Load model configuration and model
|
53 |
config = MMDiTConfig.from_json_file(args.model_config)
|
54 |
-
config.vae_type = args.vae_type # VAE overriding
|
55 |
config.height = args.resolution
|
56 |
config.width = args.resolution
|
57 |
|
@@ -135,11 +134,10 @@ if __name__ == "__main__":
|
|
135 |
# parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
|
136 |
parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
|
137 |
parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
|
138 |
-
parser.add_argument("--batch-size", type=int, default=32)
|
139 |
-
parser.add_argument("--streaming", action="store_true")
|
140 |
parser.add_argument("--noisy-pad", action="store_true")
|
141 |
parser.add_argument("--zero-masking", action="store_true")
|
142 |
-
parser.add_argument("--vae-type", type=str, default="SD3", help="Type of VAE")
|
143 |
parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
|
144 |
parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
|
145 |
parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
|
|
|
51 |
|
52 |
# Load model configuration and model
|
53 |
config = MMDiTConfig.from_json_file(args.model_config)
|
|
|
54 |
config.height = args.resolution
|
55 |
config.width = args.resolution
|
56 |
|
|
|
134 |
# parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
|
135 |
parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
|
136 |
parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
|
137 |
+
parser.add_argument("--batch-size", type=int, default=32,help="Batch size for image generation")
|
138 |
+
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode for intermediate steps")
|
139 |
parser.add_argument("--noisy-pad", action="store_true")
|
140 |
parser.add_argument("--zero-masking", action="store_true")
|
|
|
141 |
parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
|
142 |
parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
|
143 |
parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
|
models/modeling_dit.py
CHANGED
@@ -13,7 +13,7 @@ try:
|
|
13 |
MotifRMSNorm = motif_ops.T5LayerNorm
|
14 |
ScaledDotProductAttention = None
|
15 |
MotifFlashAttention = motif_ops.flash_attention
|
16 |
-
except
|
17 |
MotifRMSNorm = None
|
18 |
ScaledDotProductAttention = None
|
19 |
MotifFlashAttention = None
|
|
|
13 |
MotifRMSNorm = motif_ops.T5LayerNorm
|
14 |
ScaledDotProductAttention = None
|
15 |
MotifFlashAttention = motif_ops.flash_attention
|
16 |
+
except Exception: # if motif_ops is not available
|
17 |
MotifRMSNorm = None
|
18 |
ScaledDotProductAttention = None
|
19 |
MotifFlashAttention = None
|