amd
/

hecui102 commited on
Commit
bf07b9f
·
verified ·
1 Parent(s): 242d269

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -386
inference.py DELETED
@@ -1,386 +0,0 @@
1
- import argparse, os, sys, glob
2
- import datetime, time
3
- from omegaconf import OmegaConf
4
- from tqdm import tqdm
5
- from einops import rearrange, repeat
6
- from collections import OrderedDict
7
-
8
- import torch
9
- import torchvision
10
- import torchvision.transforms as transforms
11
- from pytorch_lightning import seed_everything
12
- from PIL import Image
13
- sys.path.insert(1, os.path.join(sys.path[0], '..', '..'))
14
- from lvdm.models.samplers.ddim import DDIMSampler
15
- from lvdm.models.samplers.ddim_multiplecond import DDIMSampler as DDIMSampler_multicond
16
- from utils.utils import instantiate_from_config
17
- import random
18
-
19
-
20
- def get_filelist(data_dir, postfixes):
21
- patterns = [os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes]
22
- file_list = []
23
- for pattern in patterns:
24
- file_list.extend(glob.glob(pattern))
25
- file_list.sort()
26
- return file_list
27
-
28
- def load_model_checkpoint(model, ckpt):
29
- state_dict = torch.load(ckpt, map_location="cpu")
30
- if "state_dict" in list(state_dict.keys()):
31
- state_dict = state_dict["state_dict"]
32
- model.load_state_dict(state_dict, strict=True)
33
- return model
34
-
35
-
36
- def load_prompts(prompt_file):
37
- f = open(prompt_file, 'r')
38
- prompt_list = []
39
- for idx, line in enumerate(f.readlines()):
40
- l = line.strip()
41
- if len(l) != 0:
42
- prompt_list.append(l)
43
- f.close()
44
- return prompt_list
45
-
46
- def load_data_prompts(data_dir, video_size=(256,256), video_frames=16, interp=False):
47
- transform = transforms.Compose([
48
- transforms.Resize(min(video_size)),
49
- transforms.CenterCrop(video_size),
50
- transforms.ToTensor(),
51
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
52
- ## load prompts
53
- '''
54
- prompt_file = get_filelist(data_dir, ['txt'])
55
- assert len(prompt_file) > 0, "Error: found NO prompt file!"
56
- ###### default prompt
57
- default_idx = 0
58
- default_idx = min(default_idx, len(prompt_file)-1)
59
- if len(prompt_file) > 1:
60
- print(f"Warning: multiple prompt files exist. The one {os.path.split(prompt_file[default_idx])[1]} is used.")
61
- ## only use the first one (sorted by name) if multiple exist
62
-
63
- ## load video
64
- '''
65
- file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
66
- # assert len(file_list) == n_samples, "Error: data and prompts are NOT paired!"
67
- data_list = []
68
- filename_list = []
69
- #prompt_list = load_prompts(prompt_file[default_idx])
70
- prompt_list = []
71
- valid_extensions = {'.jpg', '.jpeg', '.png'}
72
- prompt_list = []
73
- for filename in file_list:
74
- name, ext = os.path.splitext(filename)
75
- if ext.lower() in valid_extensions:
76
- prompt_list.append(name)
77
- #prompt_list = [i.split('.')[:-1] for i in file_list]
78
- prompt_list = [i.split('/')[-1] for i in prompt_list]
79
- print(prompt_list)
80
-
81
- n_samples = len(prompt_list)
82
- for idx in range(n_samples):
83
- if interp:
84
- image1 = Image.open(file_list[2*idx]).convert('RGB')
85
- image_tensor1 = transform(image1).unsqueeze(1) # [c,1,h,w]
86
- image2 = Image.open(file_list[2*idx+1]).convert('RGB')
87
- image_tensor2 = transform(image2).unsqueeze(1) # [c,1,h,w]
88
- frame_tensor1 = repeat(image_tensor1, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
89
- frame_tensor2 = repeat(image_tensor2, 'c t h w -> c (repeat t) h w', repeat=video_frames//2)
90
- frame_tensor = torch.cat([frame_tensor1, frame_tensor2], dim=1)
91
- _, filename = os.path.split(file_list[idx*2])
92
- else:
93
- image = Image.open(file_list[idx]).convert('RGB')
94
- #import cv2
95
- #img = cv2.imread(file_list[idx])
96
- #print(img)
97
- #print(transform)
98
- image_tensor = transform(image).unsqueeze(1) # [c,1,h,w]
99
- frame_tensor = repeat(image_tensor, 'c t h w -> c (repeat t) h w', repeat=video_frames)
100
- #print(frame_tensor)
101
- _, filename = os.path.split(file_list[idx])
102
-
103
- data_list.append(frame_tensor)
104
- filename_list.append(filename)
105
-
106
- return filename_list, data_list, prompt_list
107
-
108
-
109
- def save_results(prompt, samples, filename, fakedir, fps=8, loop=False):
110
- filename = filename.split('.')[0]+'.mp4'
111
- prompt = prompt[0] if isinstance(prompt, list) else prompt
112
-
113
- ## save video
114
- videos = [samples]
115
- savedirs = [fakedir]
116
- for idx, video in enumerate(videos):
117
- if video is None:
118
- continue
119
- # b,c,t,h,w
120
- video = video.detach().cpu()
121
- video = torch.clamp(video.float(), -1., 1.)
122
- n = video.shape[0]
123
- video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
124
- if loop:
125
- video = video[:-1,...]
126
-
127
- frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video] #[3, 1*h, n*w]
128
- grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, h, n*w]
129
- grid = (grid + 1.0) / 2.0
130
- grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
131
- path = os.path.join(savedirs[idx], filename)
132
- torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '10'}) ## crf indicates the quality
133
-
134
-
135
- def save_results_seperate(prompt, samples, filename, fakedir, fps=10, loop=False):
136
- prompt = prompt[0] if isinstance(prompt, list) else prompt
137
-
138
- ## save video
139
- videos = [samples]
140
- savedirs = [fakedir]
141
- for idx, video in enumerate(videos):
142
- if video is None:
143
- continue
144
- # b,c,t,h,w
145
- video = video.detach().cpu()
146
- if loop: # remove the last frame
147
- video = video[:,:,:-1,...]
148
- video = torch.clamp(video.float(), -1., 1.)
149
- n = video.shape[0]
150
- for i in range(n):
151
- grid = video[i,...]
152
- grid = (grid + 1.0) / 2.0
153
- grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0) #thwc
154
- #path = os.path.join(savedirs[idx].replace('samples', 'samples_separate'), f'{filename.split(".")[0]}_sample{i}.mp4')
155
- path = os.path.join(savedirs[idx], f'{filename}.mp4')
156
- #print(path)
157
- torchvision.io.write_video(path, grid, fps=fps, video_codec='h264', options={'crf': '0'})
158
-
159
- def get_latent_z(model, videos):
160
- b, c, t, h, w = videos.shape
161
- x = rearrange(videos, 'b c t h w -> (b t) c h w')
162
- z = model.encode_first_stage(x)
163
- z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
164
- return z
165
-
166
-
167
- def image_guided_synthesis(model, prompts, videos, noise_shape, n_samples=1, ddim_steps=50, ddim_eta=1., \
168
- unconditional_guidance_scale=1.0, cfg_img=None, fs=None, text_input=False, multiple_cond_cfg=False, loop=False, interp=False, timestep_spacing='uniform', guidance_rescale=0.0, **kwargs):
169
- ddim_sampler = DDIMSampler(model) if not multiple_cond_cfg else DDIMSampler_multicond(model)
170
- batch_size = noise_shape[0]
171
- fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
172
-
173
- if not text_input:
174
- prompts = [""]*batch_size
175
-
176
- img = videos[:,:,0] #bchw
177
- #print('img:', img.dtype)
178
- img_emb = model.embedder(img) ## blc
179
- #print('img_emb', img_emb.dtype)
180
- img_emb = model.image_proj_model(img_emb)
181
- #print(img_emb)
182
-
183
- cond_emb = model.get_learned_conditioning(prompts)
184
- cond = {"c_crossattn": [torch.cat([cond_emb,img_emb], dim=1)]}
185
- if model.model.conditioning_key == 'hybrid':
186
- z = get_latent_z(model, videos) # b c t h w
187
- if loop or interp:
188
- img_cat_cond = torch.zeros_like(z)
189
- img_cat_cond[:,:,0,:,:] = z[:,:,0,:,:]
190
- img_cat_cond[:,:,-1,:,:] = z[:,:,-1,:,:]
191
- else:
192
- img_cat_cond = z[:,:,:1,:,:]
193
- img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=z.shape[2])
194
- cond["c_concat"] = [img_cat_cond] # b c 1 h w
195
-
196
- if unconditional_guidance_scale != 1.0:
197
- if model.uncond_type == "empty_seq":
198
- prompts = batch_size * [""]
199
- #prompts = batch_size * ["missing body parts, temporal flickering"]
200
- #prompts = batch_size * ["low quality, temporal flickering, stuttering motion, low resolution, lack of detail, soft focus, bad hands, extra limbs, distorted facial features, incorrect proportions, missing body parts"]
201
- uc_emb = model.get_learned_conditioning(prompts)
202
- uc_emb = torch.load('./reneg_checkpoint.bin').to(model.device)
203
- #print(uc_emb.shape)
204
- elif model.uncond_type == "zero_embed":
205
- uc_emb = torch.zeros_like(cond_emb)
206
- uc_img_emb = model.embedder(torch.zeros_like(img)) ## b l c
207
- uc_img_emb = model.image_proj_model(uc_img_emb)
208
- #uc_emb = torch.load('/group/ossdphi_algo_scratch_11/xiaominl/projects/ImageReward/train-sd2.1-base-null_emb-lr0.005-hpsv2-fp16-bz64/checkpoint-1000/checkpoint.bin').to(model.device)
209
- #print(uc_emb.shape)
210
- uc = {"c_crossattn": [torch.cat([uc_emb,uc_img_emb],dim=1)]}
211
- if model.model.conditioning_key == 'hybrid':
212
- uc["c_concat"] = [img_cat_cond]
213
- else:
214
- uc = None
215
-
216
- ## we need one more unconditioning image=yes, text=""
217
- if multiple_cond_cfg and cfg_img != 1.0:
218
- uc_2 = {"c_crossattn": [torch.cat([uc_emb,img_emb],dim=1)]}
219
- if model.model.conditioning_key == 'hybrid':
220
- uc_2["c_concat"] = [img_cat_cond]
221
- kwargs.update({"unconditional_conditioning_img_nonetext": uc_2})
222
- else:
223
- kwargs.update({"unconditional_conditioning_img_nonetext": None})
224
-
225
- z0 = None
226
- cond_mask = None
227
-
228
- batch_variants = []
229
- for _ in range(n_samples):
230
-
231
- if z0 is not None:
232
- cond_z0 = z0.clone()
233
- kwargs.update({"clean_cond": True})
234
- else:
235
- cond_z0 = None
236
- if ddim_sampler is not None:
237
-
238
- samples, _ = ddim_sampler.sample(S=ddim_steps,
239
- conditioning=cond,
240
- batch_size=batch_size,
241
- shape=noise_shape[1:],
242
- verbose=False,
243
- unconditional_guidance_scale=unconditional_guidance_scale,
244
- unconditional_conditioning=uc,
245
- eta=ddim_eta,
246
- cfg_img=cfg_img,
247
- mask=cond_mask,
248
- x0=cond_z0,
249
- fs=fs,
250
- timestep_spacing=timestep_spacing,
251
- guidance_rescale=guidance_rescale,
252
- **kwargs
253
- )
254
-
255
- ## reconstruct from latent to pixel space
256
- batch_images = model.decode_first_stage(samples)
257
- batch_variants.append(batch_images)
258
- ## variants, batch, c, t, h, w
259
- batch_variants = torch.stack(batch_variants)
260
- return batch_variants.permute(1, 0, 2, 3, 4, 5)
261
-
262
-
263
- def run_inference(args, gpu_num, gpu_no):
264
- ## model config
265
- config = OmegaConf.load(args.config)
266
- model_config = config.pop("model", OmegaConf.create())
267
-
268
- ## set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
269
- model_config['params']['unet_config']['params']['use_checkpoint'] = False
270
- model = instantiate_from_config(model_config)
271
- model = model.cuda(gpu_no)
272
- model.perframe_ae = args.perframe_ae
273
- assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
274
- model = load_model_checkpoint(model, args.ckpt_path)
275
- print('load model!!!!!!!!!!!!!')
276
- if args.unet_path != '' and args.use_unet == 1:
277
- model.model.diffusion_model.load_state_dict(torch.load(args.unet_path), strict=True)
278
- model.image_proj_model.load_state_dict(torch.load(args.img_proj_path), strict=True)
279
- print('load unet down!', args.unet_path)
280
- print('load image proj down!', args.img_proj_path)
281
-
282
- model.eval()
283
-
284
- ## run over data
285
- assert (args.height % 16 == 0) and (args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!"
286
- assert args.bs == 1, "Current implementation only support [batch size = 1]!"
287
- ## latent noise shape
288
- h, w = args.height // 8, args.width // 8
289
- channels = model.model.diffusion_model.out_channels
290
- n_frames = args.video_length
291
- print(f'Inference with {n_frames} frames')
292
- noise_shape = [args.bs, channels, n_frames, h, w]
293
-
294
- #fakedir = os.path.join(args.savedir, "samples")
295
- fakedir = args.savedir
296
- fakedir_separate = os.path.join(args.savedir, "samples_separate")
297
-
298
- os.makedirs(fakedir, exist_ok=True)
299
- #os.makedirs(fakedir_separate, exist_ok=True)
300
-
301
- ## prompt file setting
302
- assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
303
- filename_list, data_list, prompt_list = load_data_prompts(args.prompt_dir, video_size=(args.height, args.width), video_frames=n_frames, interp=args.interp)
304
- num_samples = len(prompt_list)
305
- samples_split = num_samples // gpu_num
306
- print('Prompts testing [rank:%d] %d/%d samples loaded.'%(gpu_no, samples_split, num_samples))
307
- #indices = random.choices(list(range(0, num_samples)), k=samples_per_device)
308
- indices = list(range(samples_split*gpu_no, samples_split*(gpu_no+1)))
309
- prompt_list_rank = [prompt_list[i] for i in indices]
310
- data_list_rank = [data_list[i] for i in indices]
311
- filename_list_rank = [filename_list[i] for i in indices]
312
-
313
- start = time.time()
314
- with torch.no_grad(), torch.cuda.amp.autocast():
315
- for idx, indice in tqdm(enumerate(range(0, len(prompt_list_rank), args.bs)), desc='Sample Batch'):
316
- prompts = prompt_list_rank[indice:indice+args.bs]
317
- videos = data_list_rank[indice:indice+args.bs]
318
- filenames = filename_list_rank[indice:indice+args.bs]
319
- if isinstance(videos, list):
320
- videos = torch.stack(videos, dim=0).to("cuda")
321
- else:
322
- videos = videos.unsqueeze(0).to("cuda")
323
-
324
- batch_samples = image_guided_synthesis(model, prompts, videos, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, \
325
- args.unconditional_guidance_scale, args.cfg_img, args.frame_stride, args.text_input, args.multiple_cond_cfg, args.loop, args.interp, args.timestep_spacing, args.guidance_rescale)
326
-
327
- ## save each example individually
328
- for nn, samples in enumerate(batch_samples):
329
- ## samples : [n_samples,c,t,h,w]
330
- prompt = prompts[nn]
331
- filename = filenames[nn]
332
- # save_results(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
333
- save_results_seperate(prompt, samples, filename, fakedir, fps=8, loop=args.loop)
334
-
335
- print(f"Saved in {args.savedir}. Time used: {(time.time() - start):.2f} seconds")
336
-
337
-
338
- def get_parser():
339
- parser = argparse.ArgumentParser()
340
- parser.add_argument("--savedir", type=str, default=None, help="results saving path")
341
- parser.add_argument("--ckpt_path", type=str, default=None, help="checkpoint path")
342
- parser.add_argument("--config", type=str, help="config (yaml) path")
343
- parser.add_argument("--prompt_dir", type=str, default=None, help="a data dir containing videos and prompts")
344
- parser.add_argument("--n_samples", type=int, default=1, help="num of samples per prompt",)
345
- parser.add_argument("--ddim_steps", type=int, default=50, help="steps of ddim if positive, otherwise use DDPM",)
346
- parser.add_argument("--ddim_eta", type=float, default=1.0, help="eta for ddim sampling (0.0 yields deterministic sampling)",)
347
- parser.add_argument("--bs", type=int, default=1, help="batch size for inference, should be one")
348
- parser.add_argument("--height", type=int, default=512, help="image height, in pixel space")
349
- parser.add_argument("--width", type=int, default=512, help="image width, in pixel space")
350
- parser.add_argument("--frame_stride", type=int, default=3, help="frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)")
351
- parser.add_argument("--unconditional_guidance_scale", type=float, default=1.0, help="prompt classifier-free guidance")
352
- parser.add_argument("--seed", type=int, default=123, help="seed for seed_everything")
353
- parser.add_argument("--video_length", type=int, default=16, help="inference video length")
354
- parser.add_argument("--negative_prompt", action='store_true', default=False, help="negative prompt")
355
- parser.add_argument("--text_input", action='store_true', default=False, help="input text to I2V model or not")
356
- parser.add_argument("--multiple_cond_cfg", action='store_true', default=False, help="use multi-condition cfg or not")
357
- parser.add_argument("--cfg_img", type=float, default=None, help="guidance scale for image conditioning")
358
- parser.add_argument("--timestep_spacing", type=str, default="uniform", help="The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.")
359
- parser.add_argument("--guidance_rescale", type=float, default=0.0, help="guidance rescale in [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891)")
360
- parser.add_argument("--perframe_ae", action='store_true', default=False, help="if we use per-frame AE decoding, set it to True to save GPU memory, especially for the model of 576x1024")
361
-
362
- ## currently not support looping video and generative frame interpolation
363
- parser.add_argument("--loop", action='store_true', default=False, help="generate looping videos or not")
364
- parser.add_argument("--interp", action='store_true', default=False, help="generate generative frame interpolation or not")
365
- parser.add_argument("--use_unet", type=int, default=0, help="")
366
- parser.add_argument("--unet_path", type=str, default="path/to/unet", help="")
367
- parser.add_argument("--img_proj_path", type=str, default="path/to/img proj", help="")
368
-
369
-
370
- return parser
371
-
372
-
373
- if __name__ == '__main__':
374
- now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
375
- print("@DynamiCrafter cond-Inference: %s"%now)
376
- parser = get_parser()
377
- args = parser.parse_args()
378
-
379
- seed = args.seed
380
- if seed < 0:
381
- seed = random.randint(0, 2 ** 31)
382
- seed_everything(seed)
383
- #x = torch.randn(1,2,4,4)
384
- rank, gpu_num = 0, 1
385
- run_inference(args, gpu_num, rank)
386
-