|
""" |
|
Dongpin Oh: [email protected] |
|
""" |
|
|
|
import time |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from loguru import logger |
|
|
|
|
|
class FlowMixin: |
|
"""Mixin class for flow-based models.""" |
|
|
|
_is_noise_initialized = False |
|
MIN_STD: float = 0.0 |
|
CLAMP_CONTINUOUS_TIME: float = 0.0 |
|
|
|
def _timestep_shifting(self, timesteps: torch.Tensor, alpha: float = 3.0) -> torch.Tensor: |
|
""" |
|
Adjust the timesteps for higher resolution images by adding more noise. |
|
higher resolution have more pixels we need more noise to destory their signal. |
|
|
|
NOTE that the timestep must be reversed unlike original SD3 style timestep shifting, |
|
since the flow-timestep is reversed (t=0: original image; t=1: pure noise) |
|
Args: |
|
timesteps (torch.Tensor): The original timesteps. |
|
alpha (float, optional): Scaling factor for timestep shifting. Defaults to 3.0. |
|
|
|
Returns: |
|
torch.Tensor: The reversed and shifted timesteps. |
|
""" |
|
shifted_t = alpha * timesteps / (1 + (alpha - 1) * timesteps) |
|
reversed_t = 1 - shifted_t |
|
return reversed_t |
|
|
|
def get_noisy_input( |
|
self, |
|
input: torch.Tensor, |
|
normal_mean: float = 0.0, |
|
normal_std: float = 1.0, |
|
is_finetuning: bool = False, |
|
t: torch.Tensor | None = None, |
|
n_timesteps: int = 1000, |
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Generate noisy input, based on the optimal-transport flow. |
|
|
|
Args: |
|
input (torch.Tensor): Input tensor. |
|
normal_mean (float, optional): Mean of the normal distribution for noise. Defaults to 0.0. |
|
normal_std (float, optional): Standard deviation of the normal distribution for noise. Defaults to 1.0. |
|
is_finetuning (bool, optional): Whether the model is in finetuning mode. Defaults to False. |
|
t (torch.Tensor | None, optional): Predefined timesteps. If None, timesteps are sampled. Defaults to None. |
|
n_timesteps (int, opyionsl): Number of discrete timesteps. Defaults to 1000. |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing noise, noisy input, and timestep. |
|
- The timestep t is a continuous timestep ranged [0, 1] which cannot directly be used |
|
for the timestep embedding (needs to be discretized by self.discritize_timestep()). |
|
""" |
|
b = input.shape[0] |
|
|
|
if not FlowMixin._is_noise_initialized: |
|
logger.warning("The torch random seed is changed when generating the initial random noise.") |
|
current_time = int(time.time() * 1000) |
|
torch.manual_seed(current_time) |
|
FlowMixin._is_noise_initialized = True |
|
|
|
noise = randn_tensor(input.shape).cuda() |
|
|
|
|
|
if t is None: |
|
|
|
|
|
t = torch.randn(b).cuda() * normal_std + normal_mean |
|
t = torch.sigmoid(t) |
|
else: |
|
t = t |
|
|
|
|
|
t = t.clamp(0 + self.CLAMP_CONTINUOUS_TIME, 1 - self.CLAMP_CONTINUOUS_TIME) |
|
|
|
if is_finetuning: |
|
t = self._timestep_shifting(t) |
|
|
|
|
|
for _ in range(len(noise.shape) - 1): |
|
t = t.unsqueeze(1) |
|
|
|
|
|
noisy_input = (1 - t) * input + (self.MIN_STD + (1 - self.MIN_STD) * t) * noise |
|
|
|
t_squeezed = t.squeeze() |
|
if t_squeezed.dim() == 0: |
|
t_squeezed = t_squeezed.unsqueeze(0) |
|
return noise, noisy_input, t_squeezed |
|
|
|
@torch.no_grad() |
|
def _logit_norm(self, t: torch.Tensor, m: float = 0, s: float = 1) -> torch.Tensor: |
|
""" |
|
Compute the loss-weight for the flow-matching loss. |
|
It will be focusing (giving high weights) on the intermidate timestep, since |
|
such timesteps are hard to be matched, according to https://arxiv.org/pdf/2403.03206.pdf |
|
Args: |
|
t (torch.Tensor): Timestep tensor. (0 to 1) |
|
m (float, optional): Mean of the logit distribution. Defaults to 0. |
|
s (float, optional): Standard deviation of the logit distribution. Defaults to 1. |
|
|
|
Returns: |
|
torch.Tensor: Weight tensor for the flow-matching loss. |
|
""" |
|
coef = (1 / (s * ((2 * np.pi) ** 0.5))) * (1 / (t * (1 - t))) |
|
|
|
def logit(x): |
|
return torch.log(x) - torch.log(1 - x) |
|
|
|
exp = torch.exp(-((logit(t) - m) ** 2) / (2 * s**2)) |
|
return coef * exp |
|
|
|
def rectified_flow_loss( |
|
self, |
|
input: torch.Tensor, |
|
noise: torch.Tensor, |
|
t: torch.Tensor, |
|
preds: torch.Tensor, |
|
use_weighting: bool = False, |
|
reduce: str = "mean", |
|
) -> torch.Tensor: |
|
""" |
|
Compute the rectified flow loss, https://arxiv.org/pdf/2403.03206.pdf |
|
|
|
Args: |
|
input (torch.Tensor): Input tensor. |
|
noise (torch.Tensor): Noise tensor. |
|
t (torch.Tensor): Timestep tensor. |
|
preds (torch.Tensor): Predicted tensor. |
|
use_weighting (bool, optional): Whether to use weighting for the loss. Defaults to False. |
|
reduce (str, optional): Reduction method for the loss. Options are 'mean' or 'none'. Defaults to 'mean'. |
|
|
|
Returns: |
|
torch.Tensor: Rectified flow loss. |
|
""" |
|
|
|
|
|
t = t.reshape(t.shape[0], *[1 for _ in range(len(input.shape) - len(t.shape))]) |
|
|
|
target_flow = (1 - self.MIN_STD) * noise - input |
|
loss = F.mse_loss(preds.float(), target_flow.float(), reduction="none") |
|
if use_weighting: |
|
weight = self._logit_norm(t).detach() |
|
loss = loss * weight |
|
if reduce == "mean": |
|
loss = loss.mean() |
|
elif reduce == "none": |
|
loss = loss |
|
else: |
|
raise NotImplementedError |
|
|
|
return loss |
|
|
|
def discritize_timestep(self, t: torch.Tensor, n_timesteps: int = 1000) -> torch.Tensor: |
|
""" |
|
Discretize the continuous timestep. |
|
|
|
Args: |
|
t (torch.Tensor): Continuous timestep. |
|
n_timesteps (int, optional): Number of discrete timesteps. Defaults to 1000. |
|
|
|
Returns: |
|
torch.Tensor: Discretized timestep tensor. |
|
""" |
|
return (t * n_timesteps).round().long() |
|
|