File size: 6,334 Bytes
d09c0e5
 
 
 
 
 
 
 
 
 
327b52c
d09c0e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327b52c
d09c0e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0012f0c
 
d09c0e5
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import sys

sys.path.append(".")
import argparse
import os
import random
import numpy as np
import torch
import pickle
from configs.configuration_mmdit import MMDiTConfig
from models.modeling_motifimage import MotifImage

from safetensors.torch import load_file
# from tools.motif_api import PromptRewriter
# from tools.nsfw_filtering import ContentFilter

def load_sharded_model(model_index_path):
    """
    Loads a sharded model from a safetensors index file.

    Args:
        model_index_path (str): Path to the model.safetensors.index.json file.
    """
    with open(model_index_path, 'r') as f:
        index = json.load(f)

    sharded_state_dicts = {}
    folder = os.path.dirname(model_index_path)
    for weight_name, filename in index["weight_map"].items():
        if filename not in sharded_state_dicts:
            sharded_state_dicts[filename] = load_file(os.path.join(folder, filename), device="cpu")
    
    merged_state_dict = {}
    for weight_name, filename in index["weight_map"].items():
        merged_state_dict[weight_name] = sharded_state_dicts[filename][weight_name]
        
    merged_state_dict = {k: v for k, v in merged_state_dict.items() if 'dit' in k}
    return merged_state_dict


def main(args):

    # Check if the prompt file exists
    if not os.path.isfile(args.prompt_file):
        print(f"Error: The prompt file '{args.prompt_file}' does not exist.")
        sys.exit(1)

    # List of prompts
    with open(args.prompt_file) as f:
        prompts = [prompt.rstrip() for prompt in f.readlines()]

    # Load model configuration and model
    config = MMDiTConfig.from_json_file(args.model_config)
    config.height = args.resolution
    config.width = args.resolution

    model = MotifImage(config)

    # Load checkpoint
    try:
        ema_instance = torch.load(args.model_ckpt, weights_only=False)
        ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}
    except pickle.UnpicklingError as e:
        print(f"Error loading checkpoint: {e}")
        ema_instance = load_file(args.model_ckpt)
        ema_instance = {k: v for k, v in ema_instance.items() if "dit" in k}

    if "ema_model.bin" in args.model_ckpt:
        # EMA checkpoint loading
        for param, ema_param in zip(model.parameters(), ema_instance["shadow_params"]):
            param.data.copy_(ema_param.data)
    else:
        # Non-EMA checkpoint loading
        model.load_state_dict(ema_instance)

    model = model.cuda()
    model = model.to(dtype=torch.bfloat16)
    model.eval()

    # Use guidance scales from args or set default
    guidance_scales = args.guidance_scales if args.guidance_scales else [5.0]

    # If a single seed is passed without nargs, wrap it in a list
    if isinstance(args.seed, int):
        seeds = [args.seed]
    else:
        seeds = args.seed

    for seed in seeds:
        for guidance_scale in guidance_scales:
            # Output directory structure: base_dir/seed_xxx/guidance_yyy
            output_dir = os.path.join(args.output_dir, f"seed_{seed}", f"scale_{guidance_scale}")
            os.makedirs(output_dir, exist_ok=True)
            # Using for_loop when generating high-resolution images
            for i in range(0, len(prompts), args.batch_size):  # Process 1s prompts at a time
                # Set random seeds
                torch.manual_seed(seed)
                random.seed(seed)
                np.random.seed(seed)

                batch_prompts = prompts[i : i + args.batch_size]
                imgs = model.sample(
                    batch_prompts,
                    args.steps,
                    resolution=[args.resolution, args.resolution],
                    guidance_scale=guidance_scale,
                    step_scaling=1.0,
                    use_linear_quadratic_schedule=True,
                    linear_quadratic_emulating_steps=250,
                    get_intermediate_steps=args.streaming,
                    noisy_pad=args.noisy_pad,
                    zero_masking=args.zero_masking,
                )
                if args.streaming:
                    imgs, intermediate_imgs = imgs
                    if isinstance(intermediate_imgs, list):
                        for j, intermediate_img in enumerate(intermediate_imgs):
                            for k, img in enumerate(intermediate_img):
                                img.save(os.path.join(output_dir, f"{i + k:03d}_{j:03d}_intermediate.png"))
                    else:
                        # If intermediate_imgs is a single Image, save it directly
                        intermediate_imgs.save(os.path.join(output_dir, f"{i:03d}_0_intermediate.png"))
                for j, img in enumerate(imgs):
                    img.save(os.path.join(output_dir, f"{i + j:03d}_check.png"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate images with model")
    parser.add_argument("--model-config", type=str, required=True, help="Path to the model configuration file")
    parser.add_argument("--model-ckpt", type=str, required=True, help="Path to the model checkpoint file")
    parser.add_argument(
        "--seed", type=int, nargs="*", default=[7777], help="Random seed(s) for reproducibility (can provide multiple)"
    )
    # parser.add_argument("--slg", type=int, nargs="*", default=None, help="")
    parser.add_argument("--steps", type=int, default=50, help="Number of steps for image generation")
    parser.add_argument("--resolution", type=int, default=256, help="Resolution of output images")
    parser.add_argument("--batch-size", type=int, default=32,help="Batch size for image generation")
    parser.add_argument("--streaming", action="store_true", help="Enable streaming mode for intermediate steps")
    parser.add_argument("--noisy-pad", action="store_true")
    parser.add_argument("--zero-masking", action="store_true")
    parser.add_argument("--prompt-file", type=str, default="prompt_128.txt", help="Path to the prompt file")
    parser.add_argument("--guidance-scales", type=float, nargs="*", default=None, help="List of guidance scales")
    parser.add_argument("--output-dir", type=str, default="output", help="Base output directory for generated images")
    parser.add_argument("--lora-ckpt", action="store_true")
    args = parser.parse_args()

    main(args)