| | from typing import List, Optional |
| | from dataclasses import dataclass |
| |
|
| | import torch |
| | from torch.utils.data import DataLoader |
| | from torch.nn import functional as F |
| |
|
| | import lpips |
| | import logging |
| | import matplotlib.pyplot as plt |
| | import imageio, PIL |
| |
|
| | import os |
| | import math |
| | import pickle |
| | import numpy as np |
| |
|
| | from dataset import LD3Dataset |
| | from utils import move_tensor_to_device, compute_distance_between_two, compute_distance_between_two_L1 |
| |
|
| | def save_gif(snapshot_path: str): |
| | care_files = [f for f in os.listdir(snapshot_path) if "log_best" in f] |
| | care_files = sorted(care_files, key=lambda f: int(f.split("_")[-1].replace(".png", ""))) |
| | images = [] |
| | for f in care_files: |
| | images.append(imageio.imread(os.path.join(snapshot_path, f))) |
| | imageio.mimsave(os.path.join(snapshot_path, "gif.gif"), images, duration=100.) |
| | print(f"Saved gif to {os.path.join(snapshot_path, 'gif.gif')}") |
| |
|
| |
|
| | def visual(input_, name="test.png", img_resolution=32, img_channels=3): |
| | input_ = (input_ + 1.) / 2. |
| | batch_size = input_.shape[0] |
| | gridh = int(math.sqrt(batch_size)) |
| | |
| | for i in range(1, gridh+1): |
| | if batch_size % i == 0: |
| | gridh = i |
| | |
| | gridw = batch_size // gridh |
| | image = (input_ * 255.).clip(0, 255).to(torch.uint8) |
| | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2) |
| | image = image.reshape(gridh * img_resolution, gridw * img_resolution, img_channels) |
| | image = image.cpu().numpy() |
| | PIL.Image.fromarray(image, 'RGB').save(name) |
| |
|
| | def custom_collate_fn(batch): |
| | collated_batch = [] |
| | for samples in zip(*batch): |
| | if any(item is None for item in samples): |
| | collated_batch.append(None) |
| | else: |
| | collated_batch.append(torch.utils.data._utils.collate.default_collate(samples)) |
| | return collated_batch |
| |
|
| | @dataclass |
| | class TrainingConfig: |
| | train_data: any |
| | valid_data: any |
| | train_batch_size: int |
| | valid_batch_size: int |
| | lr_time_1: float |
| | lr_time_2: float |
| | shift_lr: float |
| | shift_lr_decay: float = 0.5 |
| | min_lr_time_1: float = 5e-5 |
| | min_lr_time_2: float = 1e-6 |
| | win_rate: float = 0.5 |
| | patient: int = 5 |
| | lr2_patient: int = 5 |
| | lr_time_decay: float = 0.8 |
| | momentum_time_1: float = 0.9 |
| | weight_decay_time_1: float = 0.0 |
| | loss_type: str = "LPIPS" |
| | visualize: bool = False |
| | no_v1: bool = False |
| | prior_timesteps: Optional[List[float]] = None |
| | match_prior: bool = False |
| | |
| | @dataclass |
| | class ModelConfig: |
| | net: any |
| | decoding_fn: any |
| | noise_schedule: any |
| | solver: any |
| | solver_name: str |
| | order: int |
| | steps: int |
| | prior_bound: float |
| | resolution: int |
| | channels: int |
| | time_mode: str |
| | solver_extra_params: Optional[dict] = None |
| | snapshot_path: str = "logs" |
| | device: Optional[str] = None |
| |
|
| | class LD3Trainer: |
| | def __init__( |
| | self, model_config: ModelConfig, training_config: TrainingConfig |
| | ) -> None: |
| | |
| | self.net = model_config.net |
| | self.decoding_fn = model_config.decoding_fn |
| | self.noise_schedule = model_config.noise_schedule |
| | self.solver = model_config.solver |
| | self.solver_name = model_config.solver_name |
| | self.order = model_config.order |
| | self.steps = model_config.steps |
| | self.prior_bound = model_config.prior_bound |
| | self.resolution = model_config.resolution |
| | self.channels = model_config.channels |
| | self.time_mode = model_config.time_mode |
| |
|
| | |
| | self.lr_time_1 = training_config.lr_time_1 |
| | self.lr_time_2 = training_config.lr_time_2 |
| | self.shift_lr = training_config.shift_lr |
| | self.shift_lr_decay = training_config.shift_lr_decay |
| | self.min_lr_time_1 = training_config.min_lr_time_1 |
| | self.min_lr_time_2 = training_config.min_lr_time_2 |
| | self.lr_time_decay = training_config.lr_time_decay |
| | self.momentum_time_1 = training_config.momentum_time_1 |
| | self.weight_decay_time_1 = training_config.weight_decay_time_1 |
| |
|
| | |
| | self.train_data = training_config.train_data |
| | self.valid_data = training_config.valid_data |
| | self.train_batch_size = training_config.train_batch_size |
| | self.valid_batch_size = training_config.valid_batch_size |
| | self._create_valid_loaders() |
| | self._create_train_loader() |
| |
|
| | |
| | self.cur_iter = 0 |
| | self.cur_round = 0 |
| | self.count_worse = 0 |
| | self.count_min_lr1_hit = 0 |
| | self.count_min_lr2_hit = 0 |
| | self.best_loss = float("inf") |
| |
|
| | |
| | self.patient = training_config.patient |
| | self.lr2_patient = training_config.lr2_patient |
| | self.no_v1 = training_config.no_v1 |
| | self.win_rate = training_config.win_rate |
| | self.snapshot_path = model_config.snapshot_path |
| | os.makedirs(self.snapshot_path, exist_ok=True) |
| | self.visualize = training_config.visualize |
| |
|
| | |
| | self._set_device(model_config.device) |
| | self.params1, self.params2 = self._initialize_params() |
| | self.optimizer_lamb1 = torch.optim.RMSprop( |
| | [self.params1], |
| | lr=training_config.lr_time_1, |
| | momentum=training_config.momentum_time_1, |
| | weight_decay=training_config.weight_decay_time_1, |
| | ) |
| | self.optimizer_lamb2 = torch.optim.SGD( |
| | [self.params2], lr=training_config.lr_time_2 |
| | ) |
| | self.prior_timesteps = training_config.prior_timesteps |
| | self.match_prior = training_config.match_prior |
| |
|
| | |
| | self.solver_extra_params = model_config.solver_extra_params or {} |
| | self.lambda_min = self.noise_schedule.lambda_min |
| | self.lambda_max = self.noise_schedule.lambda_max |
| | self.time_max = self.noise_schedule.inverse_lambda(self.lambda_min) |
| | self.time_min = self.noise_schedule.inverse_lambda(self.lambda_max) |
| |
|
| | |
| | self._compute_baseline() |
| |
|
| | |
| | self.loss_type = training_config.loss_type |
| | self.loss_fn = self._initialize_loss_fn() |
| | self.loss_vector = None |
| |
|
| |
|
| | def _train_to_match_prior(self, prior_timesteps=None): |
| | if prior_timesteps is None: |
| | prior_timesteps = self.prior_timesteps |
| | |
| | if prior_timesteps is None: |
| | return |
| | logging.info(f"Matching prior timesteps") |
| | prior_timesteps = self.noise_schedule.inverse_lambda(-np.log(prior_timesteps)).to(self.device).float() |
| | |
| | dis_model = discretize_model_wrapper( |
| | self.params1, |
| | self.params2, |
| | self.lambda_max, |
| | self.lambda_min, |
| | self.noise_schedule, |
| | self.time_mode, |
| | self.win_rate, |
| | ) |
| | |
| | self.params1.requires_grad = True |
| | self.params2.requires_grad = False |
| | |
| | loss_time = float("inf") |
| | while loss_time > 1e-3: |
| | self.optimizer_lamb1.zero_grad() |
| | self.optimizer_lamb2.zero_grad() |
| | times1, times2 = dis_model() |
| | loss_time = (times1 - prior_timesteps).pow(2).mean() |
| | logging.info(f"Loss time: {loss_time}") |
| | loss_time.backward() |
| | self.optimizer_lamb1.step() |
| | |
| | def _initialize_loss_fn(self): |
| | if self.loss_type == 'LPIPS': |
| | return lpips.LPIPS(net='vgg').to(self.device) |
| | elif self.loss_type == 'L2': |
| | return lambda x, y : compute_distance_between_two(x, y, self.channels, self.resolution) |
| | elif self.loss_type == 'L1': |
| | return lambda x, y: compute_distance_between_two_L1(x, y, self.channels, self.resolution) |
| | else: |
| | raise NotImplementedError |
| | |
| | def _initialize_params(self): |
| | params1 = torch.nn.Parameter(torch.ones(self.steps + 1, dtype=torch.float32).cuda(), requires_grad=True) |
| | params2 = torch.nn.Parameter(torch.zeros(self.steps + 1, dtype=torch.float32).cuda(), requires_grad=True) |
| | return params1, params2 |
| |
|
| | def _set_device(self, device): |
| | self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | def _create_valid_loaders(self): |
| | self.valid_loader = DataLoader(self.valid_data, batch_size=self.train_batch_size, shuffle=False, collate_fn=custom_collate_fn) |
| | self.valid_only_loader = DataLoader(self.valid_data, batch_size=self.valid_batch_size, shuffle=False, collate_fn=custom_collate_fn) |
| |
|
| | def _create_train_loader(self): |
| | self.train_loader = DataLoader(self.train_data, batch_size=self.train_batch_size, shuffle=True, collate_fn=custom_collate_fn) |
| | |
| | def _solve_ode(self, timesteps=None, img=None, latent=None, condition=None, uncondition=None, valid=False): |
| | batch_size = latent.shape[0] |
| | latent = latent.reshape(batch_size, self.channels, self.resolution, self.resolution) |
| | dis_model = discretize_model_wrapper( |
| | self.params1, |
| | self.params2, |
| | self.lambda_max, |
| | self.lambda_min, |
| | self.noise_schedule, |
| | self.time_mode, |
| | self.win_rate, |
| | ) |
| |
|
| | if timesteps is None: |
| | timesteps1, timesteps2 = dis_model() |
| | else: |
| | timesteps1 = timesteps |
| | timesteps2 = timesteps |
| |
|
| | if not valid and timesteps is None: |
| | tst = torch.cat([timesteps1, timesteps2], dim=0).detach().cpu() |
| | torch.save(tst, os.path.join(self.snapshot_path, f"t_steps.pt")) |
| |
|
| | self.t_steps1 = timesteps1.detach() |
| | self.t_steps2 = timesteps2.detach() |
| | lamb1 = self.noise_schedule.marginal_lambda(timesteps1) |
| | lamb2 = self.noise_schedule.marginal_lambda(timesteps2) |
| | self.logSNR1 = lamb1.detach().cpu() |
| | self.logSNR2 = lamb2.detach().cpu() |
| |
|
| | x_next_ = self.noise_schedule.prior_transformation(latent) |
| | x_next_ = self.solver.sample_simple( |
| | model_fn=self.net, |
| | x=x_next_, |
| | timesteps=timesteps1, |
| | timesteps2=timesteps2, |
| | order=self.order, |
| | NFEs=self.steps, |
| | condition=condition, |
| | unconditional_condition=uncondition, |
| | **self.solver_extra_params, |
| | ) |
| | x_next_ = self.decoding_fn(x_next_) |
| | self.loss_vector = self.loss_fn(img.float(), x_next_.float()).squeeze() |
| | loss = self.loss_vector.mean() |
| | logging.info(f"{self._current_version} Loss: {loss.item()}") |
| |
|
| | return loss, x_next_.float(), img.float() |
| |
|
| |
|
| | @property |
| | def _current_version(self): |
| | return 'Ver1' if self._is_in_version_1() else 'Ver2' |
| |
|
| | def _is_in_version_1(self): |
| | return self.cur_round < self.training_rounds_v1 |
| |
|
| | def _compute_baseline(self): |
| | self.straight_line = torch.linspace(self.lambda_min, self.lambda_max, self.steps + 1) |
| | self.time_logSNR = self.noise_schedule.inverse_lambda(self.straight_line).to(self.device) |
| | time_max = self.noise_schedule.inverse_lambda(self.lambda_min) |
| | time_min = self.noise_schedule.inverse_lambda(self.lambda_max) |
| | self.time_s = torch.linspace(time_max.item(), time_min.item(), 1000) |
| | self.time_straight = torch.linspace(time_max.item(), time_min.item(), self.steps + 1) |
| | self.time_straight = self.time_straight.to(self.device) |
| | self.straight_time = self.noise_schedule.marginal_lambda(self.time_s) |
| | t_order = 2 |
| | self.time_q = torch.linspace((time_max**(1/t_order)).item(), (time_min**(1/t_order)).item(), 1000)**t_order |
| | self.quadratic_time = torch.linspace((time_max**(1/t_order)).item(), (time_min**(1/t_order)).item(), self.steps + 1)**t_order |
| |
|
| | self.quadratic_time = self.quadratic_time.to(self.device) |
| | self.time_quadratic = self.noise_schedule.marginal_lambda(self.time_q) |
| | |
| | self.time_edm = self.solver.get_time_steps('edm', time_max.item(), time_min.item(), 999, self.device) |
| | self.lambda_edm = self.noise_schedule.marginal_lambda(self.time_edm) |
| | |
| | def _run_validation(self): |
| | total_loss = 0. |
| | count = 0 |
| | outputs = list() |
| | targets = list() |
| | with torch.no_grad(): |
| | for img, latent, ori_latent, condition, uncondition in self.valid_only_loader: |
| | |
| | |
| | img = img.to(self.device) |
| | latent = latent.to(self.device).reshape(latent.shape[0], -1) |
| | ori_latent = ori_latent.to(self.device).reshape(latent.shape[0], -1) |
| | if condition is not None: |
| | condition = condition.to(self.device) |
| | if uncondition is not None: |
| | uncondition = uncondition.to(self.device) |
| | loss, output, target = self._solve_ode(img=img, latent=latent, condition=condition, uncondition=uncondition, valid=True) |
| | |
| | total_loss += loss.item() |
| | count += 1 |
| | outputs.append(output) |
| | targets.append(target) |
| |
|
| | output = torch.cat(outputs, dim=0) |
| | target = torch.cat(targets, dim=0) |
| | return total_loss / count, output, target |
| | |
| | def _visual_times(self) -> None: |
| | """ |
| | Visualize time discretization of baselines and ours |
| | """ |
| |
|
| | log_path = os.path.join(self.snapshot_path, f"log_best_{self.cur_iter}.png") |
| |
|
| | plt.plot(self.logSNR1.cpu().numpy(), 'o', label="Our discretization1") |
| | plt.plot(self.logSNR2.cpu().numpy(), 'x', label="Our discretization2") |
| | x_axis = np.linspace(0, self.steps, self.steps + 1) |
| | plt.plot(x_axis, self.straight_line.cpu().numpy(), label="Baseline logSNR") |
| | x_axis = np.linspace(0, self.steps, 1000) |
| | plt.plot(x_axis, self.straight_time.cpu().numpy(), label="Baseline time uniform") |
| | plt.plot(x_axis, self.time_quadratic.cpu().numpy(), label="Baseline time quadratic") |
| | plt.plot(x_axis, self.lambda_edm.cpu().numpy(), label="Baseline time edm") |
| |
|
| | |
| | plt.xlabel("Reverse step i") |
| | plt.ylabel("LogSNR(t_i)") |
| | plt.legend() |
| | plt.tight_layout() |
| | plt.savefig(log_path) |
| | plt.close() |
| |
|
| | def _save_checkpoint(self): |
| | snapshot = {} |
| | snapshot["params1"] = self.params1.data |
| | snapshot["params2"] = self.params2.data |
| | snapshot["best_t_steps"] = torch.cat([self.t_steps1, self.t_steps2], dim=0) |
| |
|
| | if self._is_in_version_1(): |
| | torch.save(snapshot, os.path.join(self.snapshot_path, "best_v1.pt")) |
| | torch.save(snapshot, os.path.join(self.snapshot_path, "best_v2.pt")) |
| | torch.save(snapshot, os.path.join(self.snapshot_path, f"best_t_steps_{self.cur_iter}.pt")) |
| |
|
| | |
| | pickle.dump(self.train_data, open(os.path.join(self.snapshot_path, "train_data.pkl"), "wb")) |
| | pickle.dump(self.valid_data, open(os.path.join(self.snapshot_path, "valid_data.pkl"), "wb")) |
| | |
| | def _load_checkpoint(self, reload_data:bool): |
| | if self._is_in_version_1(): |
| | snapshot = torch.load(os.path.join(self.snapshot_path, "best_v1.pt")) |
| | else: |
| | snapshot = torch.load(os.path.join(self.snapshot_path, "best_v2.pt")) |
| | |
| | self.params1.data = snapshot["params1"].cuda() |
| | self.params2.data = snapshot["params2"].cuda() |
| | |
| | if reload_data: |
| | self.train_data = pickle.load(open(os.path.join(self.snapshot_path, "train_data.pkl"), "rb")) |
| | self.valid_data = pickle.load(open(os.path.join(self.snapshot_path, "valid_data.pkl"), "rb")) |
| | self._create_train_loader() |
| | self._create_valid_loaders() |
| |
|
| | def _examine_checkpoint(self, iter: int) -> None: |
| | logging.info(f"{self._current_version} Saving snapshot at iter {iter}") |
| | total_loss, output, target = self._run_validation() |
| |
|
| | if (iter % 5 == 0 or total_loss < self.best_loss) and self.visualize: |
| | visual(torch.cat([output[:8], target[:8]], dim=0), os.path.join(self.snapshot_path, f"learned_newnoise_ep{iter}.png"), img_resolution=self.resolution) |
| | |
| | if total_loss < self.best_loss: |
| | self.best_loss = total_loss |
| | self.count_worse = 0 |
| | self._save_checkpoint() |
| | self._visual_times() |
| | save_gif(self.snapshot_path) |
| | else: |
| | self.count_worse += 1 |
| | logging.info(f"{self._current_version} Count worse: {self.count_worse}") |
| | |
| | logging.info(f"{self._current_version} Validation loss: {total_loss}, best loss: {self.best_loss}") |
| | logging.info(f"{self._current_version} Iter {iter} snapshot saved!") |
| | |
| | if self.count_worse >= self.patient: |
| | logging.info(f"{self._current_version} Loading best model") |
| | self._load_checkpoint(reload_data=True) |
| | self.count_worse = 0 |
| |
|
| | self.optimizer_lamb1.param_groups[0]['lr'] = max(self.lr_time_decay * self.optimizer_lamb1.param_groups[0]['lr'], self.min_lr_time_1) |
| | logging.info(f"{self._current_version} Decay time1 lr to {self.optimizer_lamb1.param_groups[0]['lr']}") |
| |
|
| | if self._is_in_version_1(): |
| | if self.optimizer_lamb1.param_groups[0]['lr'] <= self.min_lr_time_1: |
| | self.count_min_lr1_hit += 1 |
| | else: |
| | self.optimizer_lamb2.param_groups[0]['lr'] = max(self.lr_time_decay * self.optimizer_lamb2.param_groups[0]['lr'], self.min_lr_time_2) |
| | logging.info(f"{self._current_version} Decay time2 lr to {self.optimizer_lamb2.param_groups[0]['lr']}") |
| | if self.optimizer_lamb2.param_groups[0]['lr'] <= self.min_lr_time_2: |
| | self.count_min_lr2_hit += 1 |
| |
|
| | def _set_trainable_params(self, is_train:bool, is_no_v1:bool)->None: |
| | if is_train: |
| | self.params1.requires_grad = True |
| | self.params2.requires_grad = not self._is_in_version_1() |
| | |
| | if is_no_v1: |
| | self.params1.requires_grad = False |
| | self.params2.requires_grad = True |
| | |
| | else: |
| | self.params1.requires_grad = False |
| | self.params2.requires_grad = False |
| |
|
| | def _log_valid_distance(self, ori_latent: torch.tensor, latent: torch.tensor): |
| | assert ori_latent.shape == latent.shape, "Shape of ori_latent and latent mismatched" |
| | sq = (latent.reshape(latent.shape[0], -1) - ori_latent.reshape(latent.shape[0], -1)).pow(2) |
| | distances = sq.sum(dim=1).sqrt().detach().cpu().numpy() |
| | logging.info(f"{self._current_version} Distance: {distances}") |
| |
|
| | def _update_dataloader(self, ori_latents:List[torch.tensor], |
| | latents:List[torch.tensor], |
| | targets:List[torch.tensor], |
| | conditions: List[Optional[torch.tensor]], |
| | unconditions: List[Optional[torch.tensor]], |
| | is_train:bool): |
| | custom_train_dataset = LD3Dataset(ori_latents, latents, targets, conditions, unconditions) |
| | if is_train: |
| | self.train_data = custom_train_dataset |
| | self._create_train_loader() |
| | else: |
| | self.valid_data = custom_train_dataset |
| | self._create_valid_loaders() |
| |
|
| | def _update_latents(self, latent, condition, uncondition, ori_latent, img, latent_params, loss_vector_ref, prior_bound): |
| | parameter_data_detached = latent_params.detach() |
| | cloned_ori_latent = ori_latent.clone() |
| | diff = parameter_data_detached.data - cloned_ori_latent |
| | diff_norm = diff.norm(dim=1, keepdim=True) |
| | pass_bound = diff_norm > prior_bound |
| | pass_bound = pass_bound.flatten() |
| | parameter_data_detached.data[pass_bound] = cloned_ori_latent[pass_bound] + prior_bound * diff[pass_bound] / diff_norm[pass_bound] |
| | |
| | _, _, _ = self._solve_ode(img=img, latent=parameter_data_detached.data, condition=condition, uncondition=uncondition, valid=False) |
| | |
| | to_update_mask = self.loss_vector < loss_vector_ref |
| | parameter_data_detached.data = parameter_data_detached.data.reshape(-1, self.channels, self.resolution, self.resolution) |
| | latent[to_update_mask] = parameter_data_detached.data[to_update_mask] |
| | return latent, to_update_mask |
| |
|
| | def _train_one_round(self): |
| | no_change = True |
| | logging.info(f"{self._current_version} Round {self.cur_round}") |
| |
|
| | if self.cur_round > 0: |
| | self._load_checkpoint(reload_data=False) |
| | self.count_worse = 0 |
| | |
| | self._examine_checkpoint(self.cur_iter) |
| |
|
| | for loader_idx, loader in enumerate([self.train_loader, self.valid_loader]): |
| | if loader_idx == 1 and self.prior_bound == 0.0: |
| | continue |
| |
|
| | self._set_trainable_params(is_train=loader_idx == 0, is_no_v1=self.no_v1) |
| | |
| | ori_latents, latents, targets, conditions, unconditions = [], [], [], [], [] |
| | for img, latent, ori_latent, condition, uncondition in loader: |
| | img, latent, ori_latent, condition, uncondition = move_tensor_to_device(img, latent, ori_latent, condition, uncondition, device=self.device) |
| | if loader_idx == 1: |
| | self._log_valid_distance(ori_latent, latent) |
| | |
| | |
| | batch_size = ori_latent.shape[0] |
| | ori_latent = ori_latent.reshape(batch_size, -1) |
| | latent_to_update = latent.clone().detach().reshape(batch_size, -1).to(self.device) |
| | latent_params = torch.nn.Parameter(latent_to_update) |
| | latent_params.requires_grad = True |
| | |
| | latent_optimizer = torch.optim.SGD([latent_params], lr=self.shift_lr) |
| | if img.device != latent_params.device: |
| | breakpoint() |
| | loss, _, _ = self._solve_ode(img=img, latent=latent_params, condition=condition, uncondition=uncondition, valid=False) |
| | loss_vector_ref = self.loss_vector.clone().detach() |
| | loss.backward() |
| | logging.info(f"{self._current_version} Iter {self.cur_iter} {'Train' if loader_idx == 0 else 'Val'} Loss: {loss.item()}") |
| | |
| | latent_optimizer.step() |
| | latent_optimizer.zero_grad() |
| |
|
| | if loader_idx == 0: |
| | torch.nn.utils.clip_grad_norm_(self.params1, 1.0) |
| | torch.nn.utils.clip_grad_norm_(self.params2, 1.0) |
| | self.optimizer_lamb1.step() |
| | self.optimizer_lamb1.zero_grad() |
| | self.optimizer_lamb2.step() |
| | self.optimizer_lamb2.zero_grad() |
| |
|
| | self.cur_iter += 1 |
| | self._examine_checkpoint(self.cur_iter) |
| | if self.count_min_lr2_hit >= self.lr2_patient: |
| | logging.info(f"{self._current_version} Reach min lr2 5 times. Stop training.") |
| | return no_change, True |
| | |
| | with torch.no_grad(): |
| | latent, to_update_mask = self._update_latents(latent, condition, uncondition, ori_latent, img, latent_params, loss_vector_ref, self.prior_bound) |
| | if loader_idx == 1 and to_update_mask.sum().item() > 0: |
| | |
| | no_change = False |
| | |
| | ori_latent = ori_latent.reshape(-1, self.channels, self.resolution, self.resolution).detach().cpu() |
| | latent = latent.reshape(-1, self.channels, self.resolution, self.resolution).detach().cpu() |
| | img = img.detach().cpu() |
| | condition = condition.detach().cpu() if condition is not None else None |
| | uncondition = uncondition.detach().cpu() if uncondition is not None else None |
| | |
| | for j in range(latent.shape[0]): |
| | ori_latents.append(ori_latent[j]) |
| | targets.append(img[j]) |
| | latents.append(latent[j]) |
| | conditions.append(condition[j] if condition is not None else None) |
| | unconditions.append(uncondition[j] if uncondition is not None else None) |
| | |
| | |
| | if self.prior_bound > 0: |
| | self._update_dataloader(ori_latents, latents, targets, conditions, unconditions, is_train=loader_idx==0) |
| | |
| | return no_change, False |
| | |
| | def train(self, training_rounds_v1: int, training_rounds_v2: int) -> None: |
| | |
| | total_round = training_rounds_v1 + training_rounds_v2 |
| | self.training_rounds_v1 = training_rounds_v1 |
| |
|
| | if self.match_prior: |
| | self._train_to_match_prior() |
| |
|
| | while self.cur_round < total_round: |
| | no_latent_change, should_stop = self._train_one_round() |
| | if should_stop: |
| | return |
| | self.cur_round += 1 |
| | |
| | if no_latent_change and self.prior_bound > 0: |
| | self.shift_lr *= self.shift_lr_decay |
| | |
| | logging.info(f"{self._current_version} Max round reached, stopping") |
| |
|
| | def discretize_model_wrapper(input1, input2, lambda_max, lambda_min, noise_schedule, mode, window_rate=0.5): |
| | ''' |
| | checked! |
| | ''' |
| | |
| | def model_time_fn(): |
| | time1, time2 = input1, input2 |
| | t_max, t_min = noise_schedule.inverse_lambda(lambda_min).to(time1.device), noise_schedule.inverse_lambda(lambda_max).to(time1.device) |
| | time_plus = torch.nn.functional.softmax(time1, dim=0) |
| | time_md = torch.cumsum(time_plus, dim=0).flip(0) |
| | normed = (time_md - time_md[-1]) / (time_md[0] - time_md[-1]) |
| | time_steps = normed * (t_max - t_min) + t_min |
| | cloned_time_steps = time_steps.clone().detach() |
| | max_move = (cloned_time_steps[1:] - cloned_time_steps[:-1]).abs().min().item() * window_rate |
| | clipped_time2 = torch.clamp(time2, min=-max_move, max=max_move) |
| | mask = torch.ones_like(normed) |
| | mask[0] = 0. |
| | mask[-1] = 0. |
| | return time_steps, time_steps + (clipped_time2 * mask) |
| |
|
| | def model_lambda_fn(): |
| | lambda1, lambda2 = input1, input2 |
| | lamb_plus = F.softmax(lambda1, dim=0) |
| | lamb_md = torch.cumsum(lamb_plus, dim=0) |
| | normed = (lamb_md - lamb_md.min()) / (lamb_md.max() - lamb_md.min()) |
| | lamb_steps1 = normed * (lambda_max - lambda_min) + lambda_min |
| | mask = torch.ones_like(lamb_steps1) |
| | |
| | cloned_lamb1 = lambda1.clone().detach() |
| | max_move = (cloned_lamb1[1:] - cloned_lamb1[:-1]).abs().min().item() * window_rate |
| | clipped_lamb2 = torch.clamp(lambda2, min=-max_move, max=max_move) |
| | |
| | mask[0] = 0. |
| | mask[-1] = 0. |
| | |
| | lamb_steps2 = lamb_steps1 + clipped_lamb2 * mask |
| |
|
| | time1 = noise_schedule.inverse_lambda(lamb_steps1) |
| | time2 = noise_schedule.inverse_lambda(lamb_steps2) |
| | return time1, time2 |
| |
|
| | return model_time_fn if mode == 'time' else model_lambda_fn |
| |
|