File size: 6,334 Bytes
d09c0e5 327b52c d09c0e5 327b52c d09c0e5 0012f0c d09c0e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import sys
sys.path.append(".")
import argparse
import os
import random
import numpy as np
import torch
import pickle
from configs.configuration_mmdit import MMDiTConfig
from models.modeling_motifimage import MotifImage
from safetensors.torch import load_file
# from tools.motif_api import PromptRewriter
# from tools.nsfw_filtering import ContentFilter
def load_sharded_model(model_index_path):
"""
Loads a sharded model from a safetensors index file.
Args:
model_index_path (str): Path to the model.safetensors.index.json file.
"""
with open(model_index_path, 'r') as f:
index = json.load(f)
sharded_state_dicts = {}
folder = os.path.dirname(model_index_path)
for weight_name, filename in index["weight_map"].items():
if filename not in sharded_state_dicts:
sharded_state_dicts[filename] = load_file(os.path.join(folder, filename), device="cpu")
merged_state_dict = {}
for weight_name, filename in index["weight_map"].items():
merged_state_dict[weight_name] = sharded_state_dicts[filename][weight_name]
merged_state_dict = {k: v for k, v in merged_state_dict.items() if 'dit' in k}
return merged_state_dict
def main(args):
# Check if the prompt file exists
if not os.path.isfile(args.prompt_file):
print(f"Error: The prompt file '{args.prompt_file}' does not exist.")
sys.exit(1)
# List of prompts
with open(args.prompt_file) as f:
prompts = [prompt.rstrip() for prompt in f.readlines()]
# Load model configuration and model
config = MMDiTConfig.from_json_file(args.model_config)
config.height = args.resolution
config.width = args.resolution
model = MotifImage(config)
# Load checkpoint
try:
ema_instance = torch.load(args.model_ckpt, weights_only=False)
ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
except pickle.UnpicklingError as e:
print(f"Error loading checkpoint: {e}")
ema_instance = load_file(args.model_ckpt)
ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
if "ema_model.bin" in args.model_ckpt:
# EMA checkpoint loading
for param, ema_param in zip(model.parameters(), ema_instance["shadow_params"]):
param.data.copy_(ema_param.data)
else:
# Non-EMA checkpoint loading
model.load_state_dict(ema_instance)
model = model.cuda()
model = model.to(dtype=torch.bfloat16)
model.eval()
# Use guidance scales from args or set default
guidance_scales = args.guidance_scales if args.guidance_scales else [5.0]
# If a single seed is passed without nargs, wrap it in a list
if isinstance(args.seed, int):
seeds = [args.seed]
else:
seeds = args.seed
for seed in seeds:
for guidance_scale in guidance_scales:
# Output directory structure: base_dir/seed_xxx/guidance_yyy
output_dir = os.path.join(args.output_dir, f"seed_{seed}", f"scale_{guidance_scale}")
os.makedirs(output_dir, exist_ok=True)
# Using for_loop when generating high-resolution images
for i in range(0, len(prompts), args.batch_size): # Process 1s prompts at a time
# Set random seeds
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_prompts = prompts[i : i + args.batch_size]
imgs = model.sample(
batch_prompts,
args.steps,
resolution=[args.resolution, args.resolution],
guidance_scale=guidance_scale,
step_scaling=1.0,
use_linear_quadratic_schedule=True,
linear_quadratic_emulating_steps=250,
get_intermediate_steps=args.streaming,
noisy_pad=args.noisy_pad,
zero_masking=args.zero_masking,
)
if args.streaming:
imgs, intermediate_imgs = imgs
if isinstance(intermediate_imgs, list):
for j, intermediate_img in enumerate(intermediate_imgs):
for k, img in enumerate(intermediate_img):
img.save(os.path.join(output_dir, f"{i + k:03d}_{j:03d}_intermediate.png"))
else:
# If intermediate_imgs is a single Image, save it directly
intermediate_imgs.save(os.path.join(output_dir, f"{i:03d}_0_intermediate.png"))
for j, img in enumerate(imgs):
img.save(os.path.join(output_dir, f"{i + j:03d}_check.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate images with model")
parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file")
parser.add_argument("--model-ckpt", type=str, required=True, help="Path to the model checkpoint file")
parser.add_argument(
"--seed", type=int, nargs="*", default=[7777], help="Random seed(s) for reproducibility (can provide multiple)"
)
# parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
parser.add_argument("--batch-size", type=int, default=32,help="Batch size for image generation")
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode for intermediate steps")
parser.add_argument("--noisy-pad", action="store_true")
parser.add_argument("--zero-masking", action="store_true")
parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
parser.add_argument("--lora-ckpt", action="store_true")
args = parser.parse_args()
main(args)
|