File size: 6,801 Bytes
6cd6a16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Dongpin Oh: [email protected]
"""

import time

import numpy as np
import torch
import torch.nn.functional as F
from diffusers.utils.torch_utils import randn_tensor
from loguru import logger


class FlowMixin:
    """Mixin class for flow-based models."""

    _is_noise_initialized = False
    MIN_STD: float = 0.0  # minimum size of std for the flow matching
    CLAMP_CONTINUOUS_TIME: float = 0.0

    def _timestep_shifting(self, timesteps: torch.Tensor, alpha: float = 3.0) -> torch.Tensor:
        """
        Adjust the timesteps for higher resolution images by adding more noise.
        higher resolution have more pixels we need more noise to destory their signal.

        NOTE that the timestep must be reversed unlike original SD3 style timestep shifting,
        since the flow-timestep is reversed (t=0: original image; t=1: pure noise)
        Args:
            timesteps (torch.Tensor): The original timesteps.
            alpha (float, optional): Scaling factor for timestep shifting. Defaults to 3.0.

        Returns:
            torch.Tensor: The reversed and shifted timesteps.
        """
        shifted_t = alpha * timesteps / (1 + (alpha - 1) * timesteps)
        reversed_t = 1 - shifted_t
        return reversed_t

    def get_noisy_input(
        self,
        input: torch.Tensor,
        normal_mean: float = 0.0,
        normal_std: float = 1.0,
        is_finetuning: bool = False,
        t: torch.Tensor | None = None,
        n_timesteps: int = 1000,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Generate noisy input, based on the optimal-transport flow.

        Args:
            input (torch.Tensor): Input tensor.
            normal_mean (float, optional): Mean of the normal distribution for noise. Defaults to 0.0.
            normal_std (float, optional): Standard deviation of the normal distribution for noise. Defaults to 1.0.
            is_finetuning (bool, optional): Whether the model is in finetuning mode. Defaults to False.
            t (torch.Tensor | None, optional): Predefined timesteps. If None, timesteps are sampled. Defaults to None.
            n_timesteps (int, opyionsl): Number of discrete timesteps. Defaults to 1000.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing noise, noisy input, and timestep.
                - The timestep t is a continuous timestep ranged [0, 1] which cannot directly be used
                  for the timestep embedding (needs to be discretized by self.discritize_timestep()).
        """
        b = input.shape[0]

        if not FlowMixin._is_noise_initialized:
            logger.warning("The torch random seed is changed when generating the initial random noise.")
            current_time = int(time.time() * 1000)
            torch.manual_seed(current_time)
            FlowMixin._is_noise_initialized = True

        noise = randn_tensor(input.shape).cuda()

        # Sample timestep from a log-normal distribution with mean 0 and std 1
        if t is None:
            # NOTE: timestep is sampled from log-normal distribution to make the model
            # focus on the intermediate timesteps, which are the most informative part of flow-ODE
            t = torch.randn(b).cuda() * normal_std + normal_mean
            t = torch.sigmoid(t)
        else:
            t = t

        # Clamp t to be within the interval [0, 1] for numerical stability
        t = t.clamp(0 + self.CLAMP_CONTINUOUS_TIME, 1 - self.CLAMP_CONTINUOUS_TIME)

        if is_finetuning:
            t = self._timestep_shifting(t)

        # Reshape t to match the dimensions required
        for _ in range(len(noise.shape) - 1):
            t = t.unsqueeze(1)

        # Generate the noisy input
        noisy_input = (1 - t) * input + (self.MIN_STD + (1 - self.MIN_STD) * t) * noise

        t_squeezed = t.squeeze()
        if t_squeezed.dim() == 0:
            t_squeezed = t_squeezed.unsqueeze(0)
        return noise, noisy_input, t_squeezed

    @torch.no_grad()
    def _logit_norm(self, t: torch.Tensor, m: float = 0, s: float = 1) -> torch.Tensor:
        """
        Compute the loss-weight for the flow-matching loss.
        It will be focusing (giving high weights) on the intermidate timestep, since
        such timesteps are hard to be matched, according to https://arxiv.org/pdf/2403.03206.pdf
        Args:
            t (torch.Tensor): Timestep tensor. (0 to 1)
            m (float, optional): Mean of the logit distribution. Defaults to 0.
            s (float, optional): Standard deviation of the logit distribution. Defaults to 1.

        Returns:
            torch.Tensor: Weight tensor for the flow-matching loss.
        """
        coef = (1 / (s * ((2 * np.pi) ** 0.5))) * (1 / (t * (1 - t)))

        def logit(x):
            return torch.log(x) - torch.log(1 - x)

        exp = torch.exp(-((logit(t) - m) ** 2) / (2 * s**2))
        return coef * exp

    def rectified_flow_loss(
        self,
        input: torch.Tensor,
        noise: torch.Tensor,
        t: torch.Tensor,
        preds: torch.Tensor,
        use_weighting: bool = False,
        reduce: str = "mean",
    ) -> torch.Tensor:
        """
        Compute the rectified flow loss, https://arxiv.org/pdf/2403.03206.pdf

        Args:
            input (torch.Tensor): Input tensor.
            noise (torch.Tensor): Noise tensor.
            t (torch.Tensor): Timestep tensor.
            preds (torch.Tensor): Predicted tensor.
            use_weighting (bool, optional): Whether to use weighting for the loss. Defaults to False.
            reduce (str, optional): Reduction method for the loss. Options are 'mean' or 'none'. Defaults to 'mean'.

        Returns:
            torch.Tensor: Rectified flow loss.
        """

        # Matching dimension for broadcasting
        t = t.reshape(t.shape[0], *[1 for _ in range(len(input.shape) - len(t.shape))])

        target_flow = (1 - self.MIN_STD) * noise - input
        loss = F.mse_loss(preds.float(), target_flow.float(), reduction="none")
        if use_weighting:
            weight = self._logit_norm(t).detach()
            loss = loss * weight
        if reduce == "mean":
            loss = loss.mean()
        elif reduce == "none":
            loss = loss
        else:
            raise NotImplementedError

        return loss

    def discritize_timestep(self, t: torch.Tensor, n_timesteps: int = 1000) -> torch.Tensor:
        """
        Discretize the continuous timestep.

        Args:
            t (torch.Tensor): Continuous timestep.
            n_timesteps (int, optional): Number of discrete timesteps. Defaults to 1000.

        Returns:
            torch.Tensor: Discretized timestep tensor.
        """
        return (t * n_timesteps).round().long()