lym00 commited on
Commit
55d907c
·
verified ·
1 Parent(s): 5f212ea

Upload config.py

Browse files
Files changed (1) hide show
  1. config.py +422 -0
config.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Diffusion pipeline configuration module."""
3
+
4
+ import gc
5
+ import typing as tp
6
+ from dataclasses import dataclass, field
7
+
8
+ import torch
9
+ from diffusers.pipelines import (
10
+ AutoPipelineForText2Image,
11
+ DiffusionPipeline,
12
+ FluxKontextPipeline,
13
+ FluxControlPipeline,
14
+ FluxFillPipeline,
15
+ SanaPipeline,
16
+ )
17
+ from omniconfig import configclass
18
+ from torch import nn
19
+ from transformers import PreTrainedModel, PreTrainedTokenizer, T5EncoderModel
20
+
21
+ from deepcompressor.data.utils.dtype import eval_dtype
22
+ from deepcompressor.quantizer.processor import Quantizer
23
+ from deepcompressor.utils import tools
24
+ from deepcompressor.utils.hooks import AccumBranchHook, ProcessHook
25
+
26
+ from ....nn.patch.linear import ConcatLinear, ShiftedLinear
27
+ from ....nn.patch.lowrank import LowRankBranch
28
+ from ..nn.patch import (
29
+ replace_fused_linear_with_concat_linear,
30
+ replace_up_block_conv_with_concat_conv,
31
+ shift_input_activations,
32
+ )
33
+
34
+ __all__ = ["DiffusionPipelineConfig"]
35
+
36
+
37
+ @configclass
38
+ @dataclass
39
+ class LoRAConfig:
40
+ """LoRA configuration.
41
+
42
+ Args:
43
+ path (`str`):
44
+ The path of the LoRA branch.
45
+ weight_name (`str`):
46
+ The weight name of the LoRA branch.
47
+ alpha (`float`):
48
+ The alpha value of the LoRA branch.
49
+ """
50
+
51
+ path: str
52
+ weight_name: str
53
+ alpha: float = 1.0
54
+
55
+
56
+ @configclass
57
+ @dataclass
58
+ class DiffusionPipelineConfig:
59
+ """Diffusion pipeline configuration.
60
+
61
+ Args:
62
+ name (`str`):
63
+ The name of the pipeline.
64
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
65
+ The data type of the pipeline.
66
+ device (`str`, *optional*, defaults to `"cuda"`):
67
+ The device of the pipeline.
68
+ shift_activations (`bool`, *optional*, defaults to `False`):
69
+ Whether to shift activations.
70
+ """
71
+
72
+ _pipeline_factories: tp.ClassVar[
73
+ dict[str, tp.Callable[[str, str, torch.dtype, torch.device, bool], DiffusionPipeline]]
74
+ ] = {}
75
+ _text_extractors: tp.ClassVar[
76
+ dict[
77
+ str,
78
+ tp.Callable[
79
+ [DiffusionPipeline, tuple[type[PreTrainedModel], ...]],
80
+ list[tuple[str, PreTrainedModel, PreTrainedTokenizer]],
81
+ ],
82
+ ]
83
+ ] = {}
84
+
85
+ name: str
86
+ path: str = ""
87
+ dtype: torch.dtype = field(
88
+ default_factory=lambda s=torch.float32: eval_dtype(s, with_quant_dtype=False, with_none=False)
89
+ )
90
+ device: str = "cuda"
91
+ shift_activations: bool = False
92
+ lora: LoRAConfig | None = None
93
+ family: str = field(init=False)
94
+ task: str = "text-to-image"
95
+
96
+ def __post_init__(self):
97
+ self.family = self.name.split("-")[0]
98
+
99
+ if self.name == "flux.1-canny-dev":
100
+ self.task = "canny-to-image"
101
+ elif self.name == "flux.1-depth-dev":
102
+ self.task = "depth-to-image"
103
+ elif self.name == "flux.1-fill-dev":
104
+ self.task = "inpainting"
105
+
106
+ def build(
107
+ self, *, dtype: str | torch.dtype | None = None, device: str | torch.device | None = None
108
+ ) -> DiffusionPipeline:
109
+ """Build the diffusion pipeline.
110
+
111
+ Args:
112
+ dtype (`str` or `torch.dtype`, *optional*):
113
+ The data type of the pipeline.
114
+ device (`str` or `torch.device`, *optional*):
115
+ The device of the pipeline.
116
+
117
+ Returns:
118
+ `DiffusionPipeline`:
119
+ The diffusion pipeline.
120
+ """
121
+ if dtype is None:
122
+ dtype = self.dtype
123
+ if device is None:
124
+ device = self.device
125
+ _factory = self._pipeline_factories.get(self.name, self._default_build)
126
+ return _factory(
127
+ name=self.name, path=self.path, dtype=dtype, device=device, shift_activations=self.shift_activations
128
+ )
129
+
130
+ def extract_text_encoders(
131
+ self, pipeline: DiffusionPipeline, supported: tuple[type[PreTrainedModel], ...] = (T5EncoderModel,)
132
+ ) -> list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]:
133
+ """Extract the text encoders and tokenizers from the pipeline.
134
+
135
+ Args:
136
+ pipeline (`DiffusionPipeline`):
137
+ The diffusion pipeline.
138
+ supported (`tuple[type[PreTrainedModel], ...]`, *optional*, defaults to `(T5EncoderModel,)`):
139
+ The supported text encoder types. If not specified, all text encoders will be extracted.
140
+
141
+ Returns:
142
+ `list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]`:
143
+ The list of text encoder name, model, and tokenizer.
144
+ """
145
+ _extractor = self._text_extractors.get(self.name, self._default_extract_text_encoders)
146
+ return _extractor(pipeline, supported)
147
+
148
+ @classmethod
149
+ def register_pipeline_factory(
150
+ cls,
151
+ names: str | tuple[str, ...],
152
+ /,
153
+ factory: tp.Callable[[str, str, torch.dtype, torch.device, bool], DiffusionPipeline],
154
+ *,
155
+ overwrite: bool = False,
156
+ ) -> None:
157
+ """Register a pipeline factory.
158
+
159
+ Args:
160
+ names (`str` or `tuple[str, ...]`):
161
+ The name of the pipeline.
162
+ factory (`Callable[[str, str,torch.dtype, torch.device, bool], DiffusionPipeline]`):
163
+ The pipeline factory function.
164
+ overwrite (`bool`, *optional*, defaults to `False`):
165
+ Whether to overwrite the existing factory for the pipeline.
166
+ """
167
+ if isinstance(names, str):
168
+ names = [names]
169
+ for name in names:
170
+ if name in cls._pipeline_factories and not overwrite:
171
+ raise ValueError(f"Pipeline factory {name} already exists.")
172
+ cls._pipeline_factories[name] = factory
173
+
174
+ @classmethod
175
+ def register_text_extractor(
176
+ cls,
177
+ names: str | tuple[str, ...],
178
+ /,
179
+ extractor: tp.Callable[
180
+ [DiffusionPipeline, tuple[type[PreTrainedModel], ...]],
181
+ list[tuple[str, PreTrainedModel, PreTrainedTokenizer]],
182
+ ],
183
+ *,
184
+ overwrite: bool = False,
185
+ ) -> None:
186
+ """Register a text extractor.
187
+
188
+ Args:
189
+ names (`str` or `tuple[str, ...]`):
190
+ The name of the pipeline.
191
+ extractor (`Callable[[DiffusionPipeline], list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]`):
192
+ The text extractor function.
193
+ overwrite (`bool`, *optional*, defaults to `False`):
194
+ Whether to overwrite the existing extractor for the pipeline.
195
+ """
196
+ if isinstance(names, str):
197
+ names = [names]
198
+ for name in names:
199
+ if name in cls._text_extractors and not overwrite:
200
+ raise ValueError(f"Text extractor {name} already exists.")
201
+ cls._text_extractors[name] = extractor
202
+
203
+ def load_lora( # noqa: C901
204
+ self, pipeline: DiffusionPipeline, smooth_cache: dict[str, torch.Tensor] | None = None
205
+ ) -> DiffusionPipeline:
206
+ smooth_cache = smooth_cache or {}
207
+ model = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
208
+ assert isinstance(model, nn.Module)
209
+ if self.lora is not None:
210
+ logger = tools.logging.getLogger(__name__)
211
+ logger.info(f"Load LoRA branches from {self.lora.path}")
212
+ lora_state_dict, alphas = pipeline.lora_state_dict(
213
+ self.lora.path, return_alphas=True, weight_name=self.lora.weight_name
214
+ )
215
+ tools.logging.Formatter.indent_inc()
216
+ for name, module in model.named_modules():
217
+ if isinstance(module, (nn.Linear, ConcatLinear, ShiftedLinear)):
218
+ lora_a_key, lora_b_key = f"transformer.{name}.lora_A.weight", f"transformer.{name}.lora_B.weight"
219
+ if lora_a_key in lora_state_dict:
220
+ assert lora_b_key in lora_state_dict
221
+ logger.info(f"+ Load LoRA branch for {name}")
222
+ tools.logging.Formatter.indent_inc()
223
+ a = lora_state_dict.pop(lora_a_key)
224
+ b = lora_state_dict.pop(lora_b_key)
225
+ assert isinstance(a, torch.Tensor)
226
+ assert isinstance(b, torch.Tensor)
227
+ assert a.shape[1] == module.in_features
228
+ assert b.shape[0] == module.out_features
229
+ if isinstance(module, ConcatLinear):
230
+ logger.debug(
231
+ f"- split LoRA branch into {len(module.linears)} parts ({module.in_features_list})"
232
+ )
233
+ m_splits = module.linears
234
+ a_splits = a.split(module.in_features_list, dim=1)
235
+ b_splits = [b] * len(a_splits)
236
+ else:
237
+ m_splits, a_splits, b_splits = [module], [a], [b]
238
+ for m, a, b in zip(m_splits, a_splits, b_splits, strict=True):
239
+ assert a.shape[0] == b.shape[1]
240
+ if isinstance(m, ShiftedLinear):
241
+ s, m = m.shift, m.linear
242
+ logger.debug(f"- shift LoRA input by {s.item() if s.numel() == 1 else s}")
243
+ else:
244
+ s = None
245
+ assert isinstance(m, nn.Linear)
246
+ device, dtype = m.weight.device, m.weight.dtype
247
+ a, b = a.to(device=device, dtype=torch.float64), b.to(device=device, dtype=torch.float64)
248
+ if s is not None:
249
+ if s.numel() == 1:
250
+ s = torch.matmul(b, a.sum(dim=1).mul_(s.double())).mul_(self.lora.alpha)
251
+ else:
252
+ s = torch.matmul(b, torch.matmul(a, s.view(1, -1).double())).mul_(self.lora.alpha)
253
+ if hasattr(m, "in_smooth_cache_key"):
254
+ logger.debug(f"- smooth LoRA input using {m.in_smooth_cache_key} smooth scale")
255
+ ss = smooth_cache[m.in_smooth_cache_key].to(device=device, dtype=torch.float64)
256
+ a = a.mul_(ss.view(1, -1))
257
+ del ss
258
+ if hasattr(m, "out_smooth_cache_key"):
259
+ logger.debug(f"- smooth LoRA output using {m.out_smooth_cache_key} smooth scale")
260
+ ss = smooth_cache[m.out_smooth_cache_key].to(device=device, dtype=torch.float64)
261
+ b = b.div_(ss.view(-1, 1))
262
+ if s is not None:
263
+ s = s.div_(ss.view(-1))
264
+ del ss
265
+ branch_hook, quant_hook = None, None
266
+ for hook in m._forward_pre_hooks.values():
267
+ if isinstance(hook, AccumBranchHook) and isinstance(hook.branch, LowRankBranch):
268
+ branch_hook = hook
269
+ if isinstance(hook, ProcessHook) and isinstance(hook.processor, Quantizer):
270
+ quant_hook = hook
271
+ if branch_hook is not None:
272
+ logger.debug("- fuse with existing LoRA branch")
273
+ assert isinstance(branch_hook.branch, LowRankBranch)
274
+ _a = branch_hook.branch.a.weight.data
275
+ _b = branch_hook.branch.b.weight.data
276
+ if branch_hook.branch.alpha != self.lora.alpha:
277
+ a, b = a.to(dtype=dtype), b.mul_(self.lora.alpha).to(dtype=dtype)
278
+ _b = _b.to(dtype=torch.float64).mul_(branch_hook.branch.alpha).to(dtype=dtype)
279
+ alpha = 1
280
+ else:
281
+ a, b = a.to(dtype=dtype), b.to(dtype=dtype)
282
+ alpha = self.lora.alpha
283
+ branch_hook.branch = LowRankBranch(
284
+ m.in_features,
285
+ m.out_features,
286
+ rank=a.shape[0] + branch_hook.branch.rank,
287
+ alpha=alpha,
288
+ ).to(device=device, dtype=dtype)
289
+ branch_hook.branch.a.weight.data[: a.shape[0], :] = a
290
+ branch_hook.branch.b.weight.data[:, : b.shape[1]] = b
291
+ branch_hook.branch.a.weight.data[a.shape[0] :, :] = _a
292
+ branch_hook.branch.b.weight.data[:, b.shape[1] :] = _b
293
+ else:
294
+ logger.debug("- create a new LoRA branch")
295
+ branch = LowRankBranch(
296
+ m.in_features, m.out_features, rank=a.shape[0], alpha=self.lora.alpha
297
+ )
298
+ branch = branch.to(device=device, dtype=dtype)
299
+ branch.a.weight.data.copy_(a.to(dtype=dtype))
300
+ branch.b.weight.data.copy_(b.to(dtype=dtype))
301
+ # low rank branch hook should be registered before the quantization hook
302
+ if quant_hook is not None:
303
+ logger.debug(f"- remove quantization hook from {name}")
304
+ quant_hook.remove(m)
305
+ logger.debug(f"- register LoRA branch to {name}")
306
+ branch.as_hook().register(m)
307
+ if quant_hook is not None:
308
+ logger.debug(f"- re-register quantization hook to {name}")
309
+ quant_hook.register(m)
310
+ if s is not None:
311
+ assert m.bias is not None
312
+ m.bias.data.copy_((m.bias.double().sub_(s)).to(dtype))
313
+ del m_splits, a_splits, b_splits, a, b, s
314
+ gc.collect()
315
+ torch.cuda.empty_cache()
316
+ tools.logging.Formatter.indent_dec()
317
+ tools.logging.Formatter.indent_dec()
318
+ if len(lora_state_dict) > 0:
319
+ logger.warning(f"Unused LoRA weights: {lora_state_dict.keys()}")
320
+ branches = nn.ModuleList()
321
+ for _, module in model.named_modules():
322
+ for hook in module._forward_hooks.values():
323
+ if isinstance(hook, AccumBranchHook) and isinstance(hook.branch, LowRankBranch):
324
+ branches.append(hook.branch)
325
+ model.register_module("_low_rank_branches", branches)
326
+
327
+ @staticmethod
328
+ def _default_build(
329
+ name: str, path: str, dtype: str | torch.dtype, device: str | torch.device, shift_activations: bool
330
+ ) -> DiffusionPipeline:
331
+ if not path:
332
+ if name == "sdxl":
333
+ path = "stabilityai/stable-diffusion-xl-base-1.0"
334
+ elif name == "sdxl-turbo":
335
+ path = "stabilityai/sdxl-turbo"
336
+ elif name == "pixart-sigma":
337
+ path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
338
+ elif name == "flux.1-kontext-dev":
339
+ path = "black-forest-labs/FLUX.1-Kontext-dev"
340
+ elif name == "flux.1-dev":
341
+ path = "black-forest-labs/FLUX.1-dev"
342
+ elif name == "flux.1-canny-dev":
343
+ path = "black-forest-labs/FLUX.1-Canny-dev"
344
+ elif name == "flux.1-depth-dev":
345
+ path = "black-forest-labs/FLUX.1-Depth-dev"
346
+ elif name == "flux.1-fill-dev":
347
+ path = "black-forest-labs/FLUX.1-Fill-dev"
348
+ elif name == "flux.1-schnell":
349
+ path = "black-forest-labs/FLUX.1-schnell"
350
+ else:
351
+ raise ValueError(f"Path for {name} is not specified.")
352
+ if name in ["flux.1-kontext-dev"]:
353
+ pipeline = FluxKontextPipeline.from_pretrained(path, torch_dtype=dtype)
354
+ elif name in ["flux.1-canny-dev", "flux.1-depth-dev"]:
355
+ pipeline = FluxControlPipeline.from_pretrained(path, torch_dtype=dtype)
356
+ elif name == "flux.1-fill-dev":
357
+ pipeline = FluxFillPipeline.from_pretrained(path, torch_dtype=dtype)
358
+ elif name.startswith("sana-"):
359
+ if dtype == torch.bfloat16:
360
+ pipeline = SanaPipeline.from_pretrained(path, variant="bf16", torch_dtype=dtype, use_safetensors=True)
361
+ pipeline.vae.to(dtype)
362
+ pipeline.text_encoder.to(dtype)
363
+ else:
364
+ pipeline = SanaPipeline.from_pretrained(path, torch_dtype=dtype)
365
+ else:
366
+ pipeline = AutoPipelineForText2Image.from_pretrained(path, torch_dtype=dtype)
367
+
368
+ # Debug output
369
+ print(">>> DEVICE:", device)
370
+ print(">>> PIPELINE TYPE:", type(pipeline))
371
+
372
+ # Try to move each component using .to_empty()
373
+ for name in ["unet", "transformer", "vae", "text_encoder"]:
374
+ module = getattr(pipeline, name, None)
375
+ if isinstance(module, torch.nn.Module):
376
+ try:
377
+ print(f">>> Moving {name} to {device} using to_empty()")
378
+ module.to_empty(device)
379
+ except Exception as e:
380
+ print(f">>> WARNING: {name}.to_empty({device}) failed: {e}")
381
+ try:
382
+ print(f">>> Falling back to {name}.to({device})")
383
+ module.to(device)
384
+ except Exception as ee:
385
+ print(f">>> ERROR: {name}.to({device}) also failed: {ee}")
386
+
387
+ # Identify main model (for patching)
388
+ model = getattr(pipeline, "unet", None) or getattr(pipeline, "transformer", None)
389
+ if model is not None:
390
+ replace_fused_linear_with_concat_linear(model)
391
+ replace_up_block_conv_with_concat_conv(model)
392
+ if shift_activations:
393
+ shift_input_activations(model)
394
+ else:
395
+ print(">>> WARNING: No model (unet/transformer) found for patching")
396
+
397
+ return pipeline
398
+
399
+ @staticmethod
400
+ def _default_extract_text_encoders(
401
+ pipeline: DiffusionPipeline, supported: tuple[type[PreTrainedModel], ...]
402
+ ) -> list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]:
403
+ """Extract the text encoders and tokenizers from the pipeline.
404
+
405
+ Args:
406
+ pipeline (`DiffusionPipeline`):
407
+ The diffusion pipeline.
408
+ supported (`tuple[type[PreTrainedModel], ...]`, *optional*, defaults to `(T5EncoderModel,)`):
409
+ The supported text encoder types. If not specified, all text encoders will be extracted.
410
+
411
+ Returns:
412
+ `list[tuple[str, PreTrainedModel, PreTrainedTokenizer]]`:
413
+ The list of text encoder name, model, and tokenizer.
414
+ """
415
+ results: list[tuple[str, PreTrainedModel, PreTrainedTokenizer]] = []
416
+ for key in vars.__dict__.keys():
417
+ if key.startswith("text_encoder"):
418
+ suffix = key[len("text_encoder") :]
419
+ encoder, tokenizer = getattr(pipeline, f"text_encoder{suffix}"), getattr(pipeline, f"tokenizer{suffix}")
420
+ if not supported or isinstance(encoder, supported):
421
+ results.append((key, encoder, tokenizer))
422
+ return results