from typing import List, Optional from dataclasses import dataclass import torch from torch.utils.data import DataLoader from torch.nn import functional as F import lpips import logging import matplotlib.pyplot as plt import imageio, PIL import os import math import pickle import numpy as np from dataset import LD3Dataset from utils import move_tensor_to_device, compute_distance_between_two, compute_distance_between_two_L1 def save_gif(snapshot_path: str): care_files = [f for f in os.listdir(snapshot_path) if "log_best" in f] care_files = sorted(care_files, key=lambda f: int(f.split("_")[-1].replace(".png", ""))) images = [] for f in care_files: images.append(imageio.imread(os.path.join(snapshot_path, f))) imageio.mimsave(os.path.join(snapshot_path, "gif.gif"), images, duration=100.) print(f"Saved gif to {os.path.join(snapshot_path, 'gif.gif')}") def visual(input_, name="test.png", img_resolution=32, img_channels=3): input_ = (input_ + 1.) / 2. batch_size = input_.shape[0] gridh = int(math.sqrt(batch_size)) for i in range(1, gridh+1): if batch_size % i == 0: gridh = i gridw = batch_size // gridh image = (input_ * 255.).clip(0, 255).to(torch.uint8) image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2) image = image.reshape(gridh * img_resolution, gridw * img_resolution, img_channels) image = image.cpu().numpy() PIL.Image.fromarray(image, 'RGB').save(name) def custom_collate_fn(batch): collated_batch = [] for samples in zip(*batch): if any(item is None for item in samples): collated_batch.append(None) else: collated_batch.append(torch.utils.data._utils.collate.default_collate(samples)) return collated_batch @dataclass class TrainingConfig: train_data: any valid_data: any train_batch_size: int valid_batch_size: int lr_time_1: float lr_time_2: float shift_lr: float shift_lr_decay: float = 0.5 min_lr_time_1: float = 5e-5 min_lr_time_2: float = 1e-6 win_rate: float = 0.5 patient: int = 5 lr2_patient: int = 5 lr_time_decay: float = 0.8 momentum_time_1: float = 0.9 weight_decay_time_1: float = 0.0 loss_type: str = "LPIPS" visualize: bool = False no_v1: bool = False prior_timesteps: Optional[List[float]] = None match_prior: bool = False @dataclass class ModelConfig: net: any decoding_fn: any noise_schedule: any solver: any solver_name: str order: int steps: int prior_bound: float resolution: int channels: int time_mode: str solver_extra_params: Optional[dict] = None snapshot_path: str = "logs" device: Optional[str] = None class LD3Trainer: def __init__( self, model_config: ModelConfig, training_config: TrainingConfig ) -> None: # Model parameters self.net = model_config.net self.decoding_fn = model_config.decoding_fn self.noise_schedule = model_config.noise_schedule self.solver = model_config.solver self.solver_name = model_config.solver_name self.order = model_config.order self.steps = model_config.steps self.prior_bound = model_config.prior_bound self.resolution = model_config.resolution self.channels = model_config.channels self.time_mode = model_config.time_mode # Learning rate parameters self.lr_time_1 = training_config.lr_time_1 self.lr_time_2 = training_config.lr_time_2 self.shift_lr = training_config.shift_lr self.shift_lr_decay = training_config.shift_lr_decay self.min_lr_time_1 = training_config.min_lr_time_1 self.min_lr_time_2 = training_config.min_lr_time_2 self.lr_time_decay = training_config.lr_time_decay self.momentum_time_1 = training_config.momentum_time_1 self.weight_decay_time_1 = training_config.weight_decay_time_1 # Training data and batch sizes self.train_data = training_config.train_data self.valid_data = training_config.valid_data self.train_batch_size = training_config.train_batch_size self.valid_batch_size = training_config.valid_batch_size self._create_valid_loaders() self._create_train_loader() # Training state self.cur_iter = 0 self.cur_round = 0 self.count_worse = 0 self.count_min_lr1_hit = 0 self.count_min_lr2_hit = 0 self.best_loss = float("inf") # Other parameters self.patient = training_config.patient self.lr2_patient = training_config.lr2_patient self.no_v1 = training_config.no_v1 self.win_rate = training_config.win_rate self.snapshot_path = model_config.snapshot_path os.makedirs(self.snapshot_path, exist_ok=True) self.visualize = training_config.visualize # Device and optimizer setup self._set_device(model_config.device) self.params1, self.params2 = self._initialize_params() self.optimizer_lamb1 = torch.optim.RMSprop( [self.params1], lr=training_config.lr_time_1, momentum=training_config.momentum_time_1, weight_decay=training_config.weight_decay_time_1, ) self.optimizer_lamb2 = torch.optim.SGD( [self.params2], lr=training_config.lr_time_2 ) self.prior_timesteps = training_config.prior_timesteps self.match_prior = training_config.match_prior # Additional attributes self.solver_extra_params = model_config.solver_extra_params or {} self.lambda_min = self.noise_schedule.lambda_min self.lambda_max = self.noise_schedule.lambda_max self.time_max = self.noise_schedule.inverse_lambda(self.lambda_min) self.time_min = self.noise_schedule.inverse_lambda(self.lambda_max) # Initialize baseline self._compute_baseline() # Initialize loss function self.loss_type = training_config.loss_type self.loss_fn = self._initialize_loss_fn() self.loss_vector = None def _train_to_match_prior(self, prior_timesteps=None): if prior_timesteps is None: prior_timesteps = self.prior_timesteps if prior_timesteps is None: return logging.info(f"Matching prior timesteps") prior_timesteps = self.noise_schedule.inverse_lambda(-np.log(prior_timesteps)).to(self.device).float() dis_model = discretize_model_wrapper( self.params1, self.params2, self.lambda_max, self.lambda_min, self.noise_schedule, self.time_mode, self.win_rate, ) self.params1.requires_grad = True self.params2.requires_grad = False loss_time = float("inf") while loss_time > 1e-3: self.optimizer_lamb1.zero_grad() self.optimizer_lamb2.zero_grad() times1, times2 = dis_model() loss_time = (times1 - prior_timesteps).pow(2).mean() logging.info(f"Loss time: {loss_time}") loss_time.backward() self.optimizer_lamb1.step() def _initialize_loss_fn(self): if self.loss_type == 'LPIPS': return lpips.LPIPS(net='vgg').to(self.device) elif self.loss_type == 'L2': return lambda x, y : compute_distance_between_two(x, y, self.channels, self.resolution) elif self.loss_type == 'L1': return lambda x, y: compute_distance_between_two_L1(x, y, self.channels, self.resolution) else: raise NotImplementedError def _initialize_params(self): params1 = torch.nn.Parameter(torch.ones(self.steps + 1, dtype=torch.float32).cuda(), requires_grad=True) params2 = torch.nn.Parameter(torch.zeros(self.steps + 1, dtype=torch.float32).cuda(), requires_grad=True) return params1, params2 def _set_device(self, device): self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") def _create_valid_loaders(self): self.valid_loader = DataLoader(self.valid_data, batch_size=self.train_batch_size, shuffle=False, collate_fn=custom_collate_fn) self.valid_only_loader = DataLoader(self.valid_data, batch_size=self.valid_batch_size, shuffle=False, collate_fn=custom_collate_fn) def _create_train_loader(self): self.train_loader = DataLoader(self.train_data, batch_size=self.train_batch_size, shuffle=True, collate_fn=custom_collate_fn) def _solve_ode(self, timesteps=None, img=None, latent=None, condition=None, uncondition=None, valid=False): batch_size = latent.shape[0] latent = latent.reshape(batch_size, self.channels, self.resolution, self.resolution) dis_model = discretize_model_wrapper( self.params1, self.params2, self.lambda_max, self.lambda_min, self.noise_schedule, self.time_mode, self.win_rate, ) if timesteps is None: timesteps1, timesteps2 = dis_model() else: timesteps1 = timesteps timesteps2 = timesteps if not valid and timesteps is None: tst = torch.cat([timesteps1, timesteps2], dim=0).detach().cpu() torch.save(tst, os.path.join(self.snapshot_path, f"t_steps.pt")) self.t_steps1 = timesteps1.detach() self.t_steps2 = timesteps2.detach() lamb1 = self.noise_schedule.marginal_lambda(timesteps1) lamb2 = self.noise_schedule.marginal_lambda(timesteps2) self.logSNR1 = lamb1.detach().cpu() self.logSNR2 = lamb2.detach().cpu() x_next_ = self.noise_schedule.prior_transformation(latent) # bs x 3 x 32 x 32 x_next_ = self.solver.sample_simple( model_fn=self.net, x=x_next_, timesteps=timesteps1, timesteps2=timesteps2, order=self.order, NFEs=self.steps, condition=condition, unconditional_condition=uncondition, **self.solver_extra_params, ) x_next_ = self.decoding_fn(x_next_) self.loss_vector = self.loss_fn(img.float(), x_next_.float()).squeeze() loss = self.loss_vector.mean() logging.info(f"{self._current_version} Loss: {loss.item()}") return loss, x_next_.float(), img.float() @property def _current_version(self): return 'Ver1' if self._is_in_version_1() else 'Ver2' def _is_in_version_1(self): return self.cur_round < self.training_rounds_v1 def _compute_baseline(self): self.straight_line = torch.linspace(self.lambda_min, self.lambda_max, self.steps + 1) self.time_logSNR = self.noise_schedule.inverse_lambda(self.straight_line).to(self.device) time_max = self.noise_schedule.inverse_lambda(self.lambda_min) time_min = self.noise_schedule.inverse_lambda(self.lambda_max) self.time_s = torch.linspace(time_max.item(), time_min.item(), 1000) self.time_straight = torch.linspace(time_max.item(), time_min.item(), self.steps + 1) self.time_straight = self.time_straight.to(self.device) self.straight_time = self.noise_schedule.marginal_lambda(self.time_s) t_order = 2 self.time_q = torch.linspace((time_max**(1/t_order)).item(), (time_min**(1/t_order)).item(), 1000)**t_order self.quadratic_time = torch.linspace((time_max**(1/t_order)).item(), (time_min**(1/t_order)).item(), self.steps + 1)**t_order self.quadratic_time = self.quadratic_time.to(self.device) self.time_quadratic = self.noise_schedule.marginal_lambda(self.time_q) # time_edm self.time_edm = self.solver.get_time_steps('edm', time_max.item(), time_min.item(), 999, self.device) self.lambda_edm = self.noise_schedule.marginal_lambda(self.time_edm) def _run_validation(self): total_loss = 0. count = 0 outputs = list() targets = list() with torch.no_grad(): for img, latent, ori_latent, condition, uncondition in self.valid_only_loader: # condition = condition.squeeze() # uncondition = uncondition.squeeze() img = img.to(self.device) latent = latent.to(self.device).reshape(latent.shape[0], -1) ori_latent = ori_latent.to(self.device).reshape(latent.shape[0], -1) if condition is not None: condition = condition.to(self.device) if uncondition is not None: uncondition = uncondition.to(self.device) loss, output, target = self._solve_ode(img=img, latent=latent, condition=condition, uncondition=uncondition, valid=True) total_loss += loss.item() count += 1 outputs.append(output) targets.append(target) output = torch.cat(outputs, dim=0) target = torch.cat(targets, dim=0) return total_loss / count, output, target def _visual_times(self) -> None: """ Visualize time discretization of baselines and ours """ log_path = os.path.join(self.snapshot_path, f"log_best_{self.cur_iter}.png") plt.plot(self.logSNR1.cpu().numpy(), 'o', label="Our discretization1") plt.plot(self.logSNR2.cpu().numpy(), 'x', label="Our discretization2") x_axis = np.linspace(0, self.steps, self.steps + 1) plt.plot(x_axis, self.straight_line.cpu().numpy(), label="Baseline logSNR") x_axis = np.linspace(0, self.steps, 1000) plt.plot(x_axis, self.straight_time.cpu().numpy(), label="Baseline time uniform") plt.plot(x_axis, self.time_quadratic.cpu().numpy(), label="Baseline time quadratic") plt.plot(x_axis, self.lambda_edm.cpu().numpy(), label="Baseline time edm") # draw a horizontal line at low_t_lambda plt.xlabel("Reverse step i") plt.ylabel("LogSNR(t_i)") plt.legend() plt.tight_layout() plt.savefig(log_path) plt.close() def _save_checkpoint(self): snapshot = {} snapshot["params1"] = self.params1.data snapshot["params2"] = self.params2.data snapshot["best_t_steps"] = torch.cat([self.t_steps1, self.t_steps2], dim=0) if self._is_in_version_1(): torch.save(snapshot, os.path.join(self.snapshot_path, "best_v1.pt")) torch.save(snapshot, os.path.join(self.snapshot_path, "best_v2.pt")) torch.save(snapshot, os.path.join(self.snapshot_path, f"best_t_steps_{self.cur_iter}.pt")) # save dataloader, valid_loader, valid_only_loader pickle.dump(self.train_data, open(os.path.join(self.snapshot_path, "train_data.pkl"), "wb")) pickle.dump(self.valid_data, open(os.path.join(self.snapshot_path, "valid_data.pkl"), "wb")) def _load_checkpoint(self, reload_data:bool): if self._is_in_version_1(): snapshot = torch.load(os.path.join(self.snapshot_path, "best_v1.pt")) else: snapshot = torch.load(os.path.join(self.snapshot_path, "best_v2.pt")) self.params1.data = snapshot["params1"].cuda() self.params2.data = snapshot["params2"].cuda() if reload_data: self.train_data = pickle.load(open(os.path.join(self.snapshot_path, "train_data.pkl"), "rb")) self.valid_data = pickle.load(open(os.path.join(self.snapshot_path, "valid_data.pkl"), "rb")) self._create_train_loader() self._create_valid_loaders() def _examine_checkpoint(self, iter: int) -> None: logging.info(f"{self._current_version} Saving snapshot at iter {iter}") total_loss, output, target = self._run_validation() if (iter % 5 == 0 or total_loss < self.best_loss) and self.visualize: visual(torch.cat([output[:8], target[:8]], dim=0), os.path.join(self.snapshot_path, f"learned_newnoise_ep{iter}.png"), img_resolution=self.resolution) if total_loss < self.best_loss: # latent cua valid k doi trong luc train. self.best_loss = total_loss self.count_worse = 0 self._save_checkpoint() self._visual_times() save_gif(self.snapshot_path) else: self.count_worse += 1 logging.info(f"{self._current_version} Count worse: {self.count_worse}") logging.info(f"{self._current_version} Validation loss: {total_loss}, best loss: {self.best_loss}") logging.info(f"{self._current_version} Iter {iter} snapshot saved!") if self.count_worse >= self.patient: logging.info(f"{self._current_version} Loading best model") self._load_checkpoint(reload_data=True) self.count_worse = 0 self.optimizer_lamb1.param_groups[0]['lr'] = max(self.lr_time_decay * self.optimizer_lamb1.param_groups[0]['lr'], self.min_lr_time_1) logging.info(f"{self._current_version} Decay time1 lr to {self.optimizer_lamb1.param_groups[0]['lr']}") if self._is_in_version_1(): if self.optimizer_lamb1.param_groups[0]['lr'] <= self.min_lr_time_1: self.count_min_lr1_hit += 1 else: self.optimizer_lamb2.param_groups[0]['lr'] = max(self.lr_time_decay * self.optimizer_lamb2.param_groups[0]['lr'], self.min_lr_time_2) logging.info(f"{self._current_version} Decay time2 lr to {self.optimizer_lamb2.param_groups[0]['lr']}") if self.optimizer_lamb2.param_groups[0]['lr'] <= self.min_lr_time_2: self.count_min_lr2_hit += 1 def _set_trainable_params(self, is_train:bool, is_no_v1:bool)->None: if is_train: self.params1.requires_grad = True self.params2.requires_grad = not self._is_in_version_1() if is_no_v1: self.params1.requires_grad = False self.params2.requires_grad = True else: self.params1.requires_grad = False self.params2.requires_grad = False def _log_valid_distance(self, ori_latent: torch.tensor, latent: torch.tensor): assert ori_latent.shape == latent.shape, "Shape of ori_latent and latent mismatched" sq = (latent.reshape(latent.shape[0], -1) - ori_latent.reshape(latent.shape[0], -1)).pow(2) distances = sq.sum(dim=1).sqrt().detach().cpu().numpy() logging.info(f"{self._current_version} Distance: {distances}") def _update_dataloader(self, ori_latents:List[torch.tensor], latents:List[torch.tensor], targets:List[torch.tensor], conditions: List[Optional[torch.tensor]], unconditions: List[Optional[torch.tensor]], is_train:bool): custom_train_dataset = LD3Dataset(ori_latents, latents, targets, conditions, unconditions) if is_train: self.train_data = custom_train_dataset self._create_train_loader() else: self.valid_data = custom_train_dataset self._create_valid_loaders() def _update_latents(self, latent, condition, uncondition, ori_latent, img, latent_params, loss_vector_ref, prior_bound): parameter_data_detached = latent_params.detach() cloned_ori_latent = ori_latent.clone() diff = parameter_data_detached.data - cloned_ori_latent diff_norm = diff.norm(dim=1, keepdim=True) pass_bound = diff_norm > prior_bound pass_bound = pass_bound.flatten() parameter_data_detached.data[pass_bound] = cloned_ori_latent[pass_bound] + prior_bound * diff[pass_bound] / diff_norm[pass_bound] _, _, _ = self._solve_ode(img=img, latent=parameter_data_detached.data, condition=condition, uncondition=uncondition, valid=False) to_update_mask = self.loss_vector < loss_vector_ref parameter_data_detached.data = parameter_data_detached.data.reshape(-1, self.channels, self.resolution, self.resolution) latent[to_update_mask] = parameter_data_detached.data[to_update_mask] return latent, to_update_mask def _train_one_round(self): no_change = True logging.info(f"{self._current_version} Round {self.cur_round}") if self.cur_round > 0: self._load_checkpoint(reload_data=False) self.count_worse = 0 self._examine_checkpoint(self.cur_iter) # run evaluation current latent and time steps for loader_idx, loader in enumerate([self.train_loader, self.valid_loader]): if loader_idx == 1 and self.prior_bound == 0.0: continue self._set_trainable_params(is_train=loader_idx == 0, is_no_v1=self.no_v1) ori_latents, latents, targets, conditions, unconditions = [], [], [], [], [] for img, latent, ori_latent, condition, uncondition in loader: img, latent, ori_latent, condition, uncondition = move_tensor_to_device(img, latent, ori_latent, condition, uncondition, device=self.device) if loader_idx == 1: self._log_valid_distance(ori_latent, latent) # Flattent latents batch_size = ori_latent.shape[0] ori_latent = ori_latent.reshape(batch_size, -1) latent_to_update = latent.clone().detach().reshape(batch_size, -1).to(self.device) latent_params = torch.nn.Parameter(latent_to_update) latent_params.requires_grad = True latent_optimizer = torch.optim.SGD([latent_params], lr=self.shift_lr) if img.device != latent_params.device: breakpoint() loss, _, _ = self._solve_ode(img=img, latent=latent_params, condition=condition, uncondition=uncondition, valid=False) loss_vector_ref = self.loss_vector.clone().detach() loss.backward() logging.info(f"{self._current_version} Iter {self.cur_iter} {'Train' if loader_idx == 0 else 'Val'} Loss: {loss.item()}") latent_optimizer.step() latent_optimizer.zero_grad() if loader_idx == 0: torch.nn.utils.clip_grad_norm_(self.params1, 1.0) torch.nn.utils.clip_grad_norm_(self.params2, 1.0) self.optimizer_lamb1.step() self.optimizer_lamb1.zero_grad() self.optimizer_lamb2.step() self.optimizer_lamb2.zero_grad() self.cur_iter += 1 self._examine_checkpoint(self.cur_iter) # evaluate if self.count_min_lr2_hit >= self.lr2_patient: logging.info(f"{self._current_version} Reach min lr2 5 times. Stop training.") return no_change, True with torch.no_grad(): latent, to_update_mask = self._update_latents(latent, condition, uncondition, ori_latent, img, latent_params, loss_vector_ref, self.prior_bound) if loader_idx == 1 and to_update_mask.sum().item() > 0: # check if this valid latent is moved no_change = False ori_latent = ori_latent.reshape(-1, self.channels, self.resolution, self.resolution).detach().cpu() latent = latent.reshape(-1, self.channels, self.resolution, self.resolution).detach().cpu() img = img.detach().cpu() condition = condition.detach().cpu() if condition is not None else None uncondition = uncondition.detach().cpu() if uncondition is not None else None for j in range(latent.shape[0]): ori_latents.append(ori_latent[j]) targets.append(img[j]) latents.append(latent[j]) conditions.append(condition[j] if condition is not None else None) unconditions.append(uncondition[j] if uncondition is not None else None) # update dataset if self.prior_bound > 0: self._update_dataloader(ori_latents, latents, targets, conditions, unconditions, is_train=loader_idx==0) return no_change, False def train(self, training_rounds_v1: int, training_rounds_v2: int) -> None: total_round = training_rounds_v1 + training_rounds_v2 self.training_rounds_v1 = training_rounds_v1 if self.match_prior: self._train_to_match_prior() while self.cur_round < total_round: no_latent_change, should_stop = self._train_one_round() if should_stop: return self.cur_round += 1 if no_latent_change and self.prior_bound > 0: self.shift_lr *= self.shift_lr_decay logging.info(f"{self._current_version} Max round reached, stopping") def discretize_model_wrapper(input1, input2, lambda_max, lambda_min, noise_schedule, mode, window_rate=0.5): ''' checked! ''' def model_time_fn(): time1, time2 = input1, input2 t_max, t_min = noise_schedule.inverse_lambda(lambda_min).to(time1.device), noise_schedule.inverse_lambda(lambda_max).to(time1.device) time_plus = torch.nn.functional.softmax(time1, dim=0) time_md = torch.cumsum(time_plus, dim=0).flip(0) normed = (time_md - time_md[-1]) / (time_md[0] - time_md[-1]) time_steps = normed * (t_max - t_min) + t_min cloned_time_steps = time_steps.clone().detach() max_move = (cloned_time_steps[1:] - cloned_time_steps[:-1]).abs().min().item() * window_rate clipped_time2 = torch.clamp(time2, min=-max_move, max=max_move) mask = torch.ones_like(normed) mask[0] = 0. mask[-1] = 0. return time_steps, time_steps + (clipped_time2 * mask) def model_lambda_fn(): lambda1, lambda2 = input1, input2 lamb_plus = F.softmax(lambda1, dim=0) lamb_md = torch.cumsum(lamb_plus, dim=0) normed = (lamb_md - lamb_md.min()) / (lamb_md.max() - lamb_md.min()) lamb_steps1 = normed * (lambda_max - lambda_min) + lambda_min mask = torch.ones_like(lamb_steps1) cloned_lamb1 = lambda1.clone().detach() max_move = (cloned_lamb1[1:] - cloned_lamb1[:-1]).abs().min().item() * window_rate clipped_lamb2 = torch.clamp(lambda2, min=-max_move, max=max_move) mask[0] = 0. mask[-1] = 0. lamb_steps2 = lamb_steps1 + clipped_lamb2 * mask time1 = noise_schedule.inverse_lambda(lamb_steps1) time2 = noise_schedule.inverse_lambda(lamb_steps2) return time1, time2 return model_time_fn if mode == 'time' else model_lambda_fn