|
import sys |
|
|
|
sys.path.append(".") |
|
import argparse |
|
import os |
|
import random |
|
import numpy as np |
|
import torch |
|
import pickle |
|
from configs.configuration_mmdit import MMDiTConfig |
|
from models.modeling_motifimage import MotifImage |
|
|
|
from safetensors.torch import load_file |
|
|
|
|
|
|
|
def load_sharded_model(model_index_path): |
|
""" |
|
Loads a sharded model from a safetensors index file. |
|
|
|
Args: |
|
model_index_path (str): Path to the model.safetensors.index.json file. |
|
""" |
|
with open(model_index_path, 'r') as f: |
|
index = json.load(f) |
|
|
|
sharded_state_dicts = {} |
|
folder = os.path.dirname(model_index_path) |
|
for weight_name, filename in index["weight_map"].items(): |
|
if filename not in sharded_state_dicts: |
|
sharded_state_dicts[filename] = load_file(os.path.join(folder, filename), device="cpu") |
|
|
|
merged_state_dict = {} |
|
for weight_name, filename in index["weight_map"].items(): |
|
merged_state_dict[weight_name] = sharded_state_dicts[filename][weight_name] |
|
|
|
merged_state_dict = {k: v for k, v in merged_state_dict.items() if 'dit' in k} |
|
return merged_state_dict |
|
|
|
|
|
def main(args): |
|
|
|
|
|
if not os.path.isfile(args.prompt_file): |
|
print(f"Error: The prompt file '{args.prompt_file}' does not exist.") |
|
sys.exit(1) |
|
|
|
|
|
with open(args.prompt_file) as f: |
|
prompts = [prompt.rstrip() for prompt in f.readlines()] |
|
|
|
|
|
config = MMDiTConfig.from_json_file(args.model_config) |
|
config.height = args.resolution |
|
config.width = args.resolution |
|
|
|
model = MotifImage(config) |
|
|
|
|
|
try: |
|
ema_instance = torch.load(args.model_ckpt, weights_only=False) |
|
ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k} |
|
except pickle.UnpicklingError as e: |
|
print(f"Error loading checkpoint: {e}") |
|
ema_instance = load_file(args.model_ckpt) |
|
ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k} |
|
|
|
if "ema_model.bin" in args.model_ckpt: |
|
|
|
for param, ema_param in zip(model.parameters(), ema_instance["shadow_params"]): |
|
param.data.copy_(ema_param.data) |
|
else: |
|
|
|
model.load_state_dict(ema_instance) |
|
|
|
model = model.cuda() |
|
model = model.to(dtype=torch.bfloat16) |
|
model.eval() |
|
|
|
|
|
guidance_scales = args.guidance_scales if args.guidance_scales else [5.0] |
|
|
|
|
|
if isinstance(args.seed, int): |
|
seeds = [args.seed] |
|
else: |
|
seeds = args.seed |
|
|
|
for seed in seeds: |
|
for guidance_scale in guidance_scales: |
|
|
|
output_dir = os.path.join(args.output_dir, f"seed_{seed}", f"scale_{guidance_scale}") |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
for i in range(0, len(prompts), args.batch_size): |
|
|
|
torch.manual_seed(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
|
|
batch_prompts = prompts[i : i + args.batch_size] |
|
imgs = model.sample( |
|
batch_prompts, |
|
args.steps, |
|
resolution=[args.resolution, args.resolution], |
|
guidance_scale=guidance_scale, |
|
step_scaling=1.0, |
|
use_linear_quadratic_schedule=True, |
|
linear_quadratic_emulating_steps=250, |
|
get_intermediate_steps=args.streaming, |
|
noisy_pad=args.noisy_pad, |
|
zero_masking=args.zero_masking, |
|
) |
|
if args.streaming: |
|
imgs, intermediate_imgs = imgs |
|
if isinstance(intermediate_imgs, list): |
|
for j, intermediate_img in enumerate(intermediate_imgs): |
|
for k, img in enumerate(intermediate_img): |
|
img.save(os.path.join(output_dir, f"{i + k:03d}_{j:03d}_intermediate.png")) |
|
else: |
|
|
|
intermediate_imgs.save(os.path.join(output_dir, f"{i:03d}_0_intermediate.png")) |
|
for j, img in enumerate(imgs): |
|
img.save(os.path.join(output_dir, f"{i + j:03d}_check.png")) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Generate images with model") |
|
parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file") |
|
parser.add_argument("--model-ckpt", type=str, required=True, help="Path to the model checkpoint file") |
|
parser.add_argument( |
|
"--seed", type=int, nargs="*", default=[7777], help="Random seed(s) for reproducibility (can provide multiple)" |
|
) |
|
|
|
parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation") |
|
parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images") |
|
parser.add_argument("--batch-size", type=int, default=32,help="Batch size for image generation") |
|
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode for intermediate steps") |
|
parser.add_argument("--noisy-pad", action="store_true") |
|
parser.add_argument("--zero-masking", action="store_true") |
|
parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file") |
|
parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales") |
|
parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images") |
|
parser.add_argument("--lora-ckpt", action="store_true") |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|