Motif-Image-6B-Preview / inference.py
beomgyu-kim's picture
refactor/motifimage (#2)
327b52c verified
import sys
sys.path.append(".")
import argparse
import os
import random
import numpy as np
import torch
import pickle
from configs.configuration_mmdit import MMDiTConfig
from models.modeling_motifimage import MotifImage
from safetensors.torch import load_file
# from tools.motif_api import PromptRewriter
# from tools.nsfw_filtering import ContentFilter
def load_sharded_model(model_index_path):
"""
Loads a sharded model from a safetensors index file.
Args:
model_index_path (str): Path to the model.safetensors.index.json file.
"""
with open(model_index_path, 'r') as f:
index = json.load(f)
sharded_state_dicts = {}
folder = os.path.dirname(model_index_path)
for weight_name, filename in index["weight_map"].items():
if filename not in sharded_state_dicts:
sharded_state_dicts[filename] = load_file(os.path.join(folder, filename), device="cpu")
merged_state_dict = {}
for weight_name, filename in index["weight_map"].items():
merged_state_dict[weight_name] = sharded_state_dicts[filename][weight_name]
merged_state_dict = {k: v for k, v in merged_state_dict.items() if 'dit' in k}
return merged_state_dict
def main(args):
# Check if the prompt file exists
if not os.path.isfile(args.prompt_file):
print(f"Error: The prompt file '{args.prompt_file}' does not exist.")
sys.exit(1)
# List of prompts
with open(args.prompt_file) as f:
prompts = [prompt.rstrip() for prompt in f.readlines()]
# Load model configuration and model
config = MMDiTConfig.from_json_file(args.model_config)
config.height = args.resolution
config.width = args.resolution
model = MotifImage(config)
# Load checkpoint
try:
ema_instance = torch.load(args.model_ckpt, weights_only=False)
ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
except pickle.UnpicklingError as e:
print(f"Error loading checkpoint: {e}")
ema_instance = load_file(args.model_ckpt)
ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
if "ema_model.bin" in args.model_ckpt:
# EMA checkpoint loading
for param, ema_param in zip(model.parameters(), ema_instance["shadow_params"]):
param.data.copy_(ema_param.data)
else:
# Non-EMA checkpoint loading
model.load_state_dict(ema_instance)
model = model.cuda()
model = model.to(dtype=torch.bfloat16)
model.eval()
# Use guidance scales from args or set default
guidance_scales = args.guidance_scales if args.guidance_scales else [5.0]
# If a single seed is passed without nargs, wrap it in a list
if isinstance(args.seed, int):
seeds = [args.seed]
else:
seeds = args.seed
for seed in seeds:
for guidance_scale in guidance_scales:
# Output directory structure: base_dir/seed_xxx/guidance_yyy
output_dir = os.path.join(args.output_dir, f"seed_{seed}", f"scale_{guidance_scale}")
os.makedirs(output_dir, exist_ok=True)
# Using for_loop when generating high-resolution images
for i in range(0, len(prompts), args.batch_size): # Process 1s prompts at a time
# Set random seeds
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_prompts = prompts[i : i + args.batch_size]
imgs = model.sample(
batch_prompts,
args.steps,
resolution=[args.resolution, args.resolution],
guidance_scale=guidance_scale,
step_scaling=1.0,
use_linear_quadratic_schedule=True,
linear_quadratic_emulating_steps=250,
get_intermediate_steps=args.streaming,
noisy_pad=args.noisy_pad,
zero_masking=args.zero_masking,
)
if args.streaming:
imgs, intermediate_imgs = imgs
if isinstance(intermediate_imgs, list):
for j, intermediate_img in enumerate(intermediate_imgs):
for k, img in enumerate(intermediate_img):
img.save(os.path.join(output_dir, f"{i + k:03d}_{j:03d}_intermediate.png"))
else:
# If intermediate_imgs is a single Image, save it directly
intermediate_imgs.save(os.path.join(output_dir, f"{i:03d}_0_intermediate.png"))
for j, img in enumerate(imgs):
img.save(os.path.join(output_dir, f"{i + j:03d}_check.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate images with model")
parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file")
parser.add_argument("--model-ckpt", type=str, required=True, help="Path to the model checkpoint file")
parser.add_argument(
"--seed", type=int, nargs="*", default=[7777], help="Random seed(s) for reproducibility (can provide multiple)"
)
# parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
parser.add_argument("--batch-size", type=int, default=32,help="Batch size for image generation")
parser.add_argument("--streaming", action="store_true", help="Enable streaming mode for intermediate steps")
parser.add_argument("--noisy-pad", action="store_true")
parser.add_argument("--zero-masking", action="store_true")
parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
parser.add_argument("--lora-ckpt", action="store_true")
args = parser.parse_args()
main(args)