import sys sys.path.append(".") import argparse import os import random import numpy as np import torch import pickle from configs.configuration_mmdit import MMDiTConfig from models.modeling_motifimage import MotifImage from safetensors.torch import load_file # from tools.motif_api import PromptRewriter # from tools.nsfw_filtering import ContentFilter def load_sharded_model(model_index_path): """ Loads a sharded model from a safetensors index file. Args: model_index_path (str): Path to the model.safetensors.index.json file. """ with open(model_index_path, 'r') as f: index = json.load(f) sharded_state_dicts = {} folder = os.path.dirname(model_index_path) for weight_name, filename in index["weight_map"].items(): if filename not in sharded_state_dicts: sharded_state_dicts[filename] = load_file(os.path.join(folder, filename), device="cpu") merged_state_dict = {} for weight_name, filename in index["weight_map"].items(): merged_state_dict[weight_name] = sharded_state_dicts[filename][weight_name] merged_state_dict = {k: v for k, v in merged_state_dict.items() if 'dit' in k} return merged_state_dict def main(args): # Check if the prompt file exists if not os.path.isfile(args.prompt_file): print(f"Error: The prompt file '{args.prompt_file}' does not exist.") sys.exit(1) # List of prompts with open(args.prompt_file) as f: prompts = [prompt.rstrip() for prompt in f.readlines()] # Load model configuration and model config = MMDiTConfig.from_json_file(args.model_config) config.height = args.resolution config.width = args.resolution model = MotifImage(config) # Load checkpoint try: ema_instance = torch.load(args.model_ckpt, weights_only=False) ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k} except pickle.UnpicklingError as e: print(f"Error loading checkpoint: {e}") ema_instance = load_file(args.model_ckpt) ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k} if "ema_model.bin" in args.model_ckpt: # EMA checkpoint loading for param, ema_param in zip(model.parameters(), ema_instance["shadow_params"]): param.data.copy_(ema_param.data) else: # Non-EMA checkpoint loading model.load_state_dict(ema_instance) model = model.cuda() model = model.to(dtype=torch.bfloat16) model.eval() # Use guidance scales from args or set default guidance_scales = args.guidance_scales if args.guidance_scales else [5.0] # If a single seed is passed without nargs, wrap it in a list if isinstance(args.seed, int): seeds = [args.seed] else: seeds = args.seed for seed in seeds: for guidance_scale in guidance_scales: # Output directory structure: base_dir/seed_xxx/guidance_yyy output_dir = os.path.join(args.output_dir, f"seed_{seed}", f"scale_{guidance_scale}") os.makedirs(output_dir, exist_ok=True) # Using for_loop when generating high-resolution images for i in range(0, len(prompts), args.batch_size): # Process 1s prompts at a time # Set random seeds torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) batch_prompts = prompts[i : i + args.batch_size] imgs = model.sample( batch_prompts, args.steps, resolution=[args.resolution, args.resolution], guidance_scale=guidance_scale, step_scaling=1.0, use_linear_quadratic_schedule=True, linear_quadratic_emulating_steps=250, get_intermediate_steps=args.streaming, noisy_pad=args.noisy_pad, zero_masking=args.zero_masking, ) if args.streaming: imgs, intermediate_imgs = imgs if isinstance(intermediate_imgs, list): for j, intermediate_img in enumerate(intermediate_imgs): for k, img in enumerate(intermediate_img): img.save(os.path.join(output_dir, f"{i + k:03d}_{j:03d}_intermediate.png")) else: # If intermediate_imgs is a single Image, save it directly intermediate_imgs.save(os.path.join(output_dir, f"{i:03d}_0_intermediate.png")) for j, img in enumerate(imgs): img.save(os.path.join(output_dir, f"{i + j:03d}_check.png")) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate images with model") parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file") parser.add_argument("--model-ckpt", type=str, required=True, help="Path to the model checkpoint file") parser.add_argument( "--seed", type=int, nargs="*", default=[7777], help="Random seed(s) for reproducibility (can provide multiple)" ) # parser.add_argument("--slg", type=int, nargs="*", default=None, help="") parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation") parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images") parser.add_argument("--batch-size", type=int, default=32,help="Batch size for image generation") parser.add_argument("--streaming", action="store_true", help="Enable streaming mode for intermediate steps") parser.add_argument("--noisy-pad", action="store_true") parser.add_argument("--zero-masking", action="store_true") parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file") parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales") parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images") parser.add_argument("--lora-ckpt", action="store_true") args = parser.parse_args() main(args)