""" Dongpin Oh: dongpin.oh@moreh.io """ import time import numpy as np import torch import torch.nn.functional as F from diffusers.utils.torch_utils import randn_tensor from loguru import logger class FlowMixin: """Mixin class for flow-based models.""" _is_noise_initialized = False MIN_STD: float = 0.0 # minimum size of std for the flow matching CLAMP_CONTINUOUS_TIME: float = 0.0 def _timestep_shifting(self, timesteps: torch.Tensor, alpha: float = 3.0) -> torch.Tensor: """ Adjust the timesteps for higher resolution images by adding more noise. higher resolution have more pixels we need more noise to destory their signal. NOTE that the timestep must be reversed unlike original SD3 style timestep shifting, since the flow-timestep is reversed (t=0: original image; t=1: pure noise) Args: timesteps (torch.Tensor): The original timesteps. alpha (float, optional): Scaling factor for timestep shifting. Defaults to 3.0. Returns: torch.Tensor: The reversed and shifted timesteps. """ shifted_t = alpha * timesteps / (1 + (alpha - 1) * timesteps) reversed_t = 1 - shifted_t return reversed_t def get_noisy_input( self, input: torch.Tensor, normal_mean: float = 0.0, normal_std: float = 1.0, is_finetuning: bool = False, t: torch.Tensor | None = None, n_timesteps: int = 1000, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Generate noisy input, based on the optimal-transport flow. Args: input (torch.Tensor): Input tensor. normal_mean (float, optional): Mean of the normal distribution for noise. Defaults to 0.0. normal_std (float, optional): Standard deviation of the normal distribution for noise. Defaults to 1.0. is_finetuning (bool, optional): Whether the model is in finetuning mode. Defaults to False. t (torch.Tensor | None, optional): Predefined timesteps. If None, timesteps are sampled. Defaults to None. n_timesteps (int, opyionsl): Number of discrete timesteps. Defaults to 1000. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing noise, noisy input, and timestep. - The timestep t is a continuous timestep ranged [0, 1] which cannot directly be used for the timestep embedding (needs to be discretized by self.discritize_timestep()). """ b = input.shape[0] if not FlowMixin._is_noise_initialized: logger.warning("The torch random seed is changed when generating the initial random noise.") current_time = int(time.time() * 1000) torch.manual_seed(current_time) FlowMixin._is_noise_initialized = True noise = randn_tensor(input.shape).cuda() # Sample timestep from a log-normal distribution with mean 0 and std 1 if t is None: # NOTE: timestep is sampled from log-normal distribution to make the model # focus on the intermediate timesteps, which are the most informative part of flow-ODE t = torch.randn(b).cuda() * normal_std + normal_mean t = torch.sigmoid(t) else: t = t # Clamp t to be within the interval [0, 1] for numerical stability t = t.clamp(0 + self.CLAMP_CONTINUOUS_TIME, 1 - self.CLAMP_CONTINUOUS_TIME) if is_finetuning: t = self._timestep_shifting(t) # Reshape t to match the dimensions required for _ in range(len(noise.shape) - 1): t = t.unsqueeze(1) # Generate the noisy input noisy_input = (1 - t) * input + (self.MIN_STD + (1 - self.MIN_STD) * t) * noise t_squeezed = t.squeeze() if t_squeezed.dim() == 0: t_squeezed = t_squeezed.unsqueeze(0) return noise, noisy_input, t_squeezed @torch.no_grad() def _logit_norm(self, t: torch.Tensor, m: float = 0, s: float = 1) -> torch.Tensor: """ Compute the loss-weight for the flow-matching loss. It will be focusing (giving high weights) on the intermidate timestep, since such timesteps are hard to be matched, according to https://arxiv.org/pdf/2403.03206.pdf Args: t (torch.Tensor): Timestep tensor. (0 to 1) m (float, optional): Mean of the logit distribution. Defaults to 0. s (float, optional): Standard deviation of the logit distribution. Defaults to 1. Returns: torch.Tensor: Weight tensor for the flow-matching loss. """ coef = (1 / (s * ((2 * np.pi) ** 0.5))) * (1 / (t * (1 - t))) def logit(x): return torch.log(x) - torch.log(1 - x) exp = torch.exp(-((logit(t) - m) ** 2) / (2 * s**2)) return coef * exp def rectified_flow_loss( self, input: torch.Tensor, noise: torch.Tensor, t: torch.Tensor, preds: torch.Tensor, use_weighting: bool = False, reduce: str = "mean", ) -> torch.Tensor: """ Compute the rectified flow loss, https://arxiv.org/pdf/2403.03206.pdf Args: input (torch.Tensor): Input tensor. noise (torch.Tensor): Noise tensor. t (torch.Tensor): Timestep tensor. preds (torch.Tensor): Predicted tensor. use_weighting (bool, optional): Whether to use weighting for the loss. Defaults to False. reduce (str, optional): Reduction method for the loss. Options are 'mean' or 'none'. Defaults to 'mean'. Returns: torch.Tensor: Rectified flow loss. """ # Matching dimension for broadcasting t = t.reshape(t.shape[0], *[1 for _ in range(len(input.shape) - len(t.shape))]) target_flow = (1 - self.MIN_STD) * noise - input loss = F.mse_loss(preds.float(), target_flow.float(), reduction="none") if use_weighting: weight = self._logit_norm(t).detach() loss = loss * weight if reduce == "mean": loss = loss.mean() elif reduce == "none": loss = loss else: raise NotImplementedError return loss def discritize_timestep(self, t: torch.Tensor, n_timesteps: int = 1000) -> torch.Tensor: """ Discretize the continuous timestep. Args: t (torch.Tensor): Continuous timestep. n_timesteps (int, optional): Number of discrete timesteps. Defaults to 1000. Returns: torch.Tensor: Discretized timestep tensor. """ return (t * n_timesteps).round().long()