Motif-Image-6B-Preview / inference.py
beomgyu-kim's picture
Add inference script and main model execution logic
d09c0e5
raw
history blame
6.38 kB
import sys
sys.path.append(".")
import argparse
import os
import random
import numpy as np
import torch
import pickle
from configs.configuration_mmdit import MMDiTConfig
from models.modeling_motif_vision import MotifVision
from safetensors.torch import load_file
# from tools.motif_api import PromptRewriter
# from tools.nsfw_filtering import ContentFilter
def load_sharded_model(model_index_path):
"""
Loads a sharded model from a safetensors index file.
Args:
model_index_path (str): Path to the model.safetensors.index.json file.
"""
with open(model_index_path, 'r') as f:
index = json.load(f)
sharded_state_dicts = {}
folder = os.path.dirname(model_index_path)
for weight_name, filename in index["weight_map"].items():
if filename not in sharded_state_dicts:
sharded_state_dicts[filename] = load_file(os.path.join(folder, filename), device="cpu")
merged_state_dict = {}
for weight_name, filename in index["weight_map"].items():
merged_state_dict[weight_name] = sharded_state_dicts[filename][weight_name]
merged_state_dict = {k: v for k, v in merged_state_dict.items() if 'dit' in k}
return merged_state_dict
def main(args):
# Check if the prompt file exists
if not os.path.isfile(args.prompt_file):
print(f"Error: The prompt file '{args.prompt_file}' does not exist.")
sys.exit(1)
# List of prompts
with open(args.prompt_file) as f:
prompts = [prompt.rstrip() for prompt in f.readlines()]
# Load model configuration and model
config = MMDiTConfig.from_json_file(args.model_config)
config.vae_type = args.vae_type # VAE overriding
config.height = args.resolution
config.width = args.resolution
model = MotifVision(config)
# Load checkpoint
try:
ema_instance = torch.load(args.model_ckpt, weights_only=False)
ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
except pickle.UnpicklingError as e:
print(f"Error loading checkpoint: {e}")
ema_instance = load_file(args.model_ckpt)
ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
if "ema_model.bin" in args.model_ckpt:
# EMA checkpoint loading
for param, ema_param in zip(model.parameters(), ema_instance["shadow_params"]):
param.data.copy_(ema_param.data)
else:
# Non-EMA checkpoint loading
model.load_state_dict(ema_instance)
model = model.cuda()
model = model.to(dtype=torch.bfloat16)
model.eval()
# Use guidance scales from args or set default
guidance_scales = args.guidance_scales if args.guidance_scales else [5.0]
# If a single seed is passed without nargs, wrap it in a list
if isinstance(args.seed, int):
seeds = [args.seed]
else:
seeds = args.seed
for seed in seeds:
for guidance_scale in guidance_scales:
# Output directory structure: base_dir/seed_xxx/guidance_yyy
output_dir = os.path.join(args.output_dir, f"seed_{seed}", f"scale_{guidance_scale}")
os.makedirs(output_dir, exist_ok=True)
# Using for_loop when generating high-resolution images
for i in range(0, len(prompts), args.batch_size): # Process 1s prompts at a time
# Set random seeds
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_prompts = prompts[i : i + args.batch_size]
imgs = model.sample(
batch_prompts,
args.steps,
resolution=[args.resolution, args.resolution],
guidance_scale=guidance_scale,
step_scaling=1.0,
use_linear_quadratic_schedule=True,
linear_quadratic_emulating_steps=250,
get_intermediate_steps=args.streaming,
noisy_pad=args.noisy_pad,
zero_masking=args.zero_masking,
)
if args.streaming:
imgs, intermediate_imgs = imgs
if isinstance(intermediate_imgs, list):
for j, intermediate_img in enumerate(intermediate_imgs):
for k, img in enumerate(intermediate_img):
img.save(os.path.join(output_dir, f"{i + k:03d}_{j:03d}_intermediate.png"))
else:
# If intermediate_imgs is a single Image, save it directly
intermediate_imgs.save(os.path.join(output_dir, f"{i:03d}_0_intermediate.png"))
for j, img in enumerate(imgs):
img.save(os.path.join(output_dir, f"{i + j:03d}_check.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate images with model")
parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file")
parser.add_argument("--model-ckpt", type=str, required=True, help="Path to the model checkpoint file")
parser.add_argument(
"--seed", type=int, nargs="*", default=[7777], help="Random seed(s) for reproducibility (can provide multiple)"
)
# parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--streaming", action="store_true")
parser.add_argument("--noisy-pad", action="store_true")
parser.add_argument("--zero-masking", action="store_true")
parser.add_argument("--vae-type", type=str, default="SD3", help="Type of VAE")
parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
parser.add_argument("--lora-ckpt", action="store_true")
args = parser.parse_args()
main(args)