lym00 commited on
Commit
5f212ea
·
verified ·
1 Parent(s): c65ecd1

Upload struct.py

Browse files
Files changed (1) hide show
  1. struct.py +2019 -0
struct.py ADDED
@@ -0,0 +1,2019 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Utility functions for Diffusion Models."""
3
+
4
+ import enum
5
+ import typing as tp
6
+ from abc import abstractmethod
7
+ from collections import OrderedDict, defaultdict
8
+ from dataclasses import dataclass, field
9
+
10
+ # region imports
11
+ import torch.nn as nn
12
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, SwiGLU
13
+ from diffusers.models.attention import BasicTransformerBlock, FeedForward, JointTransformerBlock
14
+ from diffusers.models.attention_processor import Attention, SanaLinearAttnProcessor2_0
15
+ from diffusers.models.embeddings import (
16
+ CombinedTimestepGuidanceTextProjEmbeddings,
17
+ CombinedTimestepTextProjEmbeddings,
18
+ ImageHintTimeEmbedding,
19
+ ImageProjection,
20
+ ImageTimeEmbedding,
21
+ PatchEmbed,
22
+ PixArtAlphaTextProjection,
23
+ TextImageProjection,
24
+ TextImageTimeEmbedding,
25
+ TextTimeEmbedding,
26
+ TimestepEmbedding,
27
+ )
28
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
29
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
30
+ from diffusers.models.transformers.pixart_transformer_2d import PixArtTransformer2DModel
31
+ from diffusers.models.transformers.sana_transformer import GLUMBConv, SanaTransformer2DModel, SanaTransformerBlock
32
+ from diffusers.models.transformers.transformer_2d import Transformer2DModel
33
+ from diffusers.models.transformers.transformer_flux import (
34
+ FluxSingleTransformerBlock,
35
+ FluxTransformer2DModel,
36
+ FluxTransformerBlock,
37
+ FluxAttention
38
+ )
39
+ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
40
+ from diffusers.models.unets.unet_2d import UNet2DModel
41
+ from diffusers.models.unets.unet_2d_blocks import (
42
+ CrossAttnDownBlock2D,
43
+ CrossAttnUpBlock2D,
44
+ DownBlock2D,
45
+ UNetMidBlock2D,
46
+ UNetMidBlock2DCrossAttn,
47
+ UpBlock2D,
48
+ )
49
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
50
+ from diffusers.pipelines import (
51
+ FluxControlPipeline,
52
+ FluxFillPipeline,
53
+ FluxPipeline,
54
+ FluxKontextPipeline,
55
+ PixArtAlphaPipeline,
56
+ PixArtSigmaPipeline,
57
+ SanaPipeline,
58
+ StableDiffusion3Pipeline,
59
+ StableDiffusionPipeline,
60
+ StableDiffusionXLPipeline,
61
+ )
62
+
63
+ from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d
64
+ from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear
65
+ from deepcompressor.nn.struct.attn import (
66
+ AttentionConfigStruct,
67
+ AttentionStruct,
68
+ BaseTransformerStruct,
69
+ FeedForwardConfigStruct,
70
+ FeedForwardStruct,
71
+ TransformerBlockStruct,
72
+ )
73
+ from deepcompressor.nn.struct.base import BaseModuleStruct
74
+ from deepcompressor.utils.common import join_name
75
+
76
+ from .attention import DiffusionAttentionProcessor
77
+
78
+ # endregion
79
+
80
+
81
+ __all__ = ["DiffusionModelStruct", "DiffusionBlockStruct", "DiffusionModelStruct"]
82
+
83
+
84
+ DIT_BLOCK_CLS = tp.Union[
85
+ BasicTransformerBlock,
86
+ JointTransformerBlock,
87
+ FluxSingleTransformerBlock,
88
+ FluxTransformerBlock,
89
+ SanaTransformerBlock,
90
+ ]
91
+ UNET_BLOCK_CLS = tp.Union[
92
+ DownBlock2D,
93
+ CrossAttnDownBlock2D,
94
+ UNetMidBlock2D,
95
+ UNetMidBlock2DCrossAttn,
96
+ UpBlock2D,
97
+ CrossAttnUpBlock2D,
98
+ ]
99
+ DIT_CLS = tp.Union[
100
+ Transformer2DModel,
101
+ PixArtTransformer2DModel,
102
+ SD3Transformer2DModel,
103
+ FluxTransformer2DModel,
104
+ SanaTransformer2DModel,
105
+ ]
106
+ UNET_CLS = tp.Union[UNet2DModel, UNet2DConditionModel]
107
+ MODEL_CLS = tp.Union[DIT_CLS, UNET_CLS]
108
+ UNET_PIPELINE_CLS = tp.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
109
+ DIT_PIPELINE_CLS = tp.Union[
110
+ StableDiffusion3Pipeline,
111
+ PixArtAlphaPipeline,
112
+ PixArtSigmaPipeline,
113
+ FluxPipeline,
114
+ FluxKontextPipeline,
115
+ FluxControlPipeline,
116
+ FluxFillPipeline,
117
+ SanaPipeline,
118
+ ]
119
+ PIPELINE_CLS = tp.Union[UNET_PIPELINE_CLS, DIT_PIPELINE_CLS]
120
+
121
+ ATTENTION_CLS = tp.Union[
122
+ # existing types...
123
+ FluxAttention,
124
+ ]
125
+
126
+ @dataclass(kw_only=True)
127
+ class DiffusionModuleStruct(BaseModuleStruct):
128
+ def named_key_modules(self) -> tp.Generator[tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
129
+ if isinstance(self.module, (nn.Linear, nn.Conv2d)):
130
+ yield self.key, self.name, self.module, self.parent, self.fname
131
+ else:
132
+ for name, module in self.module.named_modules():
133
+ if name and isinstance(module, (nn.Linear, nn.Conv2d)):
134
+ module_name = join_name(self.name, name)
135
+ field_name = join_name(self.fname, name)
136
+ yield self.key, module_name, module, self.parent, field_name
137
+
138
+
139
+ @dataclass(kw_only=True)
140
+ class DiffusionBlockStruct(BaseModuleStruct):
141
+ @abstractmethod
142
+ def iter_attention_structs(self) -> tp.Generator["DiffusionAttentionStruct", None, None]: ...
143
+
144
+ @abstractmethod
145
+ def iter_transformer_block_structs(self) -> tp.Generator["DiffusionTransformerBlockStruct", None, None]: ...
146
+
147
+
148
+ @dataclass(kw_only=True)
149
+ class DiffusionModelStruct(DiffusionBlockStruct):
150
+ pre_module_structs: OrderedDict[str, DiffusionModuleStruct] = field(init=False, repr=False)
151
+ post_module_structs: OrderedDict[str, DiffusionModuleStruct] = field(init=False, repr=False)
152
+
153
+ @property
154
+ @abstractmethod
155
+ def num_blocks(self) -> int: ...
156
+
157
+ @property
158
+ @abstractmethod
159
+ def block_structs(self) -> list[DiffusionBlockStruct]: ...
160
+
161
+ @abstractmethod
162
+ def get_prev_module_keys(self) -> tuple[str, ...]: ...
163
+
164
+ @abstractmethod
165
+ def get_post_module_keys(self) -> tuple[str, ...]: ...
166
+
167
+ @abstractmethod
168
+ def _get_iter_block_activations_args(
169
+ self, **input_kwargs
170
+ ) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]: ...
171
+
172
+ def _get_iter_pre_module_activations_args(
173
+ self,
174
+ ) -> tuple[list[nn.Module], list[DiffusionModuleStruct], list[bool], list[bool]]:
175
+ layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
176
+ for layer_struct in self.pre_module_structs.values():
177
+ layers.append(layer_struct.module)
178
+ layer_structs.append(layer_struct)
179
+ recomputes.append(False)
180
+ use_prev_layer_outputs.append(False)
181
+ return layers, layer_structs, recomputes, use_prev_layer_outputs
182
+
183
+ def _get_iter_post_module_activations_args(
184
+ self,
185
+ ) -> tuple[list[nn.Module], list[DiffusionModuleStruct], list[bool], list[bool]]:
186
+ layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
187
+ for layer_struct in self.post_module_structs.values():
188
+ layers.append(layer_struct.module)
189
+ layer_structs.append(layer_struct)
190
+ recomputes.append(False)
191
+ use_prev_layer_outputs.append(False)
192
+ return layers, layer_structs, recomputes, use_prev_layer_outputs
193
+
194
+ def get_iter_layer_activations_args(
195
+ self, skip_pre_modules: bool, skip_post_modules: bool, **input_kwargs
196
+ ) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
197
+ """
198
+ Get the arguments for iterating over the layers and their activations.
199
+
200
+ Args:
201
+ skip_pre_modules (`bool`):
202
+ Whether to skip the pre-modules
203
+ skip_post_modules (`bool`):
204
+ Whether to skip the post-modules
205
+
206
+ Returns:
207
+ `tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]`:
208
+ the layers, the layer structs, the recomputes, and the use_prev_layer_outputs
209
+ """
210
+ layers, structs, recomputes, uses = [], [], [], []
211
+ if not skip_pre_modules:
212
+ layers, structs, recomputes, uses = self._get_iter_pre_module_activations_args()
213
+ _layers, _structs, _recomputes, _uses = self._get_iter_block_activations_args(**input_kwargs)
214
+ layers.extend(_layers)
215
+ structs.extend(_structs)
216
+ recomputes.extend(_recomputes)
217
+ uses.extend(_uses)
218
+ if not skip_post_modules:
219
+ _layers, _structs, _recomputes, _uses = self._get_iter_post_module_activations_args()
220
+ layers.extend(_layers)
221
+ structs.extend(_structs)
222
+ recomputes.extend(_recomputes)
223
+ uses.extend(_uses)
224
+ return layers, structs, recomputes, uses
225
+
226
+ def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
227
+ for module in self.pre_module_structs.values():
228
+ yield from module.named_key_modules()
229
+ for block in self.block_structs:
230
+ yield from block.named_key_modules()
231
+ for module in self.post_module_structs.values():
232
+ yield from module.named_key_modules()
233
+
234
+ def iter_attention_structs(self) -> tp.Generator["AttentionStruct", None, None]:
235
+ for block in self.block_structs:
236
+ yield from block.iter_attention_structs()
237
+
238
+ def iter_transformer_block_structs(self) -> tp.Generator["DiffusionTransformerBlockStruct", None, None]:
239
+ for block in self.block_structs:
240
+ yield from block.iter_transformer_block_structs()
241
+
242
+ def get_named_layers(
243
+ self, skip_pre_modules: bool, skip_post_modules: bool, skip_blocks: bool = False
244
+ ) -> OrderedDict[str, DiffusionBlockStruct | DiffusionModuleStruct]:
245
+ named_layers = OrderedDict()
246
+ if not skip_pre_modules:
247
+ named_layers.update(self.pre_module_structs)
248
+ if not skip_blocks:
249
+ for block in self.block_structs:
250
+ named_layers[block.name] = block
251
+ if not skip_post_modules:
252
+ named_layers.update(self.post_module_structs)
253
+ return named_layers
254
+
255
+ @staticmethod
256
+ def _default_construct(
257
+ module: tp.Union[PIPELINE_CLS, MODEL_CLS],
258
+ /,
259
+ parent: tp.Optional[BaseModuleStruct] = None,
260
+ fname: str = "",
261
+ rname: str = "",
262
+ rkey: str = "",
263
+ idx: int = 0,
264
+ **kwargs,
265
+ ) -> "DiffusionModelStruct":
266
+ if isinstance(module, UNET_PIPELINE_CLS):
267
+ module = module.unet
268
+ elif isinstance(module, DIT_PIPELINE_CLS):
269
+ module = module.transformer
270
+ if isinstance(module, UNET_CLS):
271
+ return UNetStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs)
272
+ elif isinstance(module, DIT_CLS):
273
+ return DiTStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs)
274
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
275
+
276
+ @classmethod
277
+ def _get_default_key_map(cls) -> dict[str, set[str]]:
278
+ unet_key_map = UNetStruct._get_default_key_map()
279
+ dit_key_map = DiTStruct._get_default_key_map()
280
+ flux_key_map = FluxStruct._get_default_key_map()
281
+ key_map: dict[str, set[str]] = defaultdict(set)
282
+ for rkey, keys in unet_key_map.items():
283
+ key_map[rkey].update(keys)
284
+ for rkey, keys in dit_key_map.items():
285
+ key_map[rkey].update(keys)
286
+ for rkey, keys in flux_key_map.items():
287
+ key_map[rkey].update(keys)
288
+ return {k: v for k, v in key_map.items() if v}
289
+
290
+ @staticmethod
291
+ def _simplify_keys(keys: tp.Iterable[str], *, key_map: dict[str, set[str]]) -> list[str]:
292
+ """Simplify the keys based on the key map.
293
+
294
+ Args:
295
+ keys (`Iterable[str]`):
296
+ The keys to simplify.
297
+ key_map (`dict[str, set[str]]`):
298
+ The key map.
299
+
300
+ Returns:
301
+ `list[str]`:
302
+ The simplified keys.
303
+ """
304
+ # we first sort key_map by length of values in descending order
305
+ key_map = dict(sorted(key_map.items(), key=lambda item: len(item[1]), reverse=True))
306
+ ukeys, skeys = set(keys), set()
307
+ for k, v in key_map.items():
308
+ if k in ukeys:
309
+ skeys.add(k)
310
+ ukeys.discard(k)
311
+ ukeys.difference_update(v)
312
+ continue
313
+ if ukeys.issuperset(v):
314
+ skeys.add(k)
315
+ ukeys.difference_update(v)
316
+ assert not ukeys, f"Unrecognized keys: {ukeys}"
317
+ return sorted(skeys)
318
+
319
+
320
+ @dataclass(kw_only=True)
321
+ class DiffusionAttentionStruct(AttentionStruct):
322
+ module: Attention = field(repr=False, kw_only=False)
323
+ """the module of AttentionBlock"""
324
+ parent: tp.Optional["DiffusionTransformerBlockStruct"] = field(repr=False)
325
+
326
+ def filter_kwargs(self, kwargs: dict) -> dict:
327
+ """Filter layer kwargs to attn kwargs."""
328
+ if isinstance(self.parent.module, BasicTransformerBlock):
329
+ if kwargs.get("cross_attention_kwargs", None) is None:
330
+ attn_kwargs = {}
331
+ else:
332
+ attn_kwargs = dict(kwargs["cross_attention_kwargs"].items())
333
+ attn_kwargs.pop("gligen", None)
334
+ if self.idx == 0:
335
+ attn_kwargs["attention_mask"] = kwargs.get("attention_mask", None)
336
+ else:
337
+ attn_kwargs["attention_mask"] = kwargs.get("encoder_attention_mask", None)
338
+ else:
339
+ attn_kwargs = {}
340
+ return attn_kwargs
341
+
342
+ @staticmethod
343
+ def _default_construct(
344
+ module: Attention,
345
+ /,
346
+ parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
347
+ fname: str = "",
348
+ rname: str = "",
349
+ rkey: str = "",
350
+ idx: int = 0,
351
+ **kwargs,
352
+ ) -> "DiffusionAttentionStruct":
353
+ if isinstance(module, FluxAttention):
354
+ # FluxAttention has different attribute names than standard attention
355
+ with_rope = True
356
+ num_query_heads = module.heads # FluxAttention uses 'heads', not 'num_heads'
357
+ num_key_value_heads = module.heads # FLUX typically uses same for q/k/v
358
+
359
+ # FluxAttention doesn't have 'to_out', but may have other output projections
360
+ # Check what output projection attributes actually exist
361
+ o_proj = None
362
+ o_proj_rname = ""
363
+
364
+ # Try to find the correct output projection
365
+ if hasattr(module, 'to_out') and module.to_out is not None:
366
+ o_proj = module.to_out[0] if isinstance(module.to_out, (list, tuple)) else module.to_out
367
+ o_proj_rname = "to_out.0" if isinstance(module.to_out, (list, tuple)) else "to_out"
368
+ elif hasattr(module, 'to_add_out'):
369
+ o_proj = module.to_add_out
370
+ o_proj_rname = "to_add_out"
371
+
372
+ q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
373
+ q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
374
+ q, k, v = module.to_q, module.to_k, module.to_v
375
+ q_rname, k_rname, v_rname = "to_q", "to_k", "to_v"
376
+
377
+ # Handle the add_* projections that FluxAttention has
378
+ add_q_proj = getattr(module, "add_q_proj", None)
379
+ add_k_proj = getattr(module, "add_k_proj", None)
380
+ add_v_proj = getattr(module, "add_v_proj", None)
381
+ add_o_proj = getattr(module, "to_add_out", None)
382
+ add_q_proj_rname = "add_q_proj" if add_q_proj else ""
383
+ add_k_proj_rname = "add_k_proj" if add_k_proj else ""
384
+ add_v_proj_rname = "add_v_proj" if add_v_proj else ""
385
+ add_o_proj_rname = "to_add_out" if add_o_proj else ""
386
+
387
+ kwargs = (
388
+ "encoder_hidden_states",
389
+ "attention_mask",
390
+ "image_rotary_emb",
391
+ )
392
+ cross_attention = add_k_proj is not None
393
+ elif module.is_cross_attention:
394
+ q_proj, k_proj, v_proj = module.to_q, None, None
395
+ add_q_proj, add_k_proj, add_v_proj, add_o_proj = None, module.to_k, module.to_v, None
396
+ q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "", ""
397
+ add_q_proj_rname, add_k_proj_rname, add_v_proj_rname, add_o_proj_rname = "", "to_k", "to_v", ""
398
+ else:
399
+ q_proj, k_proj, v_proj = module.to_q, module.to_k, module.to_v
400
+ add_q_proj = getattr(module, "add_q_proj", None)
401
+ add_k_proj = getattr(module, "add_k_proj", None)
402
+ add_v_proj = getattr(module, "add_v_proj", None)
403
+ add_o_proj = getattr(module, "to_add_out", None)
404
+ q_proj_rname, k_proj_rname, v_proj_rname = "to_q", "to_k", "to_v"
405
+ add_q_proj_rname, add_k_proj_rname, add_v_proj_rname = "add_q_proj", "add_k_proj", "add_v_proj"
406
+ add_o_proj_rname = "to_add_out"
407
+ if getattr(module, "to_out", None) is not None:
408
+ o_proj = module.to_out[0]
409
+ o_proj_rname = "to_out.0"
410
+ assert isinstance(o_proj, nn.Linear)
411
+ elif parent is not None:
412
+ assert isinstance(parent.module, FluxSingleTransformerBlock)
413
+ assert isinstance(parent.module.proj_out, ConcatLinear)
414
+ assert len(parent.module.proj_out.linears) == 2
415
+ o_proj = parent.module.proj_out.linears[0]
416
+ o_proj_rname = ".proj_out.linears.0"
417
+ else:
418
+ raise RuntimeError("Cannot find the output projection.")
419
+ if isinstance(module.processor, DiffusionAttentionProcessor):
420
+ with_rope = module.processor.rope is not None
421
+ elif module.processor.__class__.__name__.startswith("Flux"):
422
+ with_rope = True
423
+ else:
424
+ with_rope = False # TODO: fix for other processors
425
+ config = AttentionConfigStruct(
426
+ hidden_size=q_proj.weight.shape[1],
427
+ add_hidden_size=add_k_proj.weight.shape[1] if add_k_proj is not None else 0,
428
+ inner_size=q_proj.weight.shape[0],
429
+ num_query_heads=module.heads,
430
+ num_key_value_heads=module.to_k.weight.shape[0] // (module.to_q.weight.shape[0] // module.heads),
431
+ with_qk_norm=module.norm_q is not None,
432
+ with_rope=with_rope,
433
+ linear_attn=isinstance(module.processor, SanaLinearAttnProcessor2_0),
434
+ )
435
+ return DiffusionAttentionStruct(
436
+ module=module,
437
+ parent=parent,
438
+ fname=fname,
439
+ idx=idx,
440
+ rname=rname,
441
+ rkey=rkey,
442
+ config=config,
443
+ q_proj=q_proj,
444
+ k_proj=k_proj,
445
+ v_proj=v_proj,
446
+ o_proj=o_proj,
447
+ add_q_proj=add_q_proj,
448
+ add_k_proj=add_k_proj,
449
+ add_v_proj=add_v_proj,
450
+ add_o_proj=add_o_proj,
451
+ q=None, # TODO: add q, k, v
452
+ k=None,
453
+ v=None,
454
+ q_proj_rname=q_proj_rname,
455
+ k_proj_rname=k_proj_rname,
456
+ v_proj_rname=v_proj_rname,
457
+ o_proj_rname=o_proj_rname,
458
+ add_q_proj_rname=add_q_proj_rname,
459
+ add_k_proj_rname=add_k_proj_rname,
460
+ add_v_proj_rname=add_v_proj_rname,
461
+ add_o_proj_rname=add_o_proj_rname,
462
+ q_rname="",
463
+ k_rname="",
464
+ v_rname="",
465
+ )
466
+
467
+
468
+ @dataclass(kw_only=True)
469
+ class DiffusionFeedForwardStruct(FeedForwardStruct):
470
+ module: FeedForward = field(repr=False, kw_only=False)
471
+ """the module of FeedForward"""
472
+ parent: tp.Optional["DiffusionTransformerBlockStruct"] = field(repr=False)
473
+ # region modules
474
+ moe_gate: None = field(init=False, repr=False, default=None)
475
+ experts: list[nn.Module] = field(init=False, repr=False)
476
+ # endregion
477
+ # region names
478
+ moe_gate_rname: str = field(init=False, repr=False, default="")
479
+ experts_rname: str = field(init=False, repr=False, default="")
480
+ # endregion
481
+
482
+ # region aliases
483
+
484
+ @property
485
+ def up_proj(self) -> nn.Linear:
486
+ return self.up_projs[0]
487
+
488
+ @property
489
+ def down_proj(self) -> nn.Linear:
490
+ return self.down_projs[0]
491
+
492
+ @property
493
+ def up_proj_rname(self) -> str:
494
+ return self.up_proj_rnames[0]
495
+
496
+ @property
497
+ def down_proj_rname(self) -> str:
498
+ return self.down_proj_rnames[0]
499
+
500
+ @property
501
+ def up_proj_name(self) -> str:
502
+ return self.up_proj_names[0]
503
+
504
+ @property
505
+ def down_proj_name(self) -> str:
506
+ return self.down_proj_names[0]
507
+
508
+ # endregion
509
+
510
+ def __post_init__(self) -> None:
511
+ assert len(self.up_projs) == len(self.down_projs) == 1
512
+ assert len(self.up_proj_rnames) == len(self.down_proj_rnames) == 1
513
+ self.experts = [self.module]
514
+ super().__post_init__()
515
+
516
+ @staticmethod
517
+ def _default_construct(
518
+ module: FeedForward | FluxSingleTransformerBlock | GLUMBConv,
519
+ /,
520
+ parent: tp.Optional["DiffusionTransformerBlockStruct"] = None,
521
+ fname: str = "",
522
+ rname: str = "",
523
+ rkey: str = "",
524
+ idx: int = 0,
525
+ **kwargs,
526
+ ) -> "DiffusionFeedForwardStruct":
527
+ if isinstance(module, FeedForward):
528
+ layer_1, layer_2 = module.net[0], module.net[2]
529
+ assert isinstance(layer_1, (GEGLU, GELU, ApproximateGELU, SwiGLU))
530
+ up_proj, up_proj_rname = layer_1.proj, "net.0.proj"
531
+ assert isinstance(up_proj, nn.Linear)
532
+ down_proj, down_proj_rname = layer_2, "net.2"
533
+ if isinstance(layer_1, GEGLU):
534
+ act_type = "gelu_glu"
535
+ elif isinstance(layer_1, SwiGLU):
536
+ act_type = "swish_glu"
537
+ else:
538
+ assert layer_1.__class__.__name__.lower().endswith("gelu")
539
+ act_type = "gelu"
540
+ if isinstance(layer_2, ShiftedLinear):
541
+ down_proj, down_proj_rname = layer_2.linear, "net.2.linear"
542
+ act_type = "gelu_shifted"
543
+ assert isinstance(down_proj, nn.Linear)
544
+ ffn = module
545
+ elif isinstance(module, FluxSingleTransformerBlock):
546
+ up_proj, up_proj_rname = module.proj_mlp, "proj_mlp"
547
+ act_type = "gelu"
548
+ assert isinstance(module.proj_out, ConcatLinear)
549
+ assert len(module.proj_out.linears) == 2
550
+ layer_2 = module.proj_out.linears[1]
551
+ if isinstance(layer_2, ShiftedLinear):
552
+ down_proj, down_proj_rname = layer_2.linear, "proj_out.linears.1.linear"
553
+ act_type = "gelu_shifted"
554
+ else:
555
+ down_proj, down_proj_rname = layer_2, "proj_out.linears.1"
556
+ ffn = nn.Sequential(up_proj, module.act_mlp, layer_2)
557
+ assert not rname, f"Unsupported rname: {rname}"
558
+ elif isinstance(module, GLUMBConv):
559
+ ffn = module
560
+ up_proj, up_proj_rname = module.conv_inverted, "conv_inverted"
561
+ down_proj, down_proj_rname = module.conv_point, "conv_point"
562
+ act_type = "silu_conv_silu_glu"
563
+ else:
564
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
565
+ config = FeedForwardConfigStruct(
566
+ hidden_size=up_proj.weight.shape[1],
567
+ intermediate_size=down_proj.weight.shape[1],
568
+ intermediate_act_type=act_type,
569
+ num_experts=1,
570
+ )
571
+ return DiffusionFeedForwardStruct(
572
+ module=ffn, # this may be a virtual module
573
+ parent=parent,
574
+ fname=fname,
575
+ idx=idx,
576
+ rname=rname,
577
+ rkey=rkey,
578
+ config=config,
579
+ up_projs=[up_proj],
580
+ down_projs=[down_proj],
581
+ up_proj_rnames=[up_proj_rname],
582
+ down_proj_rnames=[down_proj_rname],
583
+ )
584
+
585
+
586
+ @dataclass(kw_only=True)
587
+ class DiffusionTransformerBlockStruct(TransformerBlockStruct, DiffusionBlockStruct):
588
+ # region relative keys
589
+ norm_rkey: tp.ClassVar[str] = "transformer_norm"
590
+ add_norm_rkey: tp.ClassVar[str] = "transformer_add_norm"
591
+ attn_struct_cls: tp.ClassVar[type[DiffusionAttentionStruct]] = DiffusionAttentionStruct
592
+ ffn_struct_cls: tp.ClassVar[type[DiffusionFeedForwardStruct]] = DiffusionFeedForwardStruct
593
+ # endregion
594
+
595
+ parent: tp.Optional["DiffusionTransformerStruct"] = field(repr=False)
596
+ # region child modules
597
+ post_attn_norms: list[nn.LayerNorm] = field(init=False, repr=False, default_factory=list)
598
+ post_attn_add_norms: list[nn.LayerNorm] = field(init=False, repr=False, default_factory=list)
599
+ post_ffn_norm: None = field(init=False, repr=False, default=None)
600
+ post_add_ffn_norm: None = field(init=False, repr=False, default=None)
601
+ # endregion
602
+ # region relative names
603
+ post_attn_norm_rnames: list[str] = field(init=False, repr=False, default_factory=list)
604
+ post_attn_add_norm_rnames: list[str] = field(init=False, repr=False, default_factory=list)
605
+ post_ffn_norm_rname: str = field(init=False, repr=False, default="")
606
+ post_add_ffn_norm_rname: str = field(init=False, repr=False, default="")
607
+ # endregion
608
+ # region attributes
609
+ norm_type: str
610
+ add_norm_type: str
611
+ # endregion
612
+ # region absolute keys
613
+ norm_key: str = field(init=False, repr=False)
614
+ add_norm_key: str = field(init=False, repr=False)
615
+ # endregion
616
+ # region child structs
617
+ pre_attn_norm_structs: list[DiffusionModuleStruct | None] = field(init=False, repr=False)
618
+ pre_attn_add_norm_structs: list[DiffusionModuleStruct | None] = field(init=False, repr=False)
619
+ pre_ffn_norm_struct: DiffusionModuleStruct = field(init=False, repr=False, default=None)
620
+ pre_add_ffn_norm_struct: DiffusionModuleStruct | None = field(init=False, repr=False, default=None)
621
+ attn_structs: list[DiffusionAttentionStruct] = field(init=False, repr=False)
622
+ ffn_struct: DiffusionFeedForwardStruct | None = field(init=False, repr=False)
623
+ add_ffn_struct: DiffusionFeedForwardStruct | None = field(init=False, repr=False)
624
+ # endregion
625
+
626
+ def __post_init__(self) -> None:
627
+ super().__post_init__()
628
+ self.norm_key = join_name(self.key, self.norm_rkey, sep="_")
629
+ self.add_norm_key = join_name(self.key, self.add_norm_rkey, sep="_")
630
+ self.attn_norm_structs = [
631
+ DiffusionModuleStruct(norm, parent=self, fname="pre_attn_norm", rname=rname, rkey=self.norm_rkey, idx=idx)
632
+ for idx, (norm, rname) in enumerate(zip(self.pre_attn_norms, self.pre_attn_norm_rnames, strict=True))
633
+ ]
634
+ self.add_attn_norm_structs = [
635
+ DiffusionModuleStruct(
636
+ norm, parent=self, fname="pre_attn_add_norm", rname=rname, rkey=self.add_norm_rkey, idx=idx
637
+ )
638
+ for idx, (norm, rname) in enumerate(
639
+ zip(self.pre_attn_add_norms, self.pre_attn_add_norm_rnames, strict=True)
640
+ )
641
+ ]
642
+ if self.pre_ffn_norm is not None:
643
+ self.pre_ffn_norm_struct = DiffusionModuleStruct(
644
+ self.pre_ffn_norm, parent=self, fname="pre_ffn_norm", rname=self.pre_ffn_norm_rname, rkey=self.norm_rkey
645
+ )
646
+ if self.pre_add_ffn_norm is not None:
647
+ self.pre_add_ffn_norm_struct = DiffusionModuleStruct(
648
+ self.pre_add_ffn_norm,
649
+ parent=self,
650
+ fname="pre_add_ffn_norm",
651
+ rname=self.pre_add_ffn_norm_rname,
652
+ rkey=self.add_norm_rkey,
653
+ )
654
+
655
+ def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
656
+ for attn_norm in self.attn_norm_structs:
657
+ if attn_norm.module is not None:
658
+ yield from attn_norm.named_key_modules()
659
+ for add_attn_norm in self.add_attn_norm_structs:
660
+ if add_attn_norm.module is not None:
661
+ yield from add_attn_norm.named_key_modules()
662
+ for attn_struct in self.attn_structs:
663
+ yield from attn_struct.named_key_modules()
664
+ if self.pre_ffn_norm_struct is not None:
665
+ if self.pre_attn_norms and self.pre_attn_norms[0] is not self.pre_ffn_norm:
666
+ yield from self.pre_ffn_norm_struct.named_key_modules()
667
+ if self.ffn_struct is not None:
668
+ yield from self.ffn_struct.named_key_modules()
669
+ if self.pre_add_ffn_norm_struct is not None:
670
+ if self.pre_attn_add_norms and self.pre_attn_add_norms[0] is not self.pre_add_ffn_norm:
671
+ yield from self.pre_add_ffn_norm_struct.named_key_modules()
672
+ if self.add_ffn_struct is not None:
673
+ yield from self.add_ffn_struct.named_key_modules()
674
+
675
+ @staticmethod
676
+ def _default_construct(
677
+ module: DIT_BLOCK_CLS,
678
+ /,
679
+ parent: tp.Optional["DiffusionTransformerStruct"] = None,
680
+ fname: str = "",
681
+ rname: str = "",
682
+ rkey: str = "",
683
+ idx: int = 0,
684
+ **kwargs,
685
+ ) -> "DiffusionTransformerBlockStruct":
686
+ if isinstance(module, (BasicTransformerBlock, SanaTransformerBlock)):
687
+ parallel = False
688
+ if isinstance(module, SanaTransformerBlock):
689
+ norm_type = add_norm_type = "ada_norm_single"
690
+ else:
691
+ norm_type = add_norm_type = module.norm_type
692
+ pre_attn_norms, pre_attn_norm_rnames = [], []
693
+ attns, attn_rnames = [], []
694
+ pre_attn_add_norms, pre_attn_add_norm_rnames = [], []
695
+ assert module.norm1 is not None
696
+ assert module.attn1 is not None
697
+ pre_attn_norms.append(module.norm1)
698
+ pre_attn_norm_rnames.append("norm1")
699
+ attns.append(module.attn1)
700
+ attn_rnames.append("attn1")
701
+ pre_attn_add_norms.append(module.attn1.norm_cross)
702
+ pre_attn_add_norm_rnames.append("attn1.norm_cross")
703
+ if module.attn2 is not None:
704
+ if norm_type == "ada_norm_single":
705
+ pre_attn_norms.append(None)
706
+ pre_attn_norm_rnames.append("")
707
+ else:
708
+ assert module.norm2 is not None
709
+ pre_attn_norms.append(module.norm2)
710
+ pre_attn_norm_rnames.append("norm2")
711
+ attns.append(module.attn2)
712
+ attn_rnames.append("attn2")
713
+ pre_attn_add_norms.append(module.attn2.norm_cross)
714
+ pre_attn_add_norm_rnames.append("attn2.norm_cross")
715
+ if norm_type == "ada_norm_single":
716
+ assert module.norm2 is not None
717
+ pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
718
+ else:
719
+ pre_ffn_norm, pre_ffn_norm_rname = module.norm3, "" if module.norm3 is None else "norm3"
720
+ ffn, ffn_rname = module.ff, "" if module.ff is None else "ff"
721
+ pre_add_ffn_norm, pre_add_ffn_norm_rname, add_ffn, add_ffn_rname = None, "", None, ""
722
+ elif isinstance(module, JointTransformerBlock):
723
+ parallel = False
724
+ norm_type = "ada_norm_zero"
725
+ pre_attn_norms, pre_attn_norm_rnames = [module.norm1], ["norm1"]
726
+ if isinstance(module.norm1_context, AdaLayerNormZero):
727
+ add_norm_type = "ada_norm_zero"
728
+ else:
729
+ add_norm_type = "ada_norm_continous"
730
+ pre_attn_add_norms, pre_attn_add_norm_rnames = [module.norm1_context], ["norm1_context"]
731
+ attns, attn_rnames = [module.attn], ["attn"]
732
+ pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
733
+ ffn, ffn_rname = module.ff, "ff"
734
+ pre_add_ffn_norm, pre_add_ffn_norm_rname = module.norm2_context, "norm2_context"
735
+ add_ffn, add_ffn_rname = module.ff_context, "ff_context"
736
+ elif isinstance(module, FluxSingleTransformerBlock):
737
+ parallel = True
738
+ norm_type = add_norm_type = "ada_norm_zero_single"
739
+ pre_attn_norms, pre_attn_norm_rnames = [module.norm], ["norm"]
740
+ attns, attn_rnames = [module.attn], ["attn"]
741
+ pre_attn_add_norms, pre_attn_add_norm_rnames = [], []
742
+ pre_ffn_norm, pre_ffn_norm_rname = module.norm, "norm"
743
+ ffn, ffn_rname = module, ""
744
+ pre_add_ffn_norm, pre_add_ffn_norm_rname, add_ffn, add_ffn_rname = None, "", None, ""
745
+ elif isinstance(module, FluxTransformerBlock):
746
+ parallel = False
747
+ norm_type = add_norm_type = "ada_norm_zero"
748
+ pre_attn_norms, pre_attn_norm_rnames = [module.norm1], ["norm1"]
749
+ attns, attn_rnames = [module.attn], ["attn"]
750
+ pre_attn_add_norms, pre_attn_add_norm_rnames = [module.norm1_context], ["norm1_context"]
751
+ pre_ffn_norm, pre_ffn_norm_rname = module.norm2, "norm2"
752
+ ffn, ffn_rname = module.ff, "ff"
753
+ pre_add_ffn_norm, pre_add_ffn_norm_rname = module.norm2_context, "norm2_context"
754
+ add_ffn, add_ffn_rname = module.ff_context, "ff_context"
755
+ else:
756
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
757
+ return DiffusionTransformerBlockStruct(
758
+ module=module,
759
+ parent=parent,
760
+ fname=fname,
761
+ idx=idx,
762
+ rname=rname,
763
+ rkey=rkey,
764
+ parallel=parallel,
765
+ pre_attn_norms=pre_attn_norms,
766
+ pre_attn_add_norms=pre_attn_add_norms,
767
+ attns=attns,
768
+ pre_ffn_norm=pre_ffn_norm,
769
+ ffn=ffn,
770
+ pre_add_ffn_norm=pre_add_ffn_norm,
771
+ add_ffn=add_ffn,
772
+ pre_attn_norm_rnames=pre_attn_norm_rnames,
773
+ pre_attn_add_norm_rnames=pre_attn_add_norm_rnames,
774
+ attn_rnames=attn_rnames,
775
+ pre_ffn_norm_rname=pre_ffn_norm_rname,
776
+ ffn_rname=ffn_rname,
777
+ pre_add_ffn_norm_rname=pre_add_ffn_norm_rname,
778
+ add_ffn_rname=add_ffn_rname,
779
+ norm_type=norm_type,
780
+ add_norm_type=add_norm_type,
781
+ )
782
+
783
+ @classmethod
784
+ def _get_default_key_map(cls) -> dict[str, set[str]]:
785
+ """Get the default allowed keys."""
786
+ key_map: dict[str, set[str]] = defaultdict(set)
787
+ norm_rkey = norm_key = cls.norm_rkey
788
+ add_norm_rkey = add_norm_key = cls.add_norm_rkey
789
+ key_map[norm_rkey].add(norm_key)
790
+ key_map[add_norm_rkey].add(add_norm_key)
791
+ attn_cls = cls.attn_struct_cls
792
+ attn_key = attn_rkey = cls.attn_rkey
793
+ qkv_proj_key = qkv_proj_rkey = join_name(attn_key, attn_cls.qkv_proj_rkey, sep="_")
794
+ out_proj_key = out_proj_rkey = join_name(attn_key, attn_cls.out_proj_rkey, sep="_")
795
+ add_qkv_proj_key = add_qkv_proj_rkey = join_name(attn_key, attn_cls.add_qkv_proj_rkey, sep="_")
796
+ add_out_proj_key = add_out_proj_rkey = join_name(attn_key, attn_cls.add_out_proj_rkey, sep="_")
797
+ key_map[attn_rkey].add(qkv_proj_key)
798
+ key_map[attn_rkey].add(out_proj_key)
799
+ if attn_cls.add_qkv_proj_rkey.startswith("add_") and attn_cls.add_out_proj_rkey.startswith("add_"):
800
+ add_attn_rkey = join_name(attn_rkey, "add", sep="_")
801
+ key_map[add_attn_rkey].add(add_qkv_proj_key)
802
+ key_map[add_attn_rkey].add(add_out_proj_key)
803
+ key_map[qkv_proj_rkey].add(qkv_proj_key)
804
+ key_map[out_proj_rkey].add(out_proj_key)
805
+ key_map[add_qkv_proj_rkey].add(add_qkv_proj_key)
806
+ key_map[add_out_proj_rkey].add(add_out_proj_key)
807
+ ffn_cls = cls.ffn_struct_cls
808
+ ffn_key = ffn_rkey = cls.ffn_rkey
809
+ add_ffn_key = add_ffn_rkey = cls.add_ffn_rkey
810
+ up_proj_key = up_proj_rkey = join_name(ffn_key, ffn_cls.up_proj_rkey, sep="_")
811
+ down_proj_key = down_proj_rkey = join_name(ffn_key, ffn_cls.down_proj_rkey, sep="_")
812
+ add_up_proj_key = add_up_proj_rkey = join_name(add_ffn_key, ffn_cls.up_proj_rkey, sep="_")
813
+ add_down_proj_key = add_down_proj_rkey = join_name(add_ffn_key, ffn_cls.down_proj_rkey, sep="_")
814
+ key_map[ffn_rkey].add(up_proj_key)
815
+ key_map[ffn_rkey].add(down_proj_key)
816
+ key_map[add_ffn_rkey].add(add_up_proj_key)
817
+ key_map[add_ffn_rkey].add(add_down_proj_key)
818
+ key_map[up_proj_rkey].add(up_proj_key)
819
+ key_map[down_proj_rkey].add(down_proj_key)
820
+ key_map[add_up_proj_rkey].add(add_up_proj_key)
821
+ key_map[add_down_proj_rkey].add(add_down_proj_key)
822
+ return {k: v for k, v in key_map.items() if v}
823
+
824
+
825
+ @dataclass(kw_only=True)
826
+ class DiffusionTransformerStruct(BaseTransformerStruct, DiffusionBlockStruct):
827
+ # region relative keys
828
+ proj_in_rkey: tp.ClassVar[str] = "transformer_proj_in"
829
+ proj_out_rkey: tp.ClassVar[str] = "transformer_proj_out"
830
+ transformer_block_rkey: tp.ClassVar[str] = ""
831
+ transformer_block_struct_cls: tp.ClassVar[type[DiffusionTransformerBlockStruct]] = DiffusionTransformerBlockStruct
832
+ # endregion
833
+
834
+ module: Transformer2DModel = field(repr=False, kw_only=False)
835
+ # region modules
836
+ norm_in: nn.GroupNorm | None
837
+ """Input normalization"""
838
+ proj_in: nn.Linear | nn.Conv2d
839
+ """Input projection"""
840
+ norm_out: nn.GroupNorm | None
841
+ """Output normalization"""
842
+ proj_out: nn.Linear | nn.Conv2d
843
+ """Output projection"""
844
+ transformer_blocks: nn.ModuleList = field(repr=False)
845
+ """Transformer blocks"""
846
+ # endregion
847
+ # region relative names
848
+ transformer_blocks_rname: str
849
+ # endregion
850
+ # region absolute names
851
+ transformer_blocks_name: str = field(init=False, repr=False)
852
+ transformer_block_names: list[str] = field(init=False, repr=False)
853
+ # endregion
854
+ # region child structs
855
+ transformer_block_structs: list[DiffusionTransformerBlockStruct] = field(init=False, repr=False)
856
+ # endregion
857
+
858
+ # region aliases
859
+
860
+ @property
861
+ def num_blocks(self) -> int:
862
+ return len(self.transformer_blocks)
863
+
864
+ @property
865
+ def block_structs(self) -> list[DiffusionBlockStruct]:
866
+ return self.transformer_block_structs
867
+
868
+ @property
869
+ def block_names(self) -> list[str]:
870
+ return self.transformer_block_names
871
+
872
+ # endregion
873
+
874
+ def __post_init__(self):
875
+ super().__post_init__()
876
+ transformer_block_rnames = [
877
+ f"{self.transformer_blocks_rname}.{idx}" for idx in range(len(self.transformer_blocks))
878
+ ]
879
+ self.transformer_blocks_name = join_name(self.name, self.transformer_blocks_rname)
880
+ self.transformer_block_names = [join_name(self.name, rname) for rname in transformer_block_rnames]
881
+ self.transformer_block_structs = [
882
+ self.transformer_block_struct_cls.construct(
883
+ layer,
884
+ parent=self,
885
+ fname="transformer_block",
886
+ rname=rname,
887
+ rkey=self.transformer_block_rkey,
888
+ idx=idx,
889
+ )
890
+ for idx, (layer, rname) in enumerate(zip(self.transformer_blocks, transformer_block_rnames, strict=True))
891
+ ]
892
+
893
+ @staticmethod
894
+ def _default_construct(
895
+ module: Transformer2DModel,
896
+ /,
897
+ parent: BaseModuleStruct = None,
898
+ fname: str = "",
899
+ rname: str = "",
900
+ rkey: str = "",
901
+ idx: int = 0,
902
+ **kwargs,
903
+ ) -> "DiffusionTransformerStruct":
904
+ if isinstance(module, Transformer2DModel):
905
+ assert module.is_input_continuous, "input must be continuous"
906
+ transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
907
+ norm_in, norm_in_rname = module.norm, "norm"
908
+ proj_in, proj_in_rname = module.proj_in, "proj_in"
909
+ proj_out, proj_out_rname = module.proj_out, "proj_out"
910
+ norm_out, norm_out_rname = None, ""
911
+ else:
912
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
913
+ return DiffusionTransformerStruct(
914
+ module=module,
915
+ parent=parent,
916
+ fname=fname,
917
+ idx=idx,
918
+ rname=rname,
919
+ rkey=rkey,
920
+ norm_in=norm_in,
921
+ proj_in=proj_in,
922
+ transformer_blocks=transformer_blocks,
923
+ proj_out=proj_out,
924
+ norm_out=norm_out,
925
+ norm_in_rname=norm_in_rname,
926
+ proj_in_rname=proj_in_rname,
927
+ transformer_blocks_rname=transformer_blocks_rname,
928
+ norm_out_rname=norm_out_rname,
929
+ proj_out_rname=proj_out_rname,
930
+ )
931
+
932
+ @classmethod
933
+ def _get_default_key_map(cls) -> dict[str, set[str]]:
934
+ """Get the default allowed keys."""
935
+ key_map: dict[str, set[str]] = defaultdict(set)
936
+ proj_in_rkey = proj_in_key = cls.proj_in_rkey
937
+ proj_out_rkey = proj_out_key = cls.proj_out_rkey
938
+ key_map[proj_in_rkey].add(proj_in_key)
939
+ key_map[proj_out_rkey].add(proj_out_key)
940
+ block_cls = cls.transformer_block_struct_cls
941
+ block_key = block_rkey = cls.transformer_block_rkey
942
+ block_key_map = block_cls._get_default_key_map()
943
+ for rkey, keys in block_key_map.items():
944
+ rkey = join_name(block_rkey, rkey, sep="_")
945
+ for key in keys:
946
+ key = join_name(block_key, key, sep="_")
947
+ key_map[rkey].add(key)
948
+ return {k: v for k, v in key_map.items() if v}
949
+
950
+
951
+ @dataclass(kw_only=True)
952
+ class DiffusionResnetStruct(BaseModuleStruct):
953
+ # region relative keys
954
+ conv_rkey: tp.ClassVar[str] = "conv"
955
+ shortcut_rkey: tp.ClassVar[str] = "shortcut"
956
+ time_proj_rkey: tp.ClassVar[str] = "time_proj"
957
+ # endregion
958
+
959
+ module: ResnetBlock2D = field(repr=False, kw_only=False)
960
+ """the module of Resnet"""
961
+ config: FeedForwardConfigStruct
962
+ # region child modules
963
+ norms: list[nn.GroupNorm]
964
+ convs: list[list[nn.Conv2d]]
965
+ shortcut: nn.Conv2d | None
966
+ time_proj: nn.Linear | None
967
+ # endregion
968
+ # region relative names
969
+ norm_rnames: list[str]
970
+ conv_rnames: list[list[str]]
971
+ shortcut_rname: str
972
+ time_proj_rname: str
973
+ # endregion
974
+ # region absolute names
975
+ norm_names: list[str] = field(init=False, repr=False)
976
+ conv_names: list[list[str]] = field(init=False, repr=False)
977
+ shortcut_name: str = field(init=False, repr=False)
978
+ time_proj_name: str = field(init=False, repr=False)
979
+ # endregion
980
+ # region absolute keys
981
+ conv_key: str = field(init=False, repr=False)
982
+ shortcut_key: str = field(init=False, repr=False)
983
+ time_proj_key: str = field(init=False, repr=False)
984
+ # endregion
985
+
986
+ def __post_init__(self):
987
+ super().__post_init__()
988
+ self.norm_names = [join_name(self.name, rname) for rname in self.norm_rnames]
989
+ self.conv_names = [[join_name(self.name, rname) for rname in rnames] for rnames in self.conv_rnames]
990
+ self.shortcut_name = join_name(self.name, self.shortcut_rname)
991
+ self.time_proj_name = join_name(self.name, self.time_proj_rname)
992
+ self.conv_key = join_name(self.key, self.conv_rkey, sep="_")
993
+ self.shortcut_key = join_name(self.key, self.shortcut_rkey, sep="_")
994
+ self.time_proj_key = join_name(self.key, self.time_proj_rkey, sep="_")
995
+
996
+ def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
997
+ for convs, names in zip(self.convs, self.conv_names, strict=True):
998
+ for conv, name in zip(convs, names, strict=True):
999
+ yield self.conv_key, name, conv, self, "conv"
1000
+ if self.shortcut is not None:
1001
+ yield self.shortcut_key, self.shortcut_name, self.shortcut, self, "shortcut"
1002
+ if self.time_proj is not None:
1003
+ yield self.time_proj_key, self.time_proj_name, self.time_proj, self, "time_proj"
1004
+
1005
+ @staticmethod
1006
+ def construct(
1007
+ module: ResnetBlock2D,
1008
+ /,
1009
+ parent: BaseModuleStruct = None,
1010
+ fname: str = "",
1011
+ rname: str = "",
1012
+ rkey: str = "",
1013
+ idx: int = 0,
1014
+ **kwargs,
1015
+ ) -> "DiffusionResnetStruct":
1016
+ if isinstance(module, ResnetBlock2D):
1017
+ assert module.upsample is None, "upsample must be None"
1018
+ assert module.downsample is None, "downsample must be None"
1019
+ act_type = module.nonlinearity.__class__.__name__.lower()
1020
+ shifted = False
1021
+ if isinstance(module.conv1, ConcatConv2d):
1022
+ conv1_convs, conv1_names = [], []
1023
+ for conv_idx, conv in enumerate(module.conv1.convs):
1024
+ if isinstance(conv, ShiftedConv2d):
1025
+ shifted = True
1026
+ conv1_convs.append(conv.conv)
1027
+ conv1_names.append(f"conv1.convs.{conv_idx}.conv")
1028
+ else:
1029
+ assert isinstance(conv, nn.Conv2d)
1030
+ conv1_convs.append(conv)
1031
+ conv1_names.append(f"conv1.convs.{conv_idx}")
1032
+ elif isinstance(module.conv1, ShiftedConv2d):
1033
+ shifted = True
1034
+ conv1_convs = [module.conv1.conv]
1035
+ conv1_names = ["conv1.conv"]
1036
+ else:
1037
+ assert isinstance(module.conv1, nn.Conv2d)
1038
+ conv1_convs, conv1_names = [module.conv1], ["conv1"]
1039
+ if isinstance(module.conv2, ConcatConv2d):
1040
+ conv2_convs, conv2_names = [], []
1041
+ for conv_idx, conv in enumerate(module.conv2.convs):
1042
+ if isinstance(conv, ShiftedConv2d):
1043
+ shifted = True
1044
+ conv2_convs.append(conv.conv)
1045
+ conv2_names.append(f"conv2.convs.{conv_idx}.conv")
1046
+ else:
1047
+ assert isinstance(conv, nn.Conv2d)
1048
+ conv2_convs.append(conv)
1049
+ conv2_names.append(f"conv2.convs.{conv_idx}")
1050
+ elif isinstance(module.conv2, ShiftedConv2d):
1051
+ shifted = True
1052
+ conv2_convs = [module.conv2.conv]
1053
+ conv2_names = ["conv2.conv"]
1054
+ else:
1055
+ assert isinstance(module.conv2, nn.Conv2d)
1056
+ conv2_convs, conv2_names = [module.conv2], ["conv2"]
1057
+ convs, conv_rnames = [conv1_convs, conv2_convs], [conv1_names, conv2_names]
1058
+ norms, norm_rnames = [module.norm1, module.norm2], ["norm1", "norm2"]
1059
+ shortcut, shortcut_rname = module.conv_shortcut, "" if module.conv_shortcut is None else "conv_shortcut"
1060
+ time_proj, time_proj_rname = module.time_emb_proj, "" if module.time_emb_proj is None else "time_emb_proj"
1061
+ if shifted:
1062
+ assert all(hasattr(conv, "shifted") and conv.shifted for level_convs in convs for conv in level_convs)
1063
+ act_type += "_shifted"
1064
+ else:
1065
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
1066
+ config = FeedForwardConfigStruct(
1067
+ hidden_size=convs[0][0].weight.shape[1],
1068
+ intermediate_size=convs[0][0].weight.shape[0],
1069
+ intermediate_act_type=act_type,
1070
+ num_experts=1,
1071
+ )
1072
+ return DiffusionResnetStruct(
1073
+ module=module,
1074
+ parent=parent,
1075
+ fname=fname,
1076
+ idx=idx,
1077
+ rname=rname,
1078
+ rkey=rkey,
1079
+ config=config,
1080
+ norms=norms,
1081
+ convs=convs,
1082
+ shortcut=shortcut,
1083
+ time_proj=time_proj,
1084
+ norm_rnames=norm_rnames,
1085
+ conv_rnames=conv_rnames,
1086
+ shortcut_rname=shortcut_rname,
1087
+ time_proj_rname=time_proj_rname,
1088
+ )
1089
+
1090
+ @classmethod
1091
+ def _get_default_key_map(cls) -> dict[str, set[str]]:
1092
+ """Get the default allowed keys."""
1093
+ key_map: dict[str, set[str]] = defaultdict(set)
1094
+ conv_key = conv_rkey = cls.conv_rkey
1095
+ shortcut_key = shortcut_rkey = cls.shortcut_rkey
1096
+ time_proj_key = time_proj_rkey = cls.time_proj_rkey
1097
+ key_map[conv_rkey].add(conv_key)
1098
+ key_map[shortcut_rkey].add(shortcut_key)
1099
+ key_map[time_proj_rkey].add(time_proj_key)
1100
+ return {k: v for k, v in key_map.items() if v}
1101
+
1102
+
1103
+ @dataclass(kw_only=True)
1104
+ class UNetBlockStruct(DiffusionBlockStruct):
1105
+ class BlockType(enum.StrEnum):
1106
+ DOWN = "down"
1107
+ MID = "mid"
1108
+ UP = "up"
1109
+
1110
+ # region relative keys
1111
+ resnet_rkey: tp.ClassVar[str] = "resblock"
1112
+ sampler_rkey: tp.ClassVar[str] = "sample"
1113
+ transformer_rkey: tp.ClassVar[str] = ""
1114
+ resnet_struct_cls: tp.ClassVar[type[DiffusionResnetStruct]] = DiffusionResnetStruct
1115
+ transformer_struct_cls: tp.ClassVar[type[DiffusionTransformerStruct]] = DiffusionTransformerStruct
1116
+ # endregion
1117
+
1118
+ parent: tp.Optional["UNetStruct"] = field(repr=False)
1119
+ # region attributes
1120
+ block_type: BlockType
1121
+ # endregion
1122
+ # region modules
1123
+ resnets: nn.ModuleList = field(repr=False)
1124
+ transformers: nn.ModuleList = field(repr=False)
1125
+ sampler: nn.Conv2d | None
1126
+ # endregion
1127
+ # region relative names
1128
+ resnets_rname: str
1129
+ transformers_rname: str
1130
+ sampler_rname: str
1131
+ # endregion
1132
+ # region absolute names
1133
+ resnets_name: str = field(init=False, repr=False)
1134
+ transformers_name: str = field(init=False, repr=False)
1135
+ sampler_name: str = field(init=False, repr=False)
1136
+ resnet_names: list[str] = field(init=False, repr=False)
1137
+ transformer_names: list[str] = field(init=False, repr=False)
1138
+ # endregion
1139
+ # region absolute keys
1140
+ sampler_key: str = field(init=False, repr=False)
1141
+ # endregion
1142
+ # region child structs
1143
+ resnet_structs: list[DiffusionResnetStruct] = field(init=False, repr=False)
1144
+ transformer_structs: list[DiffusionTransformerStruct] = field(init=False, repr=False)
1145
+ # endregion
1146
+
1147
+ @property
1148
+ def downsample(self) -> nn.Conv2d | None:
1149
+ return self.sampler if self.is_downsample_block() else None
1150
+
1151
+ @property
1152
+ def upsample(self) -> nn.Conv2d | None:
1153
+ return self.sampler if self.is_upsample_block() else None
1154
+
1155
+ def __post_init__(self) -> None:
1156
+ super().__post_init__()
1157
+ if self.is_downsample_block():
1158
+ assert len(self.resnets) == len(self.transformers) or len(self.transformers) == 0
1159
+ if self.parent is not None and isinstance(self.parent, UNetStruct):
1160
+ assert self.rname == f"{self.parent.down_blocks_rname}.{self.idx}"
1161
+ elif self.is_mid_block():
1162
+ assert len(self.resnets) == len(self.transformers) + 1 or len(self.transformers) == 0
1163
+ if self.parent is not None and isinstance(self.parent, UNetStruct):
1164
+ assert self.rname == self.parent.mid_block_name
1165
+ assert self.idx == 0
1166
+ else:
1167
+ assert self.is_upsample_block(), f"Unsupported block type: {self.block_type}"
1168
+ assert len(self.resnets) == len(self.transformers) or len(self.transformers) == 0
1169
+ if self.parent is not None and isinstance(self.parent, UNetStruct):
1170
+ assert self.rname == f"{self.parent.up_blocks_rname}.{self.idx}"
1171
+ resnet_rnames = [f"{self.resnets_rname}.{idx}" for idx in range(len(self.resnets))]
1172
+ transformer_rnames = [f"{self.transformers_rname}.{idx}" for idx in range(len(self.transformers))]
1173
+ self.resnets_name = join_name(self.name, self.resnets_rname)
1174
+ self.transformers_name = join_name(self.name, self.transformers_rname)
1175
+ self.resnet_names = [join_name(self.name, rname) for rname in resnet_rnames]
1176
+ self.transformer_names = [join_name(self.name, rname) for rname in transformer_rnames]
1177
+ self.sampler_name = join_name(self.name, self.sampler_rname)
1178
+ self.sampler_key = join_name(self.key, self.sampler_rkey, sep="_")
1179
+ self.resnet_structs = [
1180
+ self.resnet_struct_cls.construct(
1181
+ resnet, parent=self, fname="resnet", rname=rname, rkey=self.resnet_rkey, idx=idx
1182
+ )
1183
+ for idx, (resnet, rname) in enumerate(zip(self.resnets, resnet_rnames, strict=True))
1184
+ ]
1185
+ self.transformer_structs = [
1186
+ self.transformer_struct_cls.construct(
1187
+ transformer, parent=self, fname="transformer", rname=rname, rkey=self.transformer_rkey, idx=idx
1188
+ )
1189
+ for idx, (transformer, rname) in enumerate(zip(self.transformers, transformer_rnames, strict=True))
1190
+ ]
1191
+
1192
+ def is_downsample_block(self) -> bool:
1193
+ return self.block_type == self.BlockType.DOWN
1194
+
1195
+ def is_mid_block(self) -> bool:
1196
+ return self.block_type == self.BlockType.MID
1197
+
1198
+ def is_upsample_block(self) -> bool:
1199
+ return self.block_type == self.BlockType.UP
1200
+
1201
+ def has_downsample(self) -> bool:
1202
+ return self.is_downsample_block() and self.sampler is not None
1203
+
1204
+ def has_upsample(self) -> bool:
1205
+ return self.is_upsample_block() and self.sampler is not None
1206
+
1207
+ def named_key_modules(self) -> tp.Generator[tp.Tuple[str, str, nn.Module, BaseModuleStruct, str], None, None]:
1208
+ for resnet in self.resnet_structs:
1209
+ yield from resnet.named_key_modules()
1210
+ for transformer in self.transformer_structs:
1211
+ yield from transformer.named_key_modules()
1212
+ if self.sampler is not None:
1213
+ yield self.sampler_key, self.sampler_name, self.sampler, self, "sampler"
1214
+
1215
+ def iter_attention_structs(self) -> tp.Generator[DiffusionAttentionStruct, None, None]:
1216
+ for transformer in self.transformer_structs:
1217
+ yield from transformer.iter_attention_structs()
1218
+
1219
+ def iter_transformer_block_structs(self) -> tp.Generator[DiffusionTransformerBlockStruct, None, None]:
1220
+ for transformer in self.transformer_structs:
1221
+ yield from transformer.iter_transformer_block_structs()
1222
+
1223
+ @staticmethod
1224
+ def _default_construct(
1225
+ module: UNET_BLOCK_CLS,
1226
+ /,
1227
+ parent: tp.Optional["UNetStruct"] = None,
1228
+ fname: str = "",
1229
+ rname: str = "",
1230
+ rkey: str = "",
1231
+ idx: int = 0,
1232
+ **kwargs,
1233
+ ) -> "UNetBlockStruct":
1234
+ resnets, resnets_rname = module.resnets, "resnets"
1235
+ if isinstance(module, (DownBlock2D, CrossAttnDownBlock2D)):
1236
+ block_type = UNetBlockStruct.BlockType.DOWN
1237
+ if isinstance(module, CrossAttnDownBlock2D) and module.attentions is not None:
1238
+ transformers, transformers_rname = module.attentions, "attentions"
1239
+ else:
1240
+ transformers, transformers_rname = [], ""
1241
+ if module.downsamplers is None:
1242
+ sampler, sampler_rname = None, ""
1243
+ else:
1244
+ assert len(module.downsamplers) == 1
1245
+ downsampler = module.downsamplers[0]
1246
+ assert isinstance(downsampler, Downsample2D)
1247
+ sampler, sampler_rname = downsampler.conv, "downsamplers.0.conv"
1248
+ assert isinstance(sampler, nn.Conv2d)
1249
+ elif isinstance(module, (UNetMidBlock2D, UNetMidBlock2DCrossAttn)):
1250
+ block_type = UNetBlockStruct.BlockType.MID
1251
+ if (isinstance(module, UNetMidBlock2DCrossAttn) or module.add_attention) and module.attentions is not None:
1252
+ transformers, transformers_rname = module.attentions, "attentions"
1253
+ else:
1254
+ transformers, transformers_rname = [], ""
1255
+ sampler, sampler_rname = None, ""
1256
+ elif isinstance(module, (UpBlock2D, CrossAttnUpBlock2D)):
1257
+ block_type = UNetBlockStruct.BlockType.UP
1258
+ if isinstance(module, CrossAttnUpBlock2D) and module.attentions is not None:
1259
+ transformers, transformers_rname = module.attentions, "attentions"
1260
+ else:
1261
+ transformers, transformers_rname = [], ""
1262
+ if module.upsamplers is None:
1263
+ sampler, sampler_rname = None, ""
1264
+ else:
1265
+ assert len(module.upsamplers) == 1
1266
+ upsampler = module.upsamplers[0]
1267
+ assert isinstance(upsampler, Upsample2D)
1268
+ sampler, sampler_rname = upsampler.conv, "upsamplers.0.conv"
1269
+ assert isinstance(sampler, nn.Conv2d)
1270
+ else:
1271
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
1272
+ return UNetBlockStruct(
1273
+ module=module,
1274
+ parent=parent,
1275
+ fname=fname,
1276
+ idx=idx,
1277
+ rname=rname,
1278
+ rkey=rkey,
1279
+ block_type=block_type,
1280
+ resnets=resnets,
1281
+ transformers=transformers,
1282
+ sampler=sampler,
1283
+ resnets_rname=resnets_rname,
1284
+ transformers_rname=transformers_rname,
1285
+ sampler_rname=sampler_rname,
1286
+ )
1287
+
1288
+ @classmethod
1289
+ def _get_default_key_map(cls) -> dict[str, set[str]]:
1290
+ """Get the default allowed keys."""
1291
+ key_map: dict[str, set[str]] = defaultdict(set)
1292
+ resnet_cls = cls.resnet_struct_cls
1293
+ resnet_key = resnet_rkey = cls.resnet_rkey
1294
+ resnet_key_map = resnet_cls._get_default_key_map()
1295
+ for rkey, keys in resnet_key_map.items():
1296
+ rkey = join_name(resnet_rkey, rkey, sep="_")
1297
+ for key in keys:
1298
+ key = join_name(resnet_key, key, sep="_")
1299
+ key_map[rkey].add(key)
1300
+ key_map[resnet_rkey].add(key)
1301
+ transformer_cls = cls.transformer_struct_cls
1302
+ transformer_key = transformer_rkey = cls.transformer_rkey
1303
+ transformer_key_map = transformer_cls._get_default_key_map()
1304
+ for rkey, keys in transformer_key_map.items():
1305
+ trkey = join_name(transformer_rkey, rkey, sep="_")
1306
+ for key in keys:
1307
+ key = join_name(transformer_key, key, sep="_")
1308
+ key_map[rkey].add(key)
1309
+ key_map[trkey].add(key)
1310
+ return {k: v for k, v in key_map.items() if v}
1311
+
1312
+
1313
+ @dataclass(kw_only=True)
1314
+ class UNetStruct(DiffusionModelStruct):
1315
+ # region relative keys
1316
+ input_embed_rkey: tp.ClassVar[str] = "input_embed"
1317
+ """hidden_states = input_embed(hidden_states), e.g., conv_in"""
1318
+ time_embed_rkey: tp.ClassVar[str] = "time_embed"
1319
+ """temb = time_embed(timesteps, hidden_states)"""
1320
+ add_time_embed_rkey: tp.ClassVar[str] = "time_embed"
1321
+ """add_temb = add_time_embed(timesteps, encoder_hidden_states)"""
1322
+ text_embed_rkey: tp.ClassVar[str] = "text_embed"
1323
+ """encoder_hidden_states = text_embed(encoder_hidden_states)"""
1324
+ norm_out_rkey: tp.ClassVar[str] = "output_embed"
1325
+ """hidden_states = norm_out(hidden_states), e.g., conv_norm_out"""
1326
+ proj_out_rkey: tp.ClassVar[str] = "output_embed"
1327
+ """hidden_states = output_embed(hidden_states), e.g., conv_out"""
1328
+ down_block_rkey: tp.ClassVar[str] = "down"
1329
+ mid_block_rkey: tp.ClassVar[str] = "mid"
1330
+ up_block_rkey: tp.ClassVar[str] = "up"
1331
+ down_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
1332
+ mid_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
1333
+ up_block_struct_cls: tp.ClassVar[type[UNetBlockStruct]] = UNetBlockStruct
1334
+ # endregion
1335
+
1336
+ # region child modules
1337
+ # region pre-block modules
1338
+ input_embed: nn.Conv2d
1339
+ time_embed: TimestepEmbedding
1340
+ """Time embedding"""
1341
+ add_time_embed: (
1342
+ TextTimeEmbedding
1343
+ | TextImageTimeEmbedding
1344
+ | TimestepEmbedding
1345
+ | ImageTimeEmbedding
1346
+ | ImageHintTimeEmbedding
1347
+ | None
1348
+ )
1349
+ """Additional time embedding"""
1350
+ text_embed: nn.Linear | ImageProjection | TextImageProjection | None
1351
+ """Text embedding"""
1352
+ # region post-block modules
1353
+ norm_out: nn.GroupNorm | None
1354
+ proj_out: nn.Conv2d
1355
+ # endregion
1356
+ # endregion
1357
+ down_blocks: nn.ModuleList = field(repr=False)
1358
+ mid_block: nn.Module = field(repr=False)
1359
+ up_blocks: nn.ModuleList = field(repr=False)
1360
+ # endregion
1361
+ # region relative names
1362
+ input_embed_rname: str
1363
+ time_embed_rname: str
1364
+ add_time_embed_rname: str
1365
+ text_embed_rname: str
1366
+ norm_out_rname: str
1367
+ proj_out_rname: str
1368
+ down_blocks_rname: str
1369
+ mid_block_rname: str
1370
+ up_blocks_rname: str
1371
+ # endregion
1372
+ # region absolute names
1373
+ input_embed_name: str = field(init=False, repr=False)
1374
+ time_embed_name: str = field(init=False, repr=False)
1375
+ add_time_embed_name: str = field(init=False, repr=False)
1376
+ text_embed_name: str = field(init=False, repr=False)
1377
+ norm_out_name: str = field(init=False, repr=False)
1378
+ proj_out_name: str = field(init=False, repr=False)
1379
+ down_blocks_name: str = field(init=False, repr=False)
1380
+ mid_block_name: str = field(init=False, repr=False)
1381
+ up_blocks_name: str = field(init=False, repr=False)
1382
+ down_block_names: list[str] = field(init=False, repr=False)
1383
+ up_block_names: list[str] = field(init=False, repr=False)
1384
+ # endregion
1385
+ # region absolute keys
1386
+ input_embed_key: str = field(init=False, repr=False)
1387
+ time_embed_key: str = field(init=False, repr=False)
1388
+ add_time_embed_key: str = field(init=False, repr=False)
1389
+ text_embed_key: str = field(init=False, repr=False)
1390
+ norm_out_key: str = field(init=False, repr=False)
1391
+ proj_out_key: str = field(init=False, repr=False)
1392
+ # endregion
1393
+ # region child structs
1394
+ down_block_structs: list[UNetBlockStruct] = field(init=False, repr=False)
1395
+ mid_block_struct: UNetBlockStruct = field(init=False, repr=False)
1396
+ up_block_structs: list[UNetBlockStruct] = field(init=False, repr=False)
1397
+ # endregion
1398
+
1399
+ @property
1400
+ def num_down_blocks(self) -> int:
1401
+ return len(self.down_blocks)
1402
+
1403
+ @property
1404
+ def num_up_blocks(self) -> int:
1405
+ return len(self.up_blocks)
1406
+
1407
+ @property
1408
+ def num_blocks(self) -> int:
1409
+ return self.num_down_blocks + 1 + self.num_up_blocks
1410
+
1411
+ @property
1412
+ def block_structs(self) -> list[UNetBlockStruct]:
1413
+ return [*self.down_block_structs, self.mid_block_struct, *self.up_block_structs]
1414
+
1415
+ def __post_init__(self) -> None:
1416
+ super().__post_init__()
1417
+ down_block_rnames = [f"{self.down_blocks_rname}.{idx}" for idx in range(len(self.down_blocks))]
1418
+ up_block_rnames = [f"{self.up_blocks_rname}.{idx}" for idx in range(len(self.up_blocks))]
1419
+ self.down_blocks_name = join_name(self.name, self.down_blocks_rname)
1420
+ self.mid_block_name = join_name(self.name, self.mid_block_rname)
1421
+ self.up_blocks_name = join_name(self.name, self.up_blocks_rname)
1422
+ self.down_block_names = [join_name(self.name, rname) for rname in down_block_rnames]
1423
+ self.up_block_names = [join_name(self.name, rname) for rname in up_block_rnames]
1424
+ self.pre_module_structs = {}
1425
+ for fname in ("time_embed", "add_time_embed", "text_embed", "input_embed"):
1426
+ module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
1427
+ setattr(self, f"{fname}_key", join_name(self.key, rkey, sep="_"))
1428
+ if module is not None or rname:
1429
+ setattr(self, f"{fname}_name", join_name(self.name, rname))
1430
+ else:
1431
+ setattr(self, f"{fname}_name", "")
1432
+ if module is not None:
1433
+ assert rname, f"rname of {fname} must not be empty"
1434
+ self.pre_module_structs[getattr(self, f"{fname}_name")] = DiffusionModuleStruct(
1435
+ module=module, parent=self, fname=fname, rname=rname, rkey=rkey
1436
+ )
1437
+ self.post_module_structs = {}
1438
+ for fname in ("norm_out", "proj_out"):
1439
+ module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
1440
+ setattr(self, f"{fname}_key", join_name(self.key, rkey, sep="_"))
1441
+ if module is not None or rname:
1442
+ setattr(self, f"{fname}_name", join_name(self.name, rname))
1443
+ else:
1444
+ setattr(self, f"{fname}_name", "")
1445
+ if module is not None:
1446
+ self.post_module_structs[getattr(self, f"{fname}_name")] = DiffusionModuleStruct(
1447
+ module=module, parent=self, fname=fname, rname=rname, rkey=rkey
1448
+ )
1449
+ self.down_block_structs = [
1450
+ self.down_block_struct_cls.construct(
1451
+ block, parent=self, fname="down_block", rname=rname, rkey=self.down_block_rkey, idx=idx
1452
+ )
1453
+ for idx, (block, rname) in enumerate(zip(self.down_blocks, down_block_rnames, strict=True))
1454
+ ]
1455
+ self.mid_block_struct = self.mid_block_struct_cls.construct(
1456
+ self.mid_block, parent=self, fname="mid_block", rname=self.mid_block_name, rkey=self.mid_block_rkey
1457
+ )
1458
+ self.up_block_structs = [
1459
+ self.up_block_struct_cls.construct(
1460
+ block, parent=self, fname="up_block", rname=rname, rkey=self.up_block_rkey, idx=idx
1461
+ )
1462
+ for idx, (block, rname) in enumerate(zip(self.up_blocks, up_block_rnames, strict=True))
1463
+ ]
1464
+
1465
+ def get_prev_module_keys(self) -> tuple[str, ...]:
1466
+ return tuple({self.input_embed_key, self.time_embed_key, self.add_time_embed_key, self.text_embed_key})
1467
+
1468
+ def get_post_module_keys(self) -> tuple[str, ...]:
1469
+ return tuple({self.norm_out_key, self.proj_out_key})
1470
+
1471
+ def _get_iter_block_activations_args(
1472
+ self, **input_kwargs
1473
+ ) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
1474
+ layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
1475
+ num_down_blocks = len(self.down_blocks)
1476
+ num_up_blocks = len(self.up_blocks)
1477
+ layers.extend(self.down_blocks)
1478
+ layer_structs.extend(self.down_block_structs)
1479
+ use_prev_layer_outputs.append(False)
1480
+ use_prev_layer_outputs.extend([True] * (num_down_blocks - 1))
1481
+ recomputes.append(False)
1482
+ # region check whether down block's outputs are changed
1483
+ _mid_block_additional_residual = input_kwargs.get("mid_block_additional_residual", None)
1484
+ _down_block_additional_residuals = input_kwargs.get("down_block_additional_residuals", None)
1485
+ _is_adapter = input_kwargs.get("down_intrablock_additional_residuals", None) is not None
1486
+ if not _is_adapter and _mid_block_additional_residual is None and _down_block_additional_residuals is not None:
1487
+ _is_adapter = True
1488
+ for down_block in self.down_blocks:
1489
+ if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention:
1490
+ # outputs unchanged
1491
+ recomputes.append(False)
1492
+ elif _is_adapter:
1493
+ # outputs changed
1494
+ recomputes.append(True)
1495
+ else:
1496
+ # outputs unchanged
1497
+ recomputes.append(False)
1498
+ # endregion
1499
+ layers.append(self.mid_block)
1500
+ layer_structs.append(self.mid_block_struct)
1501
+ use_prev_layer_outputs.append(False)
1502
+ # recomputes is already appened in the previous down blocks
1503
+ layers.extend(self.up_blocks)
1504
+ layer_structs.extend(self.up_block_structs)
1505
+ use_prev_layer_outputs.append(False)
1506
+ use_prev_layer_outputs.extend([True] * (num_up_blocks - 1))
1507
+ recomputes += [True] * num_up_blocks
1508
+ return layers, layer_structs, recomputes, use_prev_layer_outputs
1509
+
1510
+ @staticmethod
1511
+ def _default_construct(
1512
+ module: tp.Union[UNET_PIPELINE_CLS, UNET_CLS],
1513
+ /,
1514
+ parent: tp.Optional[BaseModuleStruct] = None,
1515
+ fname: str = "",
1516
+ rname: str = "",
1517
+ rkey: str = "",
1518
+ idx: int = 0,
1519
+ **kwargs,
1520
+ ) -> "UNetStruct":
1521
+ if isinstance(module, UNET_PIPELINE_CLS):
1522
+ module = module.unet
1523
+ if isinstance(module, (UNet2DConditionModel, UNet2DModel)):
1524
+ input_embed, time_embed = module.conv_in, module.time_embedding
1525
+ input_embed_rname, time_embed_rname = "conv_in", "time_embedding"
1526
+ text_embed, text_embed_rname = None, ""
1527
+ add_time_embed, add_time_embed_rname = None, ""
1528
+ if hasattr(module, "encoder_hid_proj"):
1529
+ text_embed, text_embed_rname = module.encoder_hid_proj, "encoder_hid_proj"
1530
+ if hasattr(module, "add_embedding"):
1531
+ add_time_embed, add_time_embed_rname = module.add_embedding, "add_embedding"
1532
+ norm_out, norm_out_rname = module.conv_norm_out, "conv_norm_out"
1533
+ proj_out, proj_out_rname = module.conv_out, "conv_out"
1534
+ down_blocks, down_blocks_rname = module.down_blocks, "down_blocks"
1535
+ mid_block, mid_block_rname = module.mid_block, "mid_block"
1536
+ up_blocks, up_blocks_rname = module.up_blocks, "up_blocks"
1537
+ return UNetStruct(
1538
+ module=module,
1539
+ parent=parent,
1540
+ fname=fname,
1541
+ idx=idx,
1542
+ rname=rname,
1543
+ rkey=rkey,
1544
+ input_embed=input_embed,
1545
+ time_embed=time_embed,
1546
+ add_time_embed=add_time_embed,
1547
+ text_embed=text_embed,
1548
+ norm_out=norm_out,
1549
+ proj_out=proj_out,
1550
+ down_blocks=down_blocks,
1551
+ mid_block=mid_block,
1552
+ up_blocks=up_blocks,
1553
+ input_embed_rname=input_embed_rname,
1554
+ time_embed_rname=time_embed_rname,
1555
+ add_time_embed_rname=add_time_embed_rname,
1556
+ text_embed_rname=text_embed_rname,
1557
+ norm_out_rname=norm_out_rname,
1558
+ proj_out_rname=proj_out_rname,
1559
+ down_blocks_rname=down_blocks_rname,
1560
+ mid_block_rname=mid_block_rname,
1561
+ up_blocks_rname=up_blocks_rname,
1562
+ )
1563
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
1564
+
1565
+ @classmethod
1566
+ def _get_default_key_map(cls) -> dict[str, set[str]]:
1567
+ """Get the default allowed keys."""
1568
+ key_map: dict[str, set[str]] = defaultdict(set)
1569
+ for idx, (block_key, block_cls) in enumerate(
1570
+ (
1571
+ (cls.down_block_rkey, cls.down_block_struct_cls),
1572
+ (cls.mid_block_rkey, cls.mid_block_struct_cls),
1573
+ (cls.up_block_rkey, cls.up_block_struct_cls),
1574
+ )
1575
+ ):
1576
+ block_key_map: dict[str, set[str]] = defaultdict(set)
1577
+ if idx != 1:
1578
+ sampler_key = join_name(block_key, block_cls.sampler_rkey, sep="_")
1579
+ sampler_rkey = block_cls.sampler_rkey
1580
+ block_key_map[sampler_rkey].add(sampler_key)
1581
+ _block_key_map = block_cls._get_default_key_map()
1582
+ for rkey, keys in _block_key_map.items():
1583
+ for key in keys:
1584
+ key = join_name(block_key, key, sep="_")
1585
+ block_key_map[rkey].add(key)
1586
+ for rkey, keys in block_key_map.items():
1587
+ key_map[rkey].update(keys)
1588
+ if block_key:
1589
+ key_map[block_key].update(keys)
1590
+ keys: set[str] = set()
1591
+ keys.add(cls.input_embed_rkey)
1592
+ keys.add(cls.time_embed_rkey)
1593
+ keys.add(cls.add_time_embed_rkey)
1594
+ keys.add(cls.text_embed_rkey)
1595
+ keys.add(cls.norm_out_rkey)
1596
+ keys.add(cls.proj_out_rkey)
1597
+ for mapped_keys in key_map.values():
1598
+ for key in mapped_keys:
1599
+ keys.add(key)
1600
+ if "embed" not in keys and "embed" not in key_map:
1601
+ key_map["embed"].add(cls.input_embed_rkey)
1602
+ key_map["embed"].add(cls.time_embed_rkey)
1603
+ key_map["embed"].add(cls.add_time_embed_rkey)
1604
+ key_map["embed"].add(cls.text_embed_rkey)
1605
+ key_map["embed"].add(cls.norm_out_rkey)
1606
+ key_map["embed"].add(cls.proj_out_rkey)
1607
+ for key in keys:
1608
+ if key in key_map:
1609
+ key_map[key].clear()
1610
+ key_map[key].add(key)
1611
+ return {k: v for k, v in key_map.items() if v}
1612
+
1613
+
1614
+ @dataclass(kw_only=True)
1615
+ class DiTStruct(DiffusionModelStruct, DiffusionTransformerStruct):
1616
+ # region relative keys
1617
+ input_embed_rkey: tp.ClassVar[str] = "input_embed"
1618
+ """hidden_states = input_embed(hidden_states), e.g., conv_in"""
1619
+ time_embed_rkey: tp.ClassVar[str] = "time_embed"
1620
+ """temb = time_embed(timesteps)"""
1621
+ text_embed_rkey: tp.ClassVar[str] = "text_embed"
1622
+ """encoder_hidden_states = text_embed(encoder_hidden_states)"""
1623
+ norm_in_rkey: tp.ClassVar[str] = "input_embed"
1624
+ """hidden_states = norm_in(hidden_states)"""
1625
+ proj_in_rkey: tp.ClassVar[str] = "input_embed"
1626
+ """hidden_states = proj_in(hidden_states)"""
1627
+ norm_out_rkey: tp.ClassVar[str] = "output_embed"
1628
+ """hidden_states = norm_out(hidden_states)"""
1629
+ proj_out_rkey: tp.ClassVar[str] = "output_embed"
1630
+ """hidden_states = proj_out(hidden_states)"""
1631
+ transformer_block_rkey: tp.ClassVar[str] = ""
1632
+ # endregion
1633
+
1634
+ # region child modules
1635
+ input_embed: PatchEmbed
1636
+ time_embed: AdaLayerNormSingle | CombinedTimestepTextProjEmbeddings | TimestepEmbedding
1637
+ text_embed: PixArtAlphaTextProjection | nn.Linear
1638
+ norm_in: None = field(init=False, repr=False, default=None)
1639
+ proj_in: None = field(init=False, repr=False, default=None)
1640
+ norm_out: nn.LayerNorm | AdaLayerNormContinuous | None
1641
+ proj_out: nn.Linear
1642
+ # endregion
1643
+ # region relative names
1644
+ input_embed_rname: str
1645
+ time_embed_rname: str
1646
+ text_embed_rname: str
1647
+ norm_in_rname: str = field(init=False, repr=False, default="")
1648
+ proj_in_rname: str = field(init=False, repr=False, default="")
1649
+ norm_out_rname: str
1650
+ proj_out_rname: str
1651
+ # endregion
1652
+ # region absolute names
1653
+ input_embed_name: str = field(init=False, repr=False)
1654
+ time_embed_name: str = field(init=False, repr=False)
1655
+ text_embed_name: str = field(init=False, repr=False)
1656
+ # endregion
1657
+ # region absolute keys
1658
+ input_embed_key: str = field(init=False, repr=False)
1659
+ time_embed_key: str = field(init=False, repr=False)
1660
+ text_embed_key: str = field(init=False, repr=False)
1661
+ norm_out_key: str = field(init=False, repr=False)
1662
+ # endregion
1663
+
1664
+ @property
1665
+ def num_blocks(self) -> int:
1666
+ return len(self.transformer_blocks)
1667
+
1668
+ @property
1669
+ def block_structs(self) -> list[DiffusionTransformerBlockStruct]:
1670
+ return self.transformer_block_structs
1671
+
1672
+ @property
1673
+ def block_names(self) -> list[str]:
1674
+ return self.transformer_block_names
1675
+
1676
+ def __post_init__(self) -> None:
1677
+ super().__post_init__()
1678
+ self.pre_module_structs = {}
1679
+ for fname in ("input_embed", "time_embed", "text_embed"):
1680
+ module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
1681
+ setattr(self, f"{fname}_key", join_name(self.key, rkey, sep="_"))
1682
+ if module is not None or rname:
1683
+ setattr(self, f"{fname}_name", join_name(self.name, rname))
1684
+ else:
1685
+ setattr(self, f"{fname}_name", "")
1686
+ if module is not None:
1687
+ self.pre_module_structs.setdefault(
1688
+ getattr(self, f"{fname}_name"),
1689
+ DiffusionModuleStruct(module=module, parent=self, fname=fname, rname=rname, rkey=rkey),
1690
+ )
1691
+ self.post_module_structs = {}
1692
+ self.norm_out_key = join_name(self.key, self.norm_out_rkey, sep="_")
1693
+ for fname in ("norm_out", "proj_out"):
1694
+ module, rname, rkey = getattr(self, fname), getattr(self, f"{fname}_rname"), getattr(self, f"{fname}_rkey")
1695
+ if module is not None:
1696
+ self.post_module_structs.setdefault(
1697
+ getattr(self, f"{fname}_name"),
1698
+ DiffusionModuleStruct(module=module, parent=self, fname=fname, rname=rname, rkey=rkey),
1699
+ )
1700
+
1701
+ def get_prev_module_keys(self) -> tuple[str, ...]:
1702
+ return tuple({self.input_embed_key, self.time_embed_key, self.text_embed_key})
1703
+
1704
+ def get_post_module_keys(self) -> tuple[str, ...]:
1705
+ return tuple({self.norm_out_key, self.proj_out_key})
1706
+
1707
+ def _get_iter_block_activations_args(
1708
+ self, **input_kwargs
1709
+ ) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
1710
+ """
1711
+ Get the arguments for iterating over the layers and their activations.
1712
+
1713
+ Args:
1714
+ skip_pre_modules (`bool`):
1715
+ Whether to skip the pre-modules
1716
+ skip_post_modules (`bool`):
1717
+ Whether to skip the post-modules
1718
+
1719
+ Returns:
1720
+ `tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]`:
1721
+ the layers, the layer structs, the recomputes, and the use_prev_layer_outputs
1722
+ """
1723
+ layers, layer_structs, recomputes, use_prev_layer_outputs = [], [], [], []
1724
+ layers.extend(self.transformer_blocks)
1725
+ layer_structs.extend(self.transformer_block_structs)
1726
+ use_prev_layer_outputs.append(False)
1727
+ use_prev_layer_outputs.extend([True] * (len(self.transformer_blocks) - 1))
1728
+ recomputes.extend([False] * len(self.transformer_blocks))
1729
+ return layers, layer_structs, recomputes, use_prev_layer_outputs
1730
+
1731
+ @staticmethod
1732
+ def _default_construct(
1733
+ module: tp.Union[DIT_PIPELINE_CLS, DIT_CLS],
1734
+ /,
1735
+ parent: tp.Optional[BaseModuleStruct] = None,
1736
+ fname: str = "",
1737
+ rname: str = "",
1738
+ rkey: str = "",
1739
+ idx: int = 0,
1740
+ **kwargs,
1741
+ ) -> "DiTStruct":
1742
+ if isinstance(module, DIT_PIPELINE_CLS):
1743
+ module = module.transformer
1744
+ if isinstance(module, FluxTransformer2DModel):
1745
+ return FluxStruct.construct(module, parent=parent, fname=fname, rname=rname, rkey=rkey, idx=idx, **kwargs)
1746
+ else:
1747
+ if isinstance(module, PixArtTransformer2DModel):
1748
+ input_embed, input_embed_rname = module.pos_embed, "pos_embed"
1749
+ time_embed, time_embed_rname = module.adaln_single, "adaln_single"
1750
+ text_embed, text_embed_rname = module.caption_projection, "caption_projection"
1751
+ norm_out, norm_out_rname = module.norm_out, "norm_out"
1752
+ proj_out, proj_out_rname = module.proj_out, "proj_out"
1753
+ transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
1754
+ # ! in fact, `module.adaln_single.emb` is `time_embed`,
1755
+ # ! `module.adaln_single.linear` is `transformer_norm`
1756
+ # ! but since PixArt shares the `transformer_norm`, we categorize it as `time_embed`
1757
+ elif isinstance(module, SanaTransformer2DModel):
1758
+ input_embed, input_embed_rname = module.patch_embed, "patch_embed"
1759
+ time_embed, time_embed_rname = module.time_embed, "time_embed"
1760
+ text_embed, text_embed_rname = module.caption_projection, "caption_projection"
1761
+ norm_out, norm_out_rname = module.norm_out, "norm_out"
1762
+ proj_out, proj_out_rname = module.proj_out, "proj_out"
1763
+ transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
1764
+ elif isinstance(module, SD3Transformer2DModel):
1765
+ input_embed, input_embed_rname = module.pos_embed, "pos_embed"
1766
+ time_embed, time_embed_rname = module.time_text_embed, "time_text_embed"
1767
+ text_embed, text_embed_rname = module.context_embedder, "context_embedder"
1768
+ norm_out, norm_out_rname = module.norm_out, "norm_out"
1769
+ proj_out, proj_out_rname = module.proj_out, "proj_out"
1770
+ transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
1771
+ else:
1772
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
1773
+ return DiTStruct(
1774
+ module=module,
1775
+ parent=parent,
1776
+ fname=fname,
1777
+ idx=idx,
1778
+ rname=rname,
1779
+ rkey=rkey,
1780
+ input_embed=input_embed,
1781
+ time_embed=time_embed,
1782
+ text_embed=text_embed,
1783
+ transformer_blocks=transformer_blocks,
1784
+ norm_out=norm_out,
1785
+ proj_out=proj_out,
1786
+ input_embed_rname=input_embed_rname,
1787
+ time_embed_rname=time_embed_rname,
1788
+ text_embed_rname=text_embed_rname,
1789
+ norm_out_rname=norm_out_rname,
1790
+ proj_out_rname=proj_out_rname,
1791
+ transformer_blocks_rname=transformer_blocks_rname,
1792
+ )
1793
+
1794
+ @classmethod
1795
+ def _get_default_key_map(cls) -> dict[str, set[str]]:
1796
+ """Get the default allowed keys."""
1797
+ key_map: dict[str, set[str]] = defaultdict(set)
1798
+ block_cls = cls.transformer_block_struct_cls
1799
+ block_key = block_rkey = cls.transformer_block_rkey
1800
+ block_key_map = block_cls._get_default_key_map()
1801
+ for rkey, keys in block_key_map.items():
1802
+ brkey = join_name(block_rkey, rkey, sep="_")
1803
+ for key in keys:
1804
+ key = join_name(block_key, key, sep="_")
1805
+ key_map[rkey].add(key)
1806
+ key_map[brkey].add(key)
1807
+ if block_rkey:
1808
+ key_map[block_rkey].add(key)
1809
+ keys: set[str] = set()
1810
+ keys.add(cls.input_embed_rkey)
1811
+ keys.add(cls.time_embed_rkey)
1812
+ keys.add(cls.text_embed_rkey)
1813
+ keys.add(cls.norm_in_rkey)
1814
+ keys.add(cls.proj_in_rkey)
1815
+ keys.add(cls.norm_out_rkey)
1816
+ keys.add(cls.proj_out_rkey)
1817
+ for mapped_keys in key_map.values():
1818
+ for key in mapped_keys:
1819
+ keys.add(key)
1820
+ if "embed" not in keys and "embed" not in key_map:
1821
+ key_map["embed"].add(cls.input_embed_rkey)
1822
+ key_map["embed"].add(cls.time_embed_rkey)
1823
+ key_map["embed"].add(cls.text_embed_rkey)
1824
+ key_map["embed"].add(cls.norm_in_rkey)
1825
+ key_map["embed"].add(cls.proj_in_rkey)
1826
+ key_map["embed"].add(cls.norm_out_rkey)
1827
+ key_map["embed"].add(cls.proj_out_rkey)
1828
+ for key in keys:
1829
+ if key in key_map:
1830
+ key_map[key].clear()
1831
+ key_map[key].add(key)
1832
+ return {k: v for k, v in key_map.items() if v}
1833
+
1834
+
1835
+ @dataclass(kw_only=True)
1836
+ class FluxStruct(DiTStruct):
1837
+ # region relative keys
1838
+ single_transformer_block_rkey: tp.ClassVar[str] = ""
1839
+ single_transformer_block_struct_cls: tp.ClassVar[type[DiffusionTransformerBlockStruct]] = (
1840
+ DiffusionTransformerBlockStruct
1841
+ )
1842
+ # endregion
1843
+
1844
+ module: FluxTransformer2DModel = field(repr=False, kw_only=False)
1845
+ """the module of FluxTransformer2DModel"""
1846
+ # region child modules
1847
+ input_embed: nn.Linear
1848
+ time_embed: CombinedTimestepGuidanceTextProjEmbeddings | CombinedTimestepTextProjEmbeddings
1849
+ text_embed: nn.Linear
1850
+ single_transformer_blocks: nn.ModuleList = field(repr=False)
1851
+ # endregion
1852
+ # region relative names
1853
+ single_transformer_blocks_rname: str
1854
+ # endregion
1855
+ # region absolute names
1856
+ single_transformer_blocks_name: str = field(init=False, repr=False)
1857
+ single_transformer_block_names: list[str] = field(init=False, repr=False)
1858
+ # endregion
1859
+ # region child structs
1860
+ single_transformer_block_structs: list[DiffusionTransformerBlockStruct] = field(init=False)
1861
+ # endregion
1862
+
1863
+ @property
1864
+ def num_blocks(self) -> int:
1865
+ return len(self.transformer_block_structs) + len(self.single_transformer_block_structs)
1866
+
1867
+ @property
1868
+ def block_structs(self) -> list[DiffusionTransformerBlockStruct]:
1869
+ return [*self.transformer_block_structs, *self.single_transformer_block_structs]
1870
+
1871
+ @property
1872
+ def block_names(self) -> list[str]:
1873
+ return [*self.transformer_block_names, *self.single_transformer_block_names]
1874
+
1875
+ def __post_init__(self) -> None:
1876
+ super().__post_init__()
1877
+ single_transformer_block_rnames = [
1878
+ f"{self.single_transformer_blocks_rname}.{idx}" for idx in range(len(self.single_transformer_blocks))
1879
+ ]
1880
+ self.single_transformer_blocks_name = join_name(self.name, self.single_transformer_blocks_rname)
1881
+ self.single_transformer_block_names = [join_name(self.name, rname) for rname in single_transformer_block_rnames]
1882
+ self.single_transformer_block_structs = [
1883
+ self.single_transformer_block_struct_cls.construct(
1884
+ block,
1885
+ parent=self,
1886
+ fname="single_transformer_block",
1887
+ rname=rname,
1888
+ rkey=self.single_transformer_block_rkey,
1889
+ idx=idx,
1890
+ )
1891
+ for idx, (block, rname) in enumerate(
1892
+ zip(self.single_transformer_blocks, single_transformer_block_rnames, strict=True)
1893
+ )
1894
+ ]
1895
+
1896
+ def _get_iter_block_activations_args(
1897
+ self, **input_kwargs
1898
+ ) -> tuple[list[nn.Module], list[DiffusionModuleStruct | DiffusionBlockStruct], list[bool], list[bool]]:
1899
+ layers, layer_structs, recomputes, use_prev_layer_outputs = super()._get_iter_block_activations_args()
1900
+ layers.extend(self.single_transformer_blocks)
1901
+ layer_structs.extend(self.single_transformer_block_structs)
1902
+ use_prev_layer_outputs.append(False)
1903
+ use_prev_layer_outputs.extend([True] * (len(self.single_transformer_blocks) - 1))
1904
+ recomputes.extend([False] * len(self.single_transformer_blocks))
1905
+ return layers, layer_structs, recomputes, use_prev_layer_outputs
1906
+
1907
+ @staticmethod
1908
+ def _default_construct(
1909
+ module: tp.Union[FluxPipeline, FluxKontextPipeline, FluxControlPipeline, FluxTransformer2DModel],
1910
+ /,
1911
+ parent: tp.Optional[BaseModuleStruct] = None,
1912
+ fname: str = "",
1913
+ rname: str = "",
1914
+ rkey: str = "",
1915
+ idx: int = 0,
1916
+ **kwargs,
1917
+ ) -> "FluxStruct":
1918
+ if isinstance(module, (FluxPipeline, FluxKontextPipeline, FluxControlPipeline)):
1919
+ module = module.transformer
1920
+ if isinstance(module, FluxTransformer2DModel):
1921
+ input_embed, time_embed, text_embed = module.x_embedder, module.time_text_embed, module.context_embedder
1922
+ input_embed_rname, time_embed_rname, text_embed_rname = "x_embedder", "time_text_embed", "context_embedder"
1923
+ norm_out, norm_out_rname = module.norm_out, "norm_out"
1924
+ proj_out, proj_out_rname = module.proj_out, "proj_out"
1925
+ transformer_blocks, transformer_blocks_rname = module.transformer_blocks, "transformer_blocks"
1926
+ single_transformer_blocks = module.single_transformer_blocks
1927
+ single_transformer_blocks_rname = "single_transformer_blocks"
1928
+ return FluxStruct(
1929
+ module=module,
1930
+ parent=parent,
1931
+ fname=fname,
1932
+ idx=idx,
1933
+ rname=rname,
1934
+ rkey=rkey,
1935
+ input_embed=input_embed,
1936
+ time_embed=time_embed,
1937
+ text_embed=text_embed,
1938
+ transformer_blocks=transformer_blocks,
1939
+ single_transformer_blocks=single_transformer_blocks,
1940
+ norm_out=norm_out,
1941
+ proj_out=proj_out,
1942
+ input_embed_rname=input_embed_rname,
1943
+ time_embed_rname=time_embed_rname,
1944
+ text_embed_rname=text_embed_rname,
1945
+ norm_out_rname=norm_out_rname,
1946
+ proj_out_rname=proj_out_rname,
1947
+ transformer_blocks_rname=transformer_blocks_rname,
1948
+ single_transformer_blocks_rname=single_transformer_blocks_rname,
1949
+ )
1950
+ raise NotImplementedError(f"Unsupported module type: {type(module)}")
1951
+
1952
+ @classmethod
1953
+ def _get_default_key_map(cls) -> dict[str, set[str]]:
1954
+ """Get the default allowed keys."""
1955
+ key_map: dict[str, set[str]] = defaultdict(set)
1956
+ for block_rkey, block_cls in (
1957
+ (cls.transformer_block_rkey, cls.transformer_block_struct_cls),
1958
+ (cls.single_transformer_block_rkey, cls.single_transformer_block_struct_cls),
1959
+ ):
1960
+ block_key = block_rkey
1961
+ block_key_map = block_cls._get_default_key_map()
1962
+ for rkey, keys in block_key_map.items():
1963
+ brkey = join_name(block_rkey, rkey, sep="_")
1964
+ for key in keys:
1965
+ key = join_name(block_key, key, sep="_")
1966
+ key_map[rkey].add(key)
1967
+ key_map[brkey].add(key)
1968
+ if block_rkey:
1969
+ key_map[block_rkey].add(key)
1970
+ keys: set[str] = set()
1971
+ keys.add(cls.input_embed_rkey)
1972
+ keys.add(cls.time_embed_rkey)
1973
+ keys.add(cls.text_embed_rkey)
1974
+ keys.add(cls.norm_in_rkey)
1975
+ keys.add(cls.proj_in_rkey)
1976
+ keys.add(cls.norm_out_rkey)
1977
+ keys.add(cls.proj_out_rkey)
1978
+ for mapped_keys in key_map.values():
1979
+ for key in mapped_keys:
1980
+ keys.add(key)
1981
+ if "embed" not in keys and "embed" not in key_map:
1982
+ key_map["embed"].add(cls.input_embed_rkey)
1983
+ key_map["embed"].add(cls.time_embed_rkey)
1984
+ key_map["embed"].add(cls.text_embed_rkey)
1985
+ key_map["embed"].add(cls.norm_in_rkey)
1986
+ key_map["embed"].add(cls.proj_in_rkey)
1987
+ key_map["embed"].add(cls.norm_out_rkey)
1988
+ key_map["embed"].add(cls.proj_out_rkey)
1989
+ for key in keys:
1990
+ if key in key_map:
1991
+ key_map[key].clear()
1992
+ key_map[key].add(key)
1993
+ return {k: v for k, v in key_map.items() if v}
1994
+
1995
+
1996
+ DiffusionAttentionStruct.register_factory(Attention, DiffusionAttentionStruct._default_construct)
1997
+
1998
+ DiffusionFeedForwardStruct.register_factory(
1999
+ (FeedForward, FluxSingleTransformerBlock, GLUMBConv), DiffusionFeedForwardStruct._default_construct
2000
+ )
2001
+
2002
+ DiffusionTransformerBlockStruct.register_factory(DIT_BLOCK_CLS, DiffusionTransformerBlockStruct._default_construct)
2003
+
2004
+ UNetBlockStruct.register_factory(UNET_BLOCK_CLS, UNetBlockStruct._default_construct)
2005
+
2006
+ UNetStruct.register_factory(tp.Union[UNET_PIPELINE_CLS, UNET_CLS], UNetStruct._default_construct)
2007
+
2008
+ FluxStruct.register_factory(
2009
+ tp.Union[FluxPipeline, FluxKontextPipeline, FluxControlPipeline, FluxTransformer2DModel], FluxStruct._default_construct
2010
+ )
2011
+
2012
+ DiTStruct.register_factory(tp.Union[DIT_PIPELINE_CLS, DIT_CLS], DiTStruct._default_construct)
2013
+
2014
+ DiffusionTransformerStruct.register_factory(Transformer2DModel, DiffusionTransformerStruct._default_construct)
2015
+
2016
+ DiffusionModelStruct.register_factory(tp.Union[PIPELINE_CLS, MODEL_CLS], DiffusionModelStruct._default_construct)
2017
+
2018
+ # Register the factory (usually at the bottom of the file)
2019
+ DiffusionAttentionStruct.register_factory(ATTENTION_CLS, DiffusionAttentionStruct._default_construct)