beomgyu-kim's picture
Add MotifVision model for text-to-image generation
6cd6a16
"""
Dongpin Oh: [email protected]
"""
import time
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.utils.torch_utils import randn_tensor
from loguru import logger
class FlowMixin:
"""Mixin class for flow-based models."""
_is_noise_initialized = False
MIN_STD: float = 0.0 # minimum size of std for the flow matching
CLAMP_CONTINUOUS_TIME: float = 0.0
def _timestep_shifting(self, timesteps: torch.Tensor, alpha: float = 3.0) -> torch.Tensor:
"""
Adjust the timesteps for higher resolution images by adding more noise.
higher resolution have more pixels we need more noise to destory their signal.
NOTE that the timestep must be reversed unlike original SD3 style timestep shifting,
since the flow-timestep is reversed (t=0: original image; t=1: pure noise)
Args:
timesteps (torch.Tensor): The original timesteps.
alpha (float, optional): Scaling factor for timestep shifting. Defaults to 3.0.
Returns:
torch.Tensor: The reversed and shifted timesteps.
"""
shifted_t = alpha * timesteps / (1 + (alpha - 1) * timesteps)
reversed_t = 1 - shifted_t
return reversed_t
def get_noisy_input(
self,
input: torch.Tensor,
normal_mean: float = 0.0,
normal_std: float = 1.0,
is_finetuning: bool = False,
t: torch.Tensor | None = None,
n_timesteps: int = 1000,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate noisy input, based on the optimal-transport flow.
Args:
input (torch.Tensor): Input tensor.
normal_mean (float, optional): Mean of the normal distribution for noise. Defaults to 0.0.
normal_std (float, optional): Standard deviation of the normal distribution for noise. Defaults to 1.0.
is_finetuning (bool, optional): Whether the model is in finetuning mode. Defaults to False.
t (torch.Tensor | None, optional): Predefined timesteps. If None, timesteps are sampled. Defaults to None.
n_timesteps (int, opyionsl): Number of discrete timesteps. Defaults to 1000.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing noise, noisy input, and timestep.
- The timestep t is a continuous timestep ranged [0, 1] which cannot directly be used
for the timestep embedding (needs to be discretized by self.discritize_timestep()).
"""
b = input.shape[0]
if not FlowMixin._is_noise_initialized:
logger.warning("The torch random seed is changed when generating the initial random noise.")
current_time = int(time.time() * 1000)
torch.manual_seed(current_time)
FlowMixin._is_noise_initialized = True
noise = randn_tensor(input.shape).cuda()
# Sample timestep from a log-normal distribution with mean 0 and std 1
if t is None:
# NOTE: timestep is sampled from log-normal distribution to make the model
# focus on the intermediate timesteps, which are the most informative part of flow-ODE
t = torch.randn(b).cuda() * normal_std + normal_mean
t = torch.sigmoid(t)
else:
t = t
# Clamp t to be within the interval [0, 1] for numerical stability
t = t.clamp(0 + self.CLAMP_CONTINUOUS_TIME, 1 - self.CLAMP_CONTINUOUS_TIME)
if is_finetuning:
t = self._timestep_shifting(t)
# Reshape t to match the dimensions required
for _ in range(len(noise.shape) - 1):
t = t.unsqueeze(1)
# Generate the noisy input
noisy_input = (1 - t) * input + (self.MIN_STD + (1 - self.MIN_STD) * t) * noise
t_squeezed = t.squeeze()
if t_squeezed.dim() == 0:
t_squeezed = t_squeezed.unsqueeze(0)
return noise, noisy_input, t_squeezed
@torch.no_grad()
def _logit_norm(self, t: torch.Tensor, m: float = 0, s: float = 1) -> torch.Tensor:
"""
Compute the loss-weight for the flow-matching loss.
It will be focusing (giving high weights) on the intermidate timestep, since
such timesteps are hard to be matched, according to https://arxiv.org/pdf/2403.03206.pdf
Args:
t (torch.Tensor): Timestep tensor. (0 to 1)
m (float, optional): Mean of the logit distribution. Defaults to 0.
s (float, optional): Standard deviation of the logit distribution. Defaults to 1.
Returns:
torch.Tensor: Weight tensor for the flow-matching loss.
"""
coef = (1 / (s * ((2 * np.pi) ** 0.5))) * (1 / (t * (1 - t)))
def logit(x):
return torch.log(x) - torch.log(1 - x)
exp = torch.exp(-((logit(t) - m) ** 2) / (2 * s**2))
return coef * exp
def rectified_flow_loss(
self,
input: torch.Tensor,
noise: torch.Tensor,
t: torch.Tensor,
preds: torch.Tensor,
use_weighting: bool = False,
reduce: str = "mean",
) -> torch.Tensor:
"""
Compute the rectified flow loss, https://arxiv.org/pdf/2403.03206.pdf
Args:
input (torch.Tensor): Input tensor.
noise (torch.Tensor): Noise tensor.
t (torch.Tensor): Timestep tensor.
preds (torch.Tensor): Predicted tensor.
use_weighting (bool, optional): Whether to use weighting for the loss. Defaults to False.
reduce (str, optional): Reduction method for the loss. Options are 'mean' or 'none'. Defaults to 'mean'.
Returns:
torch.Tensor: Rectified flow loss.
"""
# Matching dimension for broadcasting
t = t.reshape(t.shape[0], *[1 for _ in range(len(input.shape) - len(t.shape))])
target_flow = (1 - self.MIN_STD) * noise - input
loss = F.mse_loss(preds.float(), target_flow.float(), reduction="none")
if use_weighting:
weight = self._logit_norm(t).detach()
loss = loss * weight
if reduce == "mean":
loss = loss.mean()
elif reduce == "none":
loss = loss
else:
raise NotImplementedError
return loss
def discritize_timestep(self, t: torch.Tensor, n_timesteps: int = 1000) -> torch.Tensor:
"""
Discretize the continuous timestep.
Args:
t (torch.Tensor): Continuous timestep.
n_timesteps (int, optional): Number of discrete timesteps. Defaults to 1000.
Returns:
torch.Tensor: Discretized timestep tensor.
"""
return (t * n_timesteps).round().long()