File size: 29,642 Bytes
6cd6a16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327b52c
6cd6a16
327b52c
6cd6a16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327b52c
6cd6a16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
from collections import defaultdict
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
import tqdm
from diffusers.models import AutoencoderKL
from diffusers.utils.torch_utils import randn_tensor
from loguru import logger
from PIL import Image, ImageFilter
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer

from models.mixin.flow_mixin import FlowMixin
from models.modeling_dit import MotifDiT

TOKEN_MAX_LENGTH: int = 256
DROP_PROB: float = 0.1
LATENT_CHANNELS: int = 4
VAE_DOWNSCALE_FACTOR: int = 8
SD3_LATENT_CHANNEL: int = 16


def generate_intervals(steps, ratio, start=1.0):
    intervals = torch.linspace(start, 0, steps=steps)
    intervals = intervals.pow(ratio)
    return intervals


class MotifImage(nn.Module, FlowMixin):
    """
    MotifImage Text-to-Image model.

    This model combines a Diffusion transformer with a rectified flow loss and multiple text encoders.
    It uses a VAE (Variational Autoencoder) for image encoding and decoding.

    Args:
        config (MMDiTConfig): Configuration object for the MMDiT model.

    Attributes:
        dit (MotifDiT): MotifDiT model instance.
        noise_scheduler (DDPMScheduler): Noise scheduler for the diffusion process.
        normalize_img (Callable): Function to normalize images from [-1, 1] range.
        unnormalize_img (Callable): Function to unnormalize images to [0, 1] range.
        cond_drop_prob (float): Probability of dropping text embeddings during training.
        snr_gamma (str): Strategy for weighting the loss based on Signal-to-Noise Ratio (SNR).
        loss_weight_strategy (str): Strategy for weighting the loss.
        vae (AutoencoderKL): Variational Autoencoder for image encoding and decoding.
        t5 (T5EncoderModel): T5 encoder model for text encoding.
        t5_tokenizer (T5Tokenizer): T5 tokenizer for text tokenization.
        clip_l (CLIPModel): CLIP (Contrastive Language-Image Pre-training) model (large) for text encoding.
        clip_l_tokenizer (CLIPTokenizerFast): CLIP tokenizer (large) for text tokenization.
        clip_g (CLIPModel): CLIP model (giant) for text encoding.
        clip_g_tokenizer (CLIPTokenizerFast): CLIP tokenizer (giant) for text tokenization.
        tokenizers (List[Union[T5Tokenizer, CLIPTokenizerFast]]): List of tokenizers.
        text_encoders (List[Union[T5EncoderModel, CLIPModel]]): List of text encoder models.
    """

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dit = MotifDiT(config)
        self.cond_drop_prob = 0.1
        self.use_weighting = False
        self._get_encoders()
        self._freeze_encoders()

    def forward(self, images: torch.Tensor, raw_text: str) -> torch.Tensor:
        """
        Forward pass of the MotifDiT model.

        Args:
            images (torch.Tensor): Input images tensor, [0-1] ranged.
            raw_text (List[str]): Input text string.

        Returns:
            torch.Tensor: Rectified flow matching loss.
        """
        # 1. Encode images and texts
        with torch.no_grad():
            latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor
        tokens, masks = self.tokenization(raw_text)
        tokens = [token.to(latents.device) for token in tokens]
        masks = [mask.to(latents.device) for mask in masks]
        text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
        text_embeddings = self._drop_text_emb(text_embeddings)
        text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
        pooled_text_embeddings = pooled_text_embeddings.float()

        # 2. Get noisy input via the rectified flow
        is_finetuning = self.config.height > 256
        noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)

        timesteps = self.discritize_timestep(t, self.n_timesteps)

        # 3. Forward pass through the dit
        preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)

        # 4. Rectified flow matching loss
        loss = self.rectified_flow_loss(latents, noise, t, preds, use_weighting=self.use_weighting)

        return [loss]

    def _get_encoders(self) -> None:
        """Initialize the VAE and text encoders."""
        if self.config.vae_type == "SD3":
            self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae")
        elif self.config.vae_type == "SDXL":
            self.vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
        else:
            raise ValueError(f"VAE type must be `SD3` or `SDXL`  but self.config.vae_type is {self.config.vae_type}")

        # Text encoders
        # 1. T5-XXL from Google
        self.t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
        self.t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")

        # 2. CLIP-L from OpenAI
        self.clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
        self.clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")

        # 3. CLIP-G from LAION
        self.clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
        self.clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")

        self.tokenizers = [self.t5_tokenizer, self.clip_l_tokenizer, self.clip_g_tokenizer]
        self.text_encoders = [self.t5, self.clip_l, self.clip_g]

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        state_dict = super(MotifImage, self).state_dict(destination, prefix, keep_vars)
        exclude_keys = ["t5.", "clip_l.", "clip_g.", "vae."]
        for key in list(state_dict.keys()):
            if any(key.startswith(exclude_key) for exclude_key in exclude_keys):
                state_dict.pop(key)
        return state_dict

    def load_state_dict(self, state_dict, strict=False):
        """
        Load state dict and merge LoRA parameters if present.

        Args:
            state_dict (dict): State dictionary containing model parameters
            strict (bool): Whether to strictly enforce that the keys in state_dict match the keys in this module

        Returns:
            tuple: (missing_keys, unexpected_keys) lists of parameters that were missing or unexpected
        """
        # Check if state_dict contains LoRA parameters
        has_lora = any("lora_" in key for key in state_dict.keys())

        if has_lora:
            # If model doesn't have LoRA enabled but state_dict has LoRA params, enable it
            if not hasattr(self.dit, "peft_config"):
                logger.info("Enabling LoRA for parameter merging...")
                # Use default values if not already configured
                lora_rank = getattr(self.config, "lora_rank", 64)
                lora_alpha = getattr(self.config, "lora_alpha", 8)
                self.enable_lora(lora_rank, lora_alpha)

        if has_lora:
            try:
                # Load LoRA parameters
                # state_dict = {
                #     k.replace("base_layer.", ""): v
                #     for k, v in state_dict.items()
                #     if "lora_" not in k and "lora" not in k
                # }
                missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)
                # Merge LoRA weights with base model
                logger.info("Merging LoRA parameters with base model...")
                for name, module in self.dit.named_modules():
                    if hasattr(module, "merge_and_unload"):
                        module.merge_and_unload()

                logger.info("Successfully merged LoRA parameters")

            except Exception as e:
                logger.error(f"Error merging LoRA parameters: {str(e)}")
                raise
        else:
            missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False)

        # Log summary of missing/unexpected parameters
        missing_top_levels = set()
        for key in missing_keys:
            top_level_name = key.split(".")[0]
            missing_top_levels.add(top_level_name)
        if missing_top_levels:
            logger.debug("Missing keys during loading at top level:")
            for name in missing_top_levels:
                logger.debug(name)

        if unexpected_keys:
            logger.debug("Unexpected keys found:")
            for key in unexpected_keys:
                logger.debug(key)

        return missing_keys, unexpected_keys

    def _freeze_encoders(self) -> None:
        """
        freeze all encoders
        """
        for encoder_module in [self.vae, self.clip_l, self.clip_g, self.t5]:
            for param in encoder_module.parameters():
                param.requires_grad = False

    def tokenization(
        self, raw_texts: List[str], repeat_if_short: bool = False
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Tokenizes a BATCH of input texts using multiple tokenizers efficiently.
        Optionally repeats each text to fill the max length if it's shorter,
        BEFORE passing the pre-processed batch to the tokenizer.

        Args:
            raw_texts (List[str]): A list of input text strings (the batch).
            repeat_if_short (bool): If True and a text is short, repeat that text
                                    to fill the context length. Defaults to True.

        Returns:
            Tuple[List[torch.Tensor], List[torch.Tensor]]:
                - A list containing one batch tensor of input IDs per tokenizer.
                Each tensor shape: [batch_size, max_length]
                - A list containing one batch tensor of attention masks per tokenizer.
                Each tensor shape: [batch_size, max_length]
        """
        final_batch_tokens = []
        final_batch_masks = []

        # Process the batch with each tokenizer
        for tokenizer in self.tokenizers:
            effective_max_length = min(TOKEN_MAX_LENGTH, tokenizer.model_max_length)

            # 1. Pre-process the batch: Create a new list of potentially repeated strings.
            processed_texts_for_tokenizer = []
            for text_item in raw_texts:
                # Start with the original text for this item
                processed_text = text_item

                if repeat_if_short:
                    # Apply repetition logic individually based on text_item's length
                    num_initial_tokens = len(text_item.split())
                    available_length = effective_max_length - 2  # Heuristic

                    if num_initial_tokens > 0 and num_initial_tokens < available_length:
                        num_additional_repeats = available_length // (num_initial_tokens + 1)
                        if num_additional_repeats > 0:
                            total_repeats = 1 + num_additional_repeats
                            processed_text = " ".join([text_item] * total_repeats)

                # Add the processed text (original or repeated) to the list for this tokenizer
                processed_texts_for_tokenizer.append(processed_text)

            # 2. Tokenize the entire batch of processed texts at once.
            #    Pass the list `processed_texts_for_tokenizer` directly to the tokenizer.
            #    The tokenizer's __call__ method should handle the batch efficiently.
            batch_tok_output = tokenizer(  # Call the tokenizer ONCE with the full list
                processed_texts_for_tokenizer,
                padding="max_length",
                max_length=effective_max_length,
                return_tensors="pt",
                truncation=True,
            )

            # 3. Store the resulting batch tensors directly.
            #    The tokenizer should return tensors with shape [batch_size, max_length].
            final_batch_tokens.append(batch_tok_output.input_ids)
            final_batch_masks.append(batch_tok_output.attention_mask)

        return final_batch_tokens, final_batch_masks

    @torch.no_grad()
    def text_encoding(
        self, tokens: List[torch.Tensor], masks, noisy_pad=False, zero_masking=True
    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
        """
        Encode the tokenized text using multiple text encoders.

        Args:
            tokens (List[torch.Tensor]): List of tokenized text tensors.

        Returns:
            Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing a list of text embeddings and pooled text embeddings.
        """
        t5_tokens, clip_l_tokens, clip_g_tokens = tokens
        t5_masks, clip_l_masks, clip_g_masks = masks
        t5_emb = self.t5(t5_tokens, attention_mask=t5_masks)[0]
        if zero_masking:
            t5_emb = t5_emb * (t5_tokens != self.t5_tokenizer.pad_token_id).unsqueeze(-1)
        if noisy_pad:
            t5_pad_noise = (
                (t5_tokens == self.t5_tokenizer.pad_token_id).unsqueeze(-1) * torch.randn_like(t5_emb).cuda() * 0.008
            )
            t5_emb = t5_emb + t5_pad_noise

        clip_l_emb = self.clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
        clip_g_emb = self.clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
        clip_l_emb_pooled = clip_l_emb.pooler_output  # B x 768
        clip_g_emb_pooled = clip_g_emb.pooler_output  # B x 1280

        clip_l_emb = clip_l_emb.last_hidden_state  # B x L x 768,
        clip_g_emb = clip_g_emb.last_hidden_state  # B x L x 1280,

        def masking_wo_first_eos(token, eos):
            idx = (token != eos).sum(dim=1)
            mask = token != eos
            arange = torch.arange(mask.size(0)).cuda()
            mask[arange, idx] = True
            mask = mask.unsqueeze(-1)  # B x L x 1
            return mask

        if zero_masking:
            clip_l_emb = clip_l_emb * masking_wo_first_eos(
                clip_l_tokens, self.clip_l_tokenizer.eos_token_id
            )  # B x L x 768,
            clip_g_emb = clip_g_emb * masking_wo_first_eos(
                clip_g_tokens, self.clip_g_tokenizer.eos_token_id
            )  # B x L x 768,

        if noisy_pad:
            clip_l_pad_noise = (
                ~masking_wo_first_eos(clip_l_tokens, self.clip_l_tokenizer.eos_token_id)
                * torch.randn_like(clip_l_emb).cuda()
                * 0.08
            )
            clip_g_pad_noise = (
                ~masking_wo_first_eos(clip_g_tokens, self.clip_g_tokenizer.eos_token_id)
                * torch.randn_like(clip_g_emb).cuda()
                * 0.08
            )
            clip_l_emb = clip_l_emb + clip_l_pad_noise
            clip_g_emb = clip_g_emb + clip_g_pad_noise

        encodings = [t5_emb, clip_l_emb, clip_g_emb]
        pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1)  # cat by channel, B x 2048

        return encodings, pooled_encodings

    @torch.no_grad()
    def prompt_embedding(self, prompts: str, device, noisy_pad=False, zero_masking=True):
        tokens, masks = self.tokenization(prompts)
        tokens = [token.to(device) for token in tokens]
        masks = [mask.to(device) for mask in masks]
        text_embeddings, pooled_text_embeddings = self.text_encoding(
            tokens, masks, noisy_pad=noisy_pad, zero_masking=zero_masking
        )
        text_embeddings = [text_embedding.bfloat16() for text_embedding in text_embeddings]
        pooled_text_embeddings = pooled_text_embeddings.bfloat16()
        return text_embeddings, pooled_text_embeddings

    @torch.no_grad()
    def sample(
        self,
        raw_text: List[str],
        steps: int = 50,
        guidance_scale: float = 7.5,
        resolution: List[int] = (256, 256),
        pre_latent=None,
        pre_timestep=None,
        step_scaling=1.0,
        noisy_pad=False,
        zero_masking=False,
        negative_prompt: Optional[List[str]] = None,
        device: str = "cuda",
        rescale_cfg=-1.0,
        clip_t=[0.0, 1.0],
        use_linear_quadratic_schedule=False,
        linear_quadratic_emulating_steps=250,
        prompt_rewriter=None,
        moderator=None,
        get_intermediate_steps: bool = False,  # Defaulting to True based on user code
    ) -> Union[List[Image.Image], Tuple[List[Image.Image], List[List[Image.Image]]]]:  # Updated return type hint
        """
        Sample images using flow matching. Optionally returns intermediate step images
        calculated via observed average velocity method.

        Args:
            raw_text (List[str]): raw text prompts
            steps (int, optional): number of function estimations for flow matching ODE. Defaults to 50.
            guidance_scale (float, optional): classifier free guidance scale. Defaults to 7.5.
            resolution (List[int], optional): input and output resolution of raw images. Defaults to (256, 256).
            device (str, optional):  Defaults to 'cuda'.
            pre_latent (Tensor, optional): the optional input to generate image with pre-defined latents.
                for instance, it would be utilized for denoising or image-editing.
            pre_timestep (float [0,1], optional): the pre-defined timestep. with `pre_latent`, image generation
                can be done by starting with intermediate timestep.
            step_scaling (float, default to 1.3): scaling factor for each ODE-solving.
            use_linear_quadratic_schedule (bool, default to false): boolean option to linear-quaratic t schdule. If false, then linear t schdule.
            linear_quadratic_emulating_steps (int, default to 250): N value in linear-quadratic t schedule from Meta moviegen paper
                Reference: (https://ai.meta.com/static-resource/movie-gen-research-paper) Figure 10
            get_intermediate_steps (bool, optional): Whether to calculate and return intermediate step images.
                                                 Calculation is based on initial_noise - avg(velocity). Defaults to True.

        Returns:
            Union[List[PIL.Image], Tuple[List[PIL.Image], List[List[PIL.Image]]]]:
                If get_intermediate_steps is False: Returns a list of final PIL images.
                If get_intermediate_steps is True: Returns a tuple containing:
                    - List[PIL.Image]: Final output PIL images.
                    - List[List[PIL.Image]]: List of intermediate PIL images. Each inner list contains
                                              the batch of images for one intermediate step.
        """
        if prompt_rewriter:
            prompts = [prompt_rewriter.rewrite(prompt) for prompt in raw_text]
        else:
            prompts = raw_text

        # Simplified check for rewriter status
        if prompts == raw_text and prompt_rewriter is not None:
            logger.debug("Prompt rewriter did not change the prompts.")
        elif prompt_rewriter is None:
            logger.debug("Prompt rewriter not provided.")

        if moderator is None:
            is_safe_prompt = [True for _ in prompts]
        else:
            is_safe_prompt = [moderator and moderator.is_safe_content(prompt, threshold=0.7) for prompt in prompts]
            if not all(is_safe_prompt):
                logger.warning("Noxious prompt detected. Output image(s) will be blurred.")

        b = len(prompts)
        h, w = resolution

        # --- [Initial Latent Noise (e = x_1)] ---
        latent_channels = 16
        if pre_latent is None:
            initial_noise = randn_tensor(  # Store initial noise separately
                (b, latent_channels, h // VAE_DOWNSCALE_FACTOR, w // VAE_DOWNSCALE_FACTOR),
                device=device,
                dtype=torch.float32,  # Use float32 for calculations
            )
        else:
            initial_noise = pre_latent.to(device=device, dtype=torch.float32)
            if pre_timestep is not None and pre_timestep < 1.0:  # Check if it's truly intermediate
                logger.warning(
                    "Using pre_latent as initial_noise for average calculation, but pre_timestep suggests it's not pure noise. Results might be unexpected."
                )

        latents = initial_noise.clone()  # Working latents for the ODE solver

        # --- [Text Embeddings & CFG Setup] ---
        text_embeddings, pooled_text_embeddings = self.prompt_embedding(
            prompts, latents.device, noisy_pad=noisy_pad, zero_masking=zero_masking
        )
        text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in text_embeddings]
        pooled_text_embeddings = pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)

        do_classifier_free_guidance = guidance_scale > 1.0
        if do_classifier_free_guidance:
            negative_text_embeddings = [
                torch.zeros_like(text_embedding, device=text_embedding.device) for text_embedding in text_embeddings
            ]
            negative_pooled_text_embeddings = torch.zeros_like(
                pooled_text_embeddings, device=pooled_text_embeddings.device
            )
            text_embeddings = [
                torch.cat([text_embedding, negative_text_embedding], dim=0)
                for text_embedding, negative_text_embedding in zip(text_embeddings, negative_text_embeddings)
            ]
            pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)

            # if negative_prompt is None:
            #     negative_prompt = [""] * b
            #     logger.debug("No negative prompt provided, using empty strings for CFG.")
            # negative_text_embeddings, negative_pooled_text_embeddings = self.prompt_embedding(negative_prompt, latents.device)
            # negative_text_embeddings = [emb.to(device=latents.device, dtype=torch.bfloat16) for emb in negative_text_embeddings]
            # negative_pooled_text_embeddings = negative_pooled_text_embeddings.to(device=latents.device, dtype=torch.bfloat16)

            # text_embeddings = [torch.cat([pos_emb, neg_emb], dim=0) for pos_emb, neg_emb in zip(text_embeddings, negative_text_embeddings)]
            # pooled_text_embeddings = torch.cat([pooled_text_embeddings, negative_pooled_text_embeddings], dim=0)

        # --- [Timestep Schedule (Sigmas)] ---
        # linear t schedule
        sigmas = torch.linspace(1, 0, steps + 1) if not pre_timestep else torch.linspace(pre_timestep, 0, steps + 1)

        if use_linear_quadratic_schedule:
            # liner-quadratic t schedule
            assert steps % 2 == 0
            N = linear_quadratic_emulating_steps
            sigmas = torch.concat(
                [
                    torch.linspace(1, 0, N + 1)[: steps // 2],
                    torch.linspace(0, 1, steps // 2 + 1) ** 2 * (steps // 2 * 1 / N - 1) - (steps // 2 * 1 / N - 1),
                ]
            )

        # --- [Initialization for Intermediate Step Calculation] ---
        # intermediate_latents will store the latent states for intermediate steps
        intermediate_latents = [] if get_intermediate_steps else None
        predicted_velocities = []  # Store dx from each step
        sigma_history = []
        # --- [Sampling Loop] ---
        for infer_step, t in tqdm.tqdm(enumerate(sigmas[:-1]), total=len(sigmas[:-1]), desc="Sampling"):
            # Prepare input for DiT model
            if do_classifier_free_guidance:
                input_latents = torch.cat([latents] * 2, dim=0)
            else:
                input_latents = latents

            # Prepare timestep input
            timestep = (t * 1000).round().long().to(latents.device)
            timestep = timestep.expand(input_latents.shape[0]).to(torch.bfloat16)  # Ensure timestep is bfloat16

            # Predict velocity dx = v(x_t, t) ≈ e - x_0
            dx = self.dit(input_latents.to(torch.bfloat16), timestep, text_embeddings, pooled_text_embeddings)
            dt = sigmas[infer_step + 1] - sigmas[infer_step]  # dt is negative
            sigma_history.append(dt)

            # Apply Classifier-Free Guidance
            if do_classifier_free_guidance:
                cond_dx, uncond_dx = dx.chunk(2)
                current_guidance_scale = guidance_scale if clip_t[0] <= t and t <= clip_t[1] else 1.0
                dx = uncond_dx + current_guidance_scale * (cond_dx - uncond_dx)

                if rescale_cfg > 0.0:
                    std_pos = torch.std(cond_dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
                    std_cfg = torch.std(dx, dim=[1, 2, 3], keepdim=True, unbiased=False) + 1e-5
                    factor = std_pos / std_cfg
                    factor = rescale_cfg * factor + (1.0 - rescale_cfg)
                    dx = dx * factor

            # --- Store the predicted velocity for averaging ---
            predicted_velocities.append(dx.clone())

            # --- Update Latents using standard Euler step ---
            latents = latents + dt * dx

            # --- Calculate and Store Intermediate Latent State (if requested) ---
            if get_intermediate_steps:
                dxs = torch.stack(predicted_velocities)

                sigma_sum = sum(sigma_history)
                normalized_sigma_history = [s / (sigma_sum) for s in sigma_history]
                dts = torch.tensor(normalized_sigma_history, device=dxs.device, dtype=dxs.dtype).view(-1, 1, 1, 1, 1)

                avg_dx = torch.sum(dxs * dts, dim=0)
                observed_state = initial_noise - avg_dx  # Calculate the desired intermediate state
                intermediate_latents.append(observed_state.clone())  # Store its latent representation

        # --- [Decode Final Latents to PIL Images] ---
        self.vae = self.vae.to(device=latents.device, dtype=torch.float32)  # Ensure VAE is ready
        final_latents_scaled = latents.to(torch.float32) / self.vae.config.scaling_factor
        final_image_tensors = self.vae.decode(final_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
        final_image_tensors = ((final_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)

        final_pil_images = []
        for i, image_tensor in enumerate(final_image_tensors):
            img = T.ToPILImage()(image_tensor.cpu())
            if not is_safe_prompt[i]:
                img = img.filter(ImageFilter.GaussianBlur(radius=30))
            final_pil_images.append(img)

        # --- [Decode Intermediate Latents to PIL Images (if requested)] ---
        if get_intermediate_steps:
            intermediate_pil_images = []
            # Ensure VAE is still ready (it should be from final decoding)
            for step_latents in tqdm.tqdm(intermediate_latents, desc="Decoding intermediates"):
                step_latents_scaled = (
                    step_latents.to(dtype=torch.float32, device="cuda") / self.vae.config.scaling_factor
                )
                step_image_tensors = (
                    self.vae.decode(step_latents_scaled, return_dict=False)[0] + self.vae.config.shift_factor
                )
                step_image_tensors = ((step_image_tensors + 1.0) / 2.0).clamp(0.0, 1.0)

                current_step_pil = []
                for i, image_tensor in enumerate(step_image_tensors):
                    img = T.ToPILImage()(image_tensor.cpu())
                    # Apply moderation blur consistency
                    if not is_safe_prompt[i]:
                        img = img.filter(ImageFilter.GaussianBlur(radius=30))
                    current_step_pil.append(img)
                intermediate_pil_images.append(current_step_pil)  # Append list of images for this step

            return final_pil_images, intermediate_pil_images  # Return both final and intermediate images
        else:
            return final_pil_images  # Return only final images

    @torch.no_grad()
    def eval_with_loss(self, images, raw_text):
        latents = self.vae.encode(images).latent_dist.sample() * self.vae.config.scaling_factor

        tokens, masks = self.tokenization(raw_text)
        tokens = [token.to(latents.device) for token in tokens]
        masks = [mask.to(latents.device) for mask in masks]
        text_embeddings, pooled_text_embeddings = self.text_encoding(tokens, masks)
        text_embeddings = [text_embedding for text_embedding in text_embeddings]
        pooled_text_embeddings = pooled_text_embeddings.float()

        # 2. Get noisy input via the rectified flow
        is_finetuning = self.config.height > 256
        noise, noise_latents, t = self.get_noisy_input(latents, is_finetuning=is_finetuning)
        timesteps = self.discritize_timestep(t, self.n_timesteps)

        # 3. Forward pass through the dit
        preds = self.dit(noise_latents, timesteps, text_embeddings, pooled_text_embeddings)

        # 4. Rectified flow matching loss
        loss = self.rectified_flow_loss(noise_latents, noise, t, preds, reduce="none", use_weighting=False).mean(
            dim=[1, 2, 3]
        )

        intervals = np.linspace(0, 1, 9)
        t_interval = [(intervals[i], intervals[i + 1]) for i in range(len(intervals) - 1)]

        loss_bins = defaultdict(list)
        for i, interval in enumerate(t_interval, 0):
            idx = (interval[0] < t) & (t < interval[1])
            loss_bins[i].append(loss[idx])

        return loss_bins