amd
/

hecui102 commited on
Commit
1e9bde1
·
verified ·
1 Parent(s): 2f94d9d

Upload 2 files

Browse files
Files changed (2) hide show
  1. inference.py +386 -0
  2. run.sh +51 -0
inference.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import datetime, time
3
+ from omegaconf import OmegaConf
4
+ from tqdm import tqdm
5
+ from einops import rearrange, repeat
6
+ from collections import OrderedDict
7
+
8
+ import torch
9
+ import torchvision
10
+ import torchvision.transforms as transforms
11
+ from pytorch_lightning import seed_everything
12
+ from PIL import Image
13
+ sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
14
+ from lvdm.models.samplers.ddim import DDIMSampler
15
+ from lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
16
+ from utils.utils import instantiate_from_config
17
+ import random
18
+
19
+
20
+ def get_filelist(data_dir, postfixes):
21
+ patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes]
22
+ file_list = []
23
+ for pattern in patterns:
24
+ file_list.extend(glob.glob(pattern))
25
+ file_list.sort()
26
+ return file_list
27
+
28
+ def load_model_checkpoint(model, ckpt):
29
+ state_dict = torch.load(ckpt, map_location="cpu")
30
+ if "state_dict" in list(state_dict.keys()):
31
+ state_dict = state_dict["state_dict"]
32
+ model.load_state_dict(state_dict, strict=True)
33
+ return model
34
+
35
+
36
+ def load_prompts(prompt_file):
37
+ f = open(prompt_file, 'r')
38
+ prompt_list = []
39
+ for idx, line in enumerate(f.readlines()):
40
+ l = line.strip()
41
+ if len(l) != 0:
42
+ prompt_list.append(l)
43
+ f.close()
44
+ return prompt_list
45
+
46
+ def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, interp=False):
47
+ transform = transforms.Compose([
48
+ transforms.Resize(min(video_size)),
49
+ transforms.CenterCrop(video_size),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
52
+ ## load prompts
53
+ '''
54
+ prompt_file = get_filelist(data_dir, ['txt'])
55
+ assert len(prompt_file) > 0, "Error: found NO prompt file!"
56
+ ###### default prompt
57
+ default_idx = 0
58
+ default_idx = min(default_idx, len(prompt_file)-1)
59
+ if len(prompt_file) > 1:
60
+ print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.")
61
+ ## only use the first one (sorted by name) if multiple exist
62
+
63
+ ## load video
64
+ '''
65
+ file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
66
+ # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!"
67
+ data_list = []
68
+ filename_list = []
69
+ #prompt_list = load_prompts(prompt_file[default_idx])
70
+ prompt_list = []
71
+ valid_extensions = {'.jpg', '.jpeg', '.png'}
72
+ prompt_list = []
73
+ for filename in file_list:
74
+ name, ext = os.path.splitext(filename)
75
+ if ext.lower() in valid_extensions:
76
+ prompt_list.append(name)
77
+ #prompt_list = [i.split('.')[:-1] for i in file_list]
78
+ prompt_list = [i.split('/')[-1] for i in prompt_list]
79
+ print(prompt_list)
80
+
81
+ n_samples = len(prompt_list)
82
+ for idx in range(n_samples):
83
+ if interp:
84
+ image1 = Image.open(file_list[2*idx]).convert('RGB')
85
+ image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
86
+ image2 = Image.open(file_list[2*idx+1]).convert('RGB')
87
+ image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
88
+ frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
89
+ frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
90
+ frame_tensor = torch.cat([frame_tensor1, frame_tensor2], dim=1)
91
+ _, filename = os.path.split(file_list[idx*2])
92
+ else:
93
+ image = Image.open(file_list[idx]).convert('RGB')
94
+ #import cv2
95
+ #img = cv2.imread(file_list[idx])
96
+ #print(img)
97
+ #print(transform)
98
+ image_tensor = transform(image).unsqueeze(1) # [c,1,h,w]
99
+ frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
100
+ #print(frame_tensor)
101
+ _, filename = os.path.split(file_list[idx])
102
+
103
+ data_list.append(frame_tensor)
104
+ filename_list.append(filename)
105
+
106
+ return filename_list, data_list, prompt_list
107
+
108
+
109
+ def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
110
+ filename = filename.split('.')[0]+'.mp4'
111
+ prompt = prompt[0] if isinstance(prompt, list) else prompt
112
+
113
+ ## save video
114
+ videos = [samples]
115
+ savedirs = [fakedir]
116
+ for idx, video in enumerate(videos):
117
+ if video is None:
118
+ continue
119
+ # b,c,t,h,w
120
+ video = video.detach().cpu()
121
+ video = torch.clamp(video.float(), -1., 1.)
122
+ n = video.shape[0]
123
+ video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
124
+ if loop:
125
+ video = video[:-1,...]
126
+
127
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w]
128
+ grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w]
129
+ grid = (grid + 1.0) / 2.0
130
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
131
+ path = os.path.join(savedirs[idx], filename)
132
+ torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality
133
+
134
+
135
+ def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False):
136
+ prompt = prompt[0] if isinstance(prompt, list) else prompt
137
+
138
+ ## save video
139
+ videos = [samples]
140
+ savedirs = [fakedir]
141
+ for idx, video in enumerate(videos):
142
+ if video is None:
143
+ continue
144
+ # b,c,t,h,w
145
+ video = video.detach().cpu()
146
+ if loop: # remove the last frame
147
+ video = video[:,:,:-1,...]
148
+ video = torch.clamp(video.float(), -1., 1.)
149
+ n = video.shape[0]
150
+ for i in range(n):
151
+ grid = video[i,...]
152
+ grid = (grid + 1.0) / 2.0
153
+ grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc
154
+ #path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4')
155
+ path = os.path.join(savedirs[idx], f'{filename}.mp4')
156
+ #print(path)
157
+ torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '0'})
158
+
159
+ def get_latent_z(model, videos):
160
+ b, c, t, h, w = videos.shape
161
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
162
+ z = model.encode_first_stage(x)
163
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
164
+ return z
165
+
166
+
167
+ def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
168
+ unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
169
+ ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
170
+ batch_size = noise_shape[0]
171
+ fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
172
+
173
+ if not text_input:
174
+ prompts = [""]*batch_size
175
+
176
+ img = videos[:,:,0] #bchw
177
+ #print('img:', img.dtype)
178
+ img_emb = model.embedder(img) ## blc
179
+ #print('img_emb', img_emb.dtype)
180
+ img_emb = model.image_proj_model(img_emb)
181
+ #print(img_emb)
182
+
183
+ cond_emb = model.get_learned_conditioning(prompts)
184
+ cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
185
+ if model.model.conditioning_key == 'hybrid':
186
+ z = get_latent_z(model, videos) # b c t h w
187
+ if loop or interp:
188
+ img_cat_cond = torch.zeros_like(z)
189
+ img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
190
+ img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
191
+ else:
192
+ img_cat_cond = z[:,:,:1,:,:]
193
+ img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
194
+ cond["c_concat"] = [img_cat_cond] # b c 1 h w
195
+
196
+ if unconditional_guidance_scale != 1.0:
197
+ if model.uncond_type == "empty_seq":
198
+ prompts = batch_size * [""]
199
+ #prompts = batch_size * ["missing body parts, temporal flickering"]
200
+ #prompts = batch_size * ["low quality, temporal flickering, stuttering motion, low resolution, lack of detail, soft focus, bad hands, extra limbs, distorted facial features, incorrect proportions, missing body parts"]
201
+ uc_emb = model.get_learned_conditioning(prompts)
202
+ uc_emb = torch.load('./reneg_checkpoint.bin').to(model.device)
203
+ #print(uc_emb.shape)
204
+ elif model.uncond_type == "zero_embed":
205
+ uc_emb = torch.zeros_like(cond_emb)
206
+ uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
207
+ uc_img_emb = model.image_proj_model(uc_img_emb)
208
+ #uc_emb = torch.load('/group/ossdphi_algo_scratch_11/xiaominl/projects/ImageReward/train-sd2.1-base-null_emb-lr0.005-hpsv2-fp16-bz64/checkpoint-1000/checkpoint.bin').to(model.device)
209
+ #print(uc_emb.shape)
210
+ uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
211
+ if model.model.conditioning_key == 'hybrid':
212
+ uc["c_concat"] = [img_cat_cond]
213
+ else:
214
+ uc = None
215
+
216
+ ## we need one more unconditioning image=yes, text=""
217
+ if multiple_cond_cfg and cfg_img != 1.0:
218
+ uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
219
+ if model.model.conditioning_key == 'hybrid':
220
+ uc_2["c_concat"] = [img_cat_cond]
221
+ kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
222
+ else:
223
+ kwargs.update({"unconditional_conditioning_img_nonetext": None})
224
+
225
+ z0 = None
226
+ cond_mask = None
227
+
228
+ batch_variants = []
229
+ for _ in range(n_samples):
230
+
231
+ if z0 is not None:
232
+ cond_z0 = z0.clone()
233
+ kwargs.update({"clean_cond": True})
234
+ else:
235
+ cond_z0 = None
236
+ if ddim_sampler is not None:
237
+
238
+ samples, _ = ddim_sampler.sample(S=ddim_steps,
239
+ conditioning=cond,
240
+ batch_size=batch_size,
241
+ shape=noise_shape[1:],
242
+ verbose=False,
243
+ unconditional_guidance_scale=unconditional_guidance_scale,
244
+ unconditional_conditioning=uc,
245
+ eta=ddim_eta,
246
+ cfg_img=cfg_img,
247
+ mask=cond_mask,
248
+ x0=cond_z0,
249
+ fs=fs,
250
+ timestep_spacing=timestep_spacing,
251
+ guidance_rescale=guidance_rescale,
252
+ **kwargs
253
+ )
254
+
255
+ ## reconstruct from latent to pixel space
256
+ batch_images = model.decode_first_stage(samples)
257
+ batch_variants.append(batch_images)
258
+ ## variants, batch, c, t, h, w
259
+ batch_variants = torch.stack(batch_variants)
260
+ return batch_variants.permute(1, 0, 2, 3, 4, 5)
261
+
262
+
263
+ def run_inference(args, gpu_num, gpu_no):
264
+ ## model config
265
+ config = OmegaConf.load(args.config)
266
+ model_config = config.pop("model", OmegaConf.create())
267
+
268
+ ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
269
+ model_config['params']['unet_config']['params']['use_checkpoint'] = False
270
+ model = instantiate_from_config(model_config)
271
+ model = model.cuda(gpu_no)
272
+ model.perframe_ae = args.perframe_ae
273
+ assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
274
+ model = load_model_checkpoint(model, args.ckpt_path)
275
+ print('load model!!!!!!!!!!!!!')
276
+ if args.unet_path != '' and args.use_unet == 1:
277
+ model.model.diffusion_model.load_state_dict(torch.load(args.unet_path), strict=True)
278
+ model.image_proj_model.load_state_dict(torch.load(args.img_proj_path), strict=True)
279
+ print('load unet down!', args.unet_path)
280
+ print('load image proj down!', args.img_proj_path)
281
+
282
+ model.eval()
283
+
284
+ ## run over data
285
+ assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
286
+ assert args.bs == 1, "Current implementation only support [batch size = 1]!"
287
+ ## latent noise shape
288
+ h, w = args.height // 8, args.width // 8
289
+ channels = model.model.diffusion_model.out_channels
290
+ n_frames = args.video_length
291
+ print(f'Inference with {n_frames} frames')
292
+ noise_shape = [args.bs, channels, n_frames, h, w]
293
+
294
+ #fakedir = os.path.join(args.savedir, "samples")
295
+ fakedir = args.savedir
296
+ fakedir_separate = os.path.join(args.savedir, "samples_separate")
297
+
298
+ os.makedirs(fakedir, exist_ok=True)
299
+ #os.makedirs(fakedir_separate, exist_ok=True)
300
+
301
+ ## prompt file setting
302
+ assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
303
+ filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, interp=args.interp)
304
+ num_samples = len(prompt_list)
305
+ samples_split = num_samples // gpu_num
306
+ print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples))
307
+ #indices = random.choices(list(range(0, num_samples)), k=samples_per_device)
308
+ indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
309
+ prompt_list_rank = [prompt_list[i] for i in indices]
310
+ data_list_rank = [data_list[i] for i in indices]
311
+ filename_list_rank = [filename_list[i] for i in indices]
312
+
313
+ start = time.time()
314
+ with torch.no_grad(), torch.cuda.amp.autocast():
315
+ for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'):
316
+ prompts = prompt_list_rank[indice:indice+args.bs]
317
+ videos = data_list_rank[indice:indice+args.bs]
318
+ filenames = filename_list_rank[indice:indice+args.bs]
319
+ if isinstance(videos, list):
320
+ videos = torch.stack(videos, dim=0).to("cuda")
321
+ else:
322
+ videos = videos.unsqueeze(0).to("cuda")
323
+
324
+ batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
325
+ args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.interp, args.timestep_spacing, args.guidance_rescale)
326
+
327
+ ## save each example individually
328
+ for nn, samples in enumerate(batch_samples):
329
+ ## samples : [n_samples,c,t,h,w]
330
+ prompt = prompts[nn]
331
+ filename = filenames[nn]
332
+ # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
333
+ save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
334
+
335
+ print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
336
+
337
+
338
+ def get_parser():
339
+ parser = argparse.ArgumentParser()
340
+ parser.add_argument("--savedir", type=str, default=None, help="results saving path")
341
+ parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
342
+ parser.add_argument("--config", type=str, help="config (yaml) path")
343
+ parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts")
344
+ parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
345
+ parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
346
+ parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
347
+ parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
348
+ parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
349
+ parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
350
+ parser.add_argument("--frame_stride", type=int, default=3, help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)")
351
+ parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
352
+ parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything")
353
+ parser.add_argument("--video_length", type=int, default=16, help="inference video length")
354
+ parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt")
355
+ parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not")
356
+ parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not")
357
+ parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning")
358
+ parser.add_argument("--timestep_spacing", type=str, default="uniform", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.")
359
+ parser.add_argument("--guidance_rescale", type=float, default=0.0, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)")
360
+ parser.add_argument("--perframe_ae", action='store_true', default=False, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024")
361
+
362
+ ## currently not support looping video and generative frame interpolation
363
+ parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not")
364
+ parser.add_argument("--interp", action='store_true', default=False, help="generate generative frame interpolation or not")
365
+ parser.add_argument("--use_unet", type=int, default=0, help="")
366
+ parser.add_argument("--unet_path", type=str, default="path/to/unet", help="")
367
+ parser.add_argument("--img_proj_path", type=str, default="path/to/img proj", help="")
368
+
369
+
370
+ return parser
371
+
372
+
373
+ if __name__ == '__main__':
374
+ now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
375
+ print("@DynamiCrafter cond-Inference: %s"%now)
376
+ parser = get_parser()
377
+ args = parser.parse_args()
378
+
379
+ seed = args.seed
380
+ if seed < 0:
381
+ seed = random.randint(0, 2 ** 31)
382
+ seed_everything(seed)
383
+ #x = torch.randn(1,2,4,4)
384
+ rank, gpu_num = 0, 1
385
+ run_inference(args, gpu_num, rank)
386
+
run.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version=$1 ## 1024, 512, 256
2
+ GPU=3
3
+ ckpt='./stage_1.ckpt'
4
+ config='configs/inference_512_v1.0_07B.yaml'
5
+ prompt_dir="/VBench"
6
+ base_res_dir="result"
7
+
8
+ # 判断分辨率
9
+ if [ "$1" == "256" ]; then
10
+ H=256
11
+ FS=3
12
+ elif [ "$1" == "512" ]; then
13
+ H=320
14
+ FS=24
15
+ elif [ "$1" == "1024" ]; then
16
+ H=576
17
+ FS=10
18
+ else
19
+ echo "Invalid input. Please enter 256, 512, or 1024."
20
+ exit 1
21
+ fi
22
+
23
+
24
+ seed='123'
25
+ sub_dir='0'
26
+ res_dir="${base_res_dir}/${sub_dir}"
27
+
28
+ echo "Running seed=$seed -> Saving to: $res_dir"
29
+ echo $prompt_dir
30
+
31
+ CUDA_VISIBLE_DEVICES=$GPU python3 scripts/evaluation/inference.py \
32
+ --seed ${seed} \
33
+ --ckpt_path $ckpt \
34
+ --config $config \
35
+ --savedir $res_dir \
36
+ --n_samples 1 \
37
+ --bs 1 --height ${H} --width $1 \
38
+ --unconditional_guidance_scale 7.5 \
39
+ --ddim_steps 16 \
40
+ --ddim_eta 1.0 \
41
+ --prompt_dir $prompt_dir \
42
+ --text_input \
43
+ --video_length 16 \
44
+ --frame_stride ${FS} \
45
+ --use_unet 1 \
46
+ --unet_path 'stae_2/output/unet.pt' \
47
+ --img_proj_path 'stage_2/output/img_proj.pt' \
48
+ --timestep_spacing 'uniform_trailing' \
49
+ --guidance_rescale 0.7 \
50
+ --perframe_ae
51
+