File size: 6,801 Bytes
6cd6a16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""
Dongpin Oh: [email protected]
"""
import time
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.utils.torch_utils import randn_tensor
from loguru import logger
class FlowMixin:
"""Mixin class for flow-based models."""
_is_noise_initialized = False
MIN_STD: float = 0.0 # minimum size of std for the flow matching
CLAMP_CONTINUOUS_TIME: float = 0.0
def _timestep_shifting(self, timesteps: torch.Tensor, alpha: float = 3.0) -> torch.Tensor:
"""
Adjust the timesteps for higher resolution images by adding more noise.
higher resolution have more pixels we need more noise to destory their signal.
NOTE that the timestep must be reversed unlike original SD3 style timestep shifting,
since the flow-timestep is reversed (t=0: original image; t=1: pure noise)
Args:
timesteps (torch.Tensor): The original timesteps.
alpha (float, optional): Scaling factor for timestep shifting. Defaults to 3.0.
Returns:
torch.Tensor: The reversed and shifted timesteps.
"""
shifted_t = alpha * timesteps / (1 + (alpha - 1) * timesteps)
reversed_t = 1 - shifted_t
return reversed_t
def get_noisy_input(
self,
input: torch.Tensor,
normal_mean: float = 0.0,
normal_std: float = 1.0,
is_finetuning: bool = False,
t: torch.Tensor | None = None,
n_timesteps: int = 1000,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate noisy input, based on the optimal-transport flow.
Args:
input (torch.Tensor): Input tensor.
normal_mean (float, optional): Mean of the normal distribution for noise. Defaults to 0.0.
normal_std (float, optional): Standard deviation of the normal distribution for noise. Defaults to 1.0.
is_finetuning (bool, optional): Whether the model is in finetuning mode. Defaults to False.
t (torch.Tensor | None, optional): Predefined timesteps. If None, timesteps are sampled. Defaults to None.
n_timesteps (int, opyionsl): Number of discrete timesteps. Defaults to 1000.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing noise, noisy input, and timestep.
- The timestep t is a continuous timestep ranged [0, 1] which cannot directly be used
for the timestep embedding (needs to be discretized by self.discritize_timestep()).
"""
b = input.shape[0]
if not FlowMixin._is_noise_initialized:
logger.warning("The torch random seed is changed when generating the initial random noise.")
current_time = int(time.time() * 1000)
torch.manual_seed(current_time)
FlowMixin._is_noise_initialized = True
noise = randn_tensor(input.shape).cuda()
# Sample timestep from a log-normal distribution with mean 0 and std 1
if t is None:
# NOTE: timestep is sampled from log-normal distribution to make the model
# focus on the intermediate timesteps, which are the most informative part of flow-ODE
t = torch.randn(b).cuda() * normal_std + normal_mean
t = torch.sigmoid(t)
else:
t = t
# Clamp t to be within the interval [0, 1] for numerical stability
t = t.clamp(0 + self.CLAMP_CONTINUOUS_TIME, 1 - self.CLAMP_CONTINUOUS_TIME)
if is_finetuning:
t = self._timestep_shifting(t)
# Reshape t to match the dimensions required
for _ in range(len(noise.shape) - 1):
t = t.unsqueeze(1)
# Generate the noisy input
noisy_input = (1 - t) * input + (self.MIN_STD + (1 - self.MIN_STD) * t) * noise
t_squeezed = t.squeeze()
if t_squeezed.dim() == 0:
t_squeezed = t_squeezed.unsqueeze(0)
return noise, noisy_input, t_squeezed
@torch.no_grad()
def _logit_norm(self, t: torch.Tensor, m: float = 0, s: float = 1) -> torch.Tensor:
"""
Compute the loss-weight for the flow-matching loss.
It will be focusing (giving high weights) on the intermidate timestep, since
such timesteps are hard to be matched, according to https://arxiv.org/pdf/2403.03206.pdf
Args:
t (torch.Tensor): Timestep tensor. (0 to 1)
m (float, optional): Mean of the logit distribution. Defaults to 0.
s (float, optional): Standard deviation of the logit distribution. Defaults to 1.
Returns:
torch.Tensor: Weight tensor for the flow-matching loss.
"""
coef = (1 / (s * ((2 * np.pi) ** 0.5))) * (1 / (t * (1 - t)))
def logit(x):
return torch.log(x) - torch.log(1 - x)
exp = torch.exp(-((logit(t) - m) ** 2) / (2 * s**2))
return coef * exp
def rectified_flow_loss(
self,
input: torch.Tensor,
noise: torch.Tensor,
t: torch.Tensor,
preds: torch.Tensor,
use_weighting: bool = False,
reduce: str = "mean",
) -> torch.Tensor:
"""
Compute the rectified flow loss, https://arxiv.org/pdf/2403.03206.pdf
Args:
input (torch.Tensor): Input tensor.
noise (torch.Tensor): Noise tensor.
t (torch.Tensor): Timestep tensor.
preds (torch.Tensor): Predicted tensor.
use_weighting (bool, optional): Whether to use weighting for the loss. Defaults to False.
reduce (str, optional): Reduction method for the loss. Options are 'mean' or 'none'. Defaults to 'mean'.
Returns:
torch.Tensor: Rectified flow loss.
"""
# Matching dimension for broadcasting
t = t.reshape(t.shape[0], *[1 for _ in range(len(input.shape) - len(t.shape))])
target_flow = (1 - self.MIN_STD) * noise - input
loss = F.mse_loss(preds.float(), target_flow.float(), reduction="none")
if use_weighting:
weight = self._logit_norm(t).detach()
loss = loss * weight
if reduce == "mean":
loss = loss.mean()
elif reduce == "none":
loss = loss
else:
raise NotImplementedError
return loss
def discritize_timestep(self, t: torch.Tensor, n_timesteps: int = 1000) -> torch.Tensor:
"""
Discretize the continuous timestep.
Args:
t (torch.Tensor): Continuous timestep.
n_timesteps (int, optional): Number of discrete timesteps. Defaults to 1000.
Returns:
torch.Tensor: Discretized timestep tensor.
"""
return (t * n_timesteps).round().long()
|