beomgyu-kim commited on
Commit
d09c0e5
·
1 Parent(s): 6cd6a16

Add inference script and main model execution logic

Browse files
Files changed (2) hide show
  1. inference.py +149 -0
  2. run_inference.sh +10 -0
inference.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append(".")
4
+ import argparse
5
+ import os
6
+ import random
7
+ import numpy as np
8
+ import torch
9
+ import pickle
10
+ from configs.configuration_mmdit import MMDiTConfig
11
+ from models.modeling_motif_vision import MotifVision
12
+
13
+ from safetensors.torch import load_file
14
+ # from tools.motif_api import PromptRewriter
15
+ # from tools.nsfw_filtering import ContentFilter
16
+
17
+ def load_sharded_model(model_index_path):
18
+ """
19
+ Loads a sharded model from a safetensors index file.
20
+
21
+ Args:
22
+ model_index_path (str): Path to the model.safetensors.index.json file.
23
+ """
24
+ with open(model_index_path, 'r') as f:
25
+ index = json.load(f)
26
+
27
+ sharded_state_dicts = {}
28
+ folder = os.path.dirname(model_index_path)
29
+ for weight_name, filename in index["weight_map"].items():
30
+ if filename not in sharded_state_dicts:
31
+ sharded_state_dicts[filename] = load_file(os.path.join(folder, filename), device="cpu")
32
+
33
+ merged_state_dict = {}
34
+ for weight_name, filename in index["weight_map"].items():
35
+ merged_state_dict[weight_name] = sharded_state_dicts[filename][weight_name]
36
+
37
+ merged_state_dict = {k: v for k, v in merged_state_dict.items() if 'dit' in k}
38
+ return merged_state_dict
39
+
40
+
41
+ def main(args):
42
+
43
+ # Check if the prompt file exists
44
+ if not os.path.isfile(args.prompt_file):
45
+ print(f"Error: The prompt file '{args.prompt_file}' does not exist.")
46
+ sys.exit(1)
47
+
48
+ # List of prompts
49
+ with open(args.prompt_file) as f:
50
+ prompts = [prompt.rstrip() for prompt in f.readlines()]
51
+
52
+ # Load model configuration and model
53
+ config = MMDiTConfig.from_json_file(args.model_config)
54
+ config.vae_type = args.vae_type # VAE overriding
55
+ config.height = args.resolution
56
+ config.width = args.resolution
57
+
58
+ model = MotifVision(config)
59
+
60
+ # Load checkpoint
61
+ try:
62
+ ema_instance = torch.load(args.model_ckpt, weights_only=False)
63
+ ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
64
+ except pickle.UnpicklingError as e:
65
+ print(f"Error loading checkpoint: {e}")
66
+ ema_instance = load_file(args.model_ckpt)
67
+ ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
68
+
69
+ if "ema_model.bin" in args.model_ckpt:
70
+ # EMA checkpoint loading
71
+ for param, ema_param in zip(model.parameters(), ema_instance["shadow_params"]):
72
+ param.data.copy_(ema_param.data)
73
+ else:
74
+ # Non-EMA checkpoint loading
75
+ model.load_state_dict(ema_instance)
76
+
77
+ model = model.cuda()
78
+ model = model.to(dtype=torch.bfloat16)
79
+ model.eval()
80
+
81
+ # Use guidance scales from args or set default
82
+ guidance_scales = args.guidance_scales if args.guidance_scales else [5.0]
83
+
84
+ # If a single seed is passed without nargs, wrap it in a list
85
+ if isinstance(args.seed, int):
86
+ seeds = [args.seed]
87
+ else:
88
+ seeds = args.seed
89
+
90
+ for seed in seeds:
91
+ for guidance_scale in guidance_scales:
92
+ # Output directory structure: base_dir/seed_xxx/guidance_yyy
93
+ output_dir = os.path.join(args.output_dir, f"seed_{seed}", f"scale_{guidance_scale}")
94
+ os.makedirs(output_dir, exist_ok=True)
95
+ # Using for_loop when generating high-resolution images
96
+ for i in range(0, len(prompts), args.batch_size): # Process 1s prompts at a time
97
+ # Set random seeds
98
+ torch.manual_seed(seed)
99
+ random.seed(seed)
100
+ np.random.seed(seed)
101
+
102
+ batch_prompts = prompts[i : i + args.batch_size]
103
+ imgs = model.sample(
104
+ batch_prompts,
105
+ args.steps,
106
+ resolution=[args.resolution, args.resolution],
107
+ guidance_scale=guidance_scale,
108
+ step_scaling=1.0,
109
+ use_linear_quadratic_schedule=True,
110
+ linear_quadratic_emulating_steps=250,
111
+ get_intermediate_steps=args.streaming,
112
+ noisy_pad=args.noisy_pad,
113
+ zero_masking=args.zero_masking,
114
+ )
115
+ if args.streaming:
116
+ imgs, intermediate_imgs = imgs
117
+ if isinstance(intermediate_imgs, list):
118
+ for j, intermediate_img in enumerate(intermediate_imgs):
119
+ for k, img in enumerate(intermediate_img):
120
+ img.save(os.path.join(output_dir, f"{i + k:03d}_{j:03d}_intermediate.png"))
121
+ else:
122
+ # If intermediate_imgs is a single Image, save it directly
123
+ intermediate_imgs.save(os.path.join(output_dir, f"{i:03d}_0_intermediate.png"))
124
+ for j, img in enumerate(imgs):
125
+ img.save(os.path.join(output_dir, f"{i + j:03d}_check.png"))
126
+
127
+
128
+ if __name__ == "__main__":
129
+ parser = argparse.ArgumentParser(description="Generate images with model")
130
+ parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file")
131
+ parser.add_argument("--model-ckpt", type=str, required=True, help="Path to the model checkpoint file")
132
+ parser.add_argument(
133
+ "--seed", type=int, nargs="*", default=[7777], help="Random seed(s) for reproducibility (can provide multiple)"
134
+ )
135
+ # parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
136
+ parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
137
+ parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
138
+ parser.add_argument("--batch-size", type=int, default=32)
139
+ parser.add_argument("--streaming", action="store_true")
140
+ parser.add_argument("--noisy-pad", action="store_true")
141
+ parser.add_argument("--zero-masking", action="store_true")
142
+ parser.add_argument("--vae-type", type=str, default="SD3", help="Type of VAE")
143
+ parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
144
+ parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
145
+ parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
146
+ parser.add_argument("--lora-ckpt", action="store_true")
147
+ args = parser.parse_args()
148
+
149
+ main(args)
run_inference.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ python inference.py \
2
+ --model-config configs/mmdit_xlarge_hq.json \
3
+ --model-ckpt checkpoints/pytorch_model_fsdp.bin \
4
+ --seed 7777 \
5
+ --steps 30 \
6
+ --resolution 1024 \
7
+ --prompt-file prompts/sample_prompts.txt \
8
+ --guidance-scales 4.0 \
9
+ --output-dir outputs \
10
+ --batch-size 1