recoilme commited on
Commit
2a06cb8
·
1 Parent(s): 9a62483
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. .gitignore +13 -0
  3. README.md +45 -1
  4. config.json +38 -0
  5. diffusion_pytorch_model.safetensors +3 -0
  6. eval.py +167 -0
  7. train_sdxl_vae.py +504 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jupyter Notebook
2
+ __pycache__/
3
+ *.pyc
4
+ .ipynb_checkpoints/
5
+ *.ipynb_checkpoints/*
6
+ .ipynb_checkpoints/*
7
+ src/samples
8
+ # cache
9
+ cache
10
+ datasets
11
+ test
12
+ wandb
13
+ nohup.out
README.md CHANGED
@@ -10,4 +10,48 @@ library_name: diffusers
10
  # SDXL-VAE finetuned
11
  - madebyollin/sdxl-vae-fp16-fix: MSE: 3.680e-03, PSNR: 25.2100, LPIPS: 0.1314
12
  - KBlueLeaf/EQ-SDXL-VAE : MSE: 3.530e-03, PSNR: 25.2827, LPIPS: 0.1298
13
- - AiArtLab/sdxl_vae : MSE: <span style="color:red">3.321e-03</span>, PSNR: <span style="color:red">25.6389</span>, LPIPS: <span style="color:red">0.1251</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # SDXL-VAE finetuned
11
  - madebyollin/sdxl-vae-fp16-fix: MSE: 3.680e-03, PSNR: 25.2100, LPIPS: 0.1314
12
  - KBlueLeaf/EQ-SDXL-VAE : MSE: 3.530e-03, PSNR: 25.2827, LPIPS: 0.1298
13
+ - AiArtLab/sdxl_vae : MSE: <span style="color:red">3.321e-03</span>, PSNR: <span style="color:red">25.6389</span>, LPIPS: <span style="color:red">0.1251</span>
14
+
15
+ ### Train status, in progress:
16
+
17
+ ![result](result.png)
18
+
19
+ ## VAE Training Process
20
+
21
+ Dataset: 100,000 PNG images
22
+ Training Time: 4 days
23
+ Hardware: Single RTX 4090
24
+ Resolution: 512px
25
+ Precision: FP32
26
+ Effective Batch Size: 16 (batch size 2 + gradient accumulation 8)
27
+ Optimizer: AdamW (8-bit)
28
+
29
+ ## Implementation
30
+
31
+ Base Code: Used a simple diffusion model training script.
32
+ Encoder: Frozen (to avoid retraining SDXL for the new VAE).
33
+ Training Target: Only the decoder, focusing on image reconstruction.
34
+
35
+ ## Loss Functions
36
+
37
+ Initially used LPIPS and MSE.
38
+ Noticed FID score improving, but images becoming blurry (FID overfits to blurry images—improving FID is not always good).
39
+ Switched to MAE (Mean Absolute Error) instead of MSE (not sure is MSE bad).
40
+ Balanced LPIPS and MAE at 90/10 ratio.
41
+ Used median perceptual_loss_weight for better balance.
42
+
43
+ ## Results
44
+
45
+ https://imgsli.com/NDA3NTEy/1/2
46
+
47
+ ## Donations
48
+
49
+ Please contact with us if you may provide some GPU's or money on training
50
+
51
+ DOGE: DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83
52
+
53
+ BTC: 3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
54
+
55
+ ## Contacts
56
+
57
+ [recoilme](https://t.me/recoilme)
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.34.0",
4
+ "_name_or_path": "sdxl_vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": false,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 512,
28
+ "scaling_factor": 0.13025,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03f2412467f6bedce9efeddba5860b5ec0d3267931d14c500d4bd7a878e14cbd
3
+ size 334643268
eval.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import logging
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.data as data
6
+ import lpips
7
+ from tqdm import tqdm
8
+ from torchvision.transforms import (
9
+ Compose,
10
+ Resize,
11
+ ToTensor,
12
+ CenterCrop,
13
+ )
14
+ from diffusers import AutoencoderKL
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ warnings.filterwarnings(
20
+ "ignore",
21
+ ".*Found keys that are not in the model state dict but in the checkpoint.*",
22
+ )
23
+
24
+ DEVICE = "cuda"
25
+ DTYPE = torch.float16
26
+ SHORT_AXIS_SIZE = 256
27
+
28
+ NAMES = [
29
+ "madebyollin/sdxl-vae-fp16-fix",
30
+ "KBlueLeaf/EQ-SDXL-VAE ",
31
+ "AiArtLab/simplevae ",
32
+ ]
33
+ BASE_MODELS = [
34
+ "madebyollin/sdxl-vae-fp16-fix",
35
+ "KBlueLeaf/EQ-SDXL-VAE",
36
+ "AiArtLab/simplevae",
37
+ ]
38
+ SUB_FOLDERS = [None, None, "sdxl_vae"]
39
+ CKPT_PATHS = [
40
+ None,
41
+ None,
42
+ None,
43
+ ]
44
+ USE_APPROXS = [False, False, False]
45
+
46
+ def process(x):
47
+ return x * 2 - 1
48
+
49
+ def deprocess(x):
50
+ return x * 0.5 + 0.5
51
+
52
+ import torch.utils.data as data
53
+ from datasets import load_dataset
54
+
55
+ class ImageNetDataset(data.IterableDataset):
56
+ def __init__(self, split, transform=None, max_len=10, streaming=True):
57
+ self.split = split
58
+ self.transform = transform
59
+ self.dataset = load_dataset("evanarlian/imagenet_1k_resized_256", split=split, streaming=streaming)
60
+ self.max_len = max_len
61
+ self.iterator = iter(self.dataset)
62
+
63
+ def __iter__(self):
64
+ for i, entry in enumerate(self.iterator):
65
+ if self.max_len and i >= self.max_len:
66
+ break
67
+ img = entry["image"]
68
+ target = entry["label"]
69
+ if self.transform is not None:
70
+ img = self.transform(img)
71
+ yield img, target
72
+
73
+ if __name__ == "__main__":
74
+ lpips_loss = torch.compile(
75
+ lpips.LPIPS(net="vgg").eval().to(DEVICE).requires_grad_(False)
76
+ )
77
+
78
+ @torch.compile
79
+ def metrics(inp, recon):
80
+ mse = F.mse_loss(inp, recon)
81
+ psnr = 10 * torch.log10(1 / mse)
82
+ return (
83
+ mse.cpu(),
84
+ psnr.cpu(),
85
+ lpips_loss(inp, recon, normalize=True).mean().cpu(),
86
+ )
87
+
88
+ transform = Compose(
89
+ [
90
+ Resize(SHORT_AXIS_SIZE),
91
+ CenterCrop(SHORT_AXIS_SIZE),
92
+ ToTensor(),
93
+ ]
94
+ )
95
+ valid_dataset = ImageNetDataset("val", transform=transform, max_len=50000, streaming=True)
96
+ valid_loader = data.DataLoader(
97
+ valid_dataset,
98
+ batch_size=4,
99
+ shuffle=False,
100
+ num_workers=2,
101
+ pin_memory=True,
102
+ pin_memory_device=DEVICE,
103
+ )
104
+
105
+ # Проверяем, что данные грузятся
106
+ for batch in valid_loader:
107
+ print("Batch shape:", batch[0].shape)
108
+ break
109
+
110
+ logger.info("Loading models...")
111
+ vaes = []
112
+ for base_model, sub_folder, ckpt_path, use_approx in zip(
113
+ BASE_MODELS, SUB_FOLDERS, CKPT_PATHS, USE_APPROXS
114
+ ):
115
+ vae = AutoencoderKL.from_pretrained(base_model, subfolder=sub_folder)
116
+ if use_approx:
117
+ vae.decoder = LatentApproxDecoder(
118
+ latent_dim=vae.config.latent_channels,
119
+ out_channels=3,
120
+ shuffle=2,
121
+ )
122
+ vae.decode = lambda x: vae.decoder(x)
123
+ vae.get_last_layer = lambda: vae.decoder.conv_out.weight
124
+ if ckpt_path:
125
+ LatentTrainer.load_from_checkpoint(
126
+ ckpt_path, vae=vae, map_location="cpu", strict=False
127
+ )
128
+ vae = vae.to(DTYPE).eval().requires_grad_(False).to(DEVICE)
129
+ vae.encoder = torch.compile(vae.encoder)
130
+ vae.decoder = torch.compile(vae.decoder)
131
+ vaes.append(torch.compile(vae))
132
+
133
+ logger.info("Running Validation")
134
+ total = 0
135
+ all_latents = [[] for _ in range(len(vaes))]
136
+ all_mse = [[] for _ in range(len(vaes))]
137
+ all_psnr = [[] for _ in range(len(vaes))]
138
+ all_lpips = [[] for _ in range(len(vaes))]
139
+
140
+ for idx, batch in enumerate(tqdm(valid_loader)):
141
+ image = batch[0].to(DEVICE)
142
+ test_inp = process(image).to(DTYPE)
143
+ batch_size = test_inp.size(0)
144
+
145
+ for i, vae in enumerate(vaes):
146
+ latent = vae.encode(test_inp).latent_dist.mode()
147
+ recon = deprocess(vae.decode(latent).sample.float())
148
+ all_latents[i].append(latent.cpu().float())
149
+ mse, psnr, lpips_ = metrics(image, recon)
150
+ all_mse[i].append(mse.cpu() * batch_size)
151
+ all_psnr[i].append(psnr.cpu() * batch_size)
152
+ all_lpips[i].append(lpips_.cpu() * batch_size)
153
+
154
+ total += batch_size
155
+
156
+ for i in range(len(vaes)):
157
+ all_latents[i] = torch.cat(all_latents[i], dim=0)
158
+ all_mse[i] = torch.stack(all_mse[i]).sum() / total
159
+ all_psnr[i] = torch.stack(all_psnr[i]).sum() / total
160
+ all_lpips[i] = torch.stack(all_lpips[i]).sum() / total
161
+
162
+ logger.info(
163
+ f" - {NAMES[i]}: MSE: {all_mse[i]:.3e}, PSNR: {all_psnr[i]:.4f}, "
164
+ f"LPIPS: {all_lpips[i]:.4f}"
165
+ )
166
+
167
+ logger.info("End")
train_sdxl_vae.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ from accelerate import Accelerator
18
+ from PIL import Image, UnidentifiedImageError
19
+ from tqdm import tqdm
20
+ import bitsandbytes as bnb
21
+ import wandb
22
+ import lpips # pip install lpips
23
+
24
+ # --------------------------- Параметры ---------------------------
25
+ ds_path = "/workspace/png"
26
+ project = "sdxl_vae"
27
+ batch_size = 1
28
+ base_learning_rate = 1e-6
29
+ min_learning_rate = 8e-7
30
+ num_epochs = 8
31
+ sample_interval_share = 20
32
+ use_wandb = True
33
+ save_model = True
34
+ use_decay = True
35
+ optimizer_type = "adam8bit"
36
+ dtype = torch.float32
37
+ # model_resolution — то, что подавается в VAE (низкое разрешение)
38
+ model_resolution = 768 # бывший `resolution`
39
+ # high_resolution — настоящий «высокий» кроп, на котором считаем метрики и сохраняем сэмплы
40
+ high_resolution = 768 # >>> CHANGED: обучаемся на входах 1024 -> даунсемплим до 512 для модели
41
+ limit = 0
42
+ save_barrier = 1.03
43
+ warmup_percent = 0.01
44
+ percentile_clipping = 95
45
+ beta2 = 0.97
46
+ eps = 1e-6
47
+ clip_grad_norm = 1.0
48
+ mixed_precision = "no" # или "fp16"/"bf16" при поддержке
49
+ gradient_accumulation_steps = 16
50
+ generated_folder = "samples"
51
+ save_as = "sdxl_vae_new"
52
+ perceptual_loss_weight = 0.03 # начальное значение веса (будет перезаписываться каждый шаг)
53
+ num_workers = 0
54
+ device = None # accelerator задаст устройство
55
+
56
+ # --- Параметры динамической нормализации LPIPS
57
+ lpips_ratio = 0.9 #percent lpips in loss
58
+
59
+ min_perceptual_weight = 0.1 # минимальный предел веса
60
+ max_perceptual_weight = 99 # максимальный предел веса (защита от взрывов)
61
+
62
+ # --------------------------- параметры препроцессинга ---------------------------
63
+ resize_long_side = 1280 # если None или 0 — ресайза не будет; рекомендовано 1024
64
+
65
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
66
+
67
+ accelerator = Accelerator(
68
+ mixed_precision=mixed_precision,
69
+ gradient_accumulation_steps=gradient_accumulation_steps
70
+ )
71
+ device = accelerator.device
72
+
73
+ # reproducibility
74
+ seed = int(datetime.now().strftime("%Y%m%d"))
75
+ torch.manual_seed(seed)
76
+ np.random.seed(seed)
77
+ random.seed(seed)
78
+
79
+ torch.backends.cudnn.benchmark = True
80
+
81
+ # --------------------------- WandB ---------------------------
82
+ if use_wandb and accelerator.is_main_process:
83
+ wandb.init(project=project, config={
84
+ "batch_size": batch_size,
85
+ "base_learning_rate": base_learning_rate,
86
+ "num_epochs": num_epochs,
87
+ "optimizer_type": optimizer_type,
88
+ "model_resolution": model_resolution,
89
+ "high_resolution": high_resolution,
90
+ "gradient_accumulation_steps": gradient_accumulation_steps,
91
+ })
92
+
93
+ # --------------------------- VAE ---------------------------
94
+ vae = AutoencoderKL.from_pretrained(project).to(dtype)
95
+ #vae = AsymmetricAutoencoderKL.from_pretrained(project).to(dtype)
96
+
97
+ # >>> CHANGED: заморозка всех параметров, затем разморозка mid_block + up_blocks[-2:]
98
+ for p in vae.parameters():
99
+ p.requires_grad = False
100
+
101
+ decoder = getattr(vae, "decoder", None)
102
+ if decoder is None:
103
+ raise RuntimeError("vae.decoder not found — не могу применить стратегию разморозки. Проверь структуру модели.")
104
+
105
+ unfrozen_param_names = []
106
+
107
+ if not hasattr(decoder, "up_blocks"):
108
+ raise RuntimeError("decoder.up_blocks не найдены — ожидается список блоков декодера.")
109
+
110
+ # >>> CHANGED: размораживаем последние 2 up_blocks (как просил) и mid_block
111
+ n_up = len(decoder.up_blocks)
112
+ start_idx = 0 #max(0, n_up - 2)
113
+ for idx in range(start_idx, n_up):
114
+ block = decoder.up_blocks[idx]
115
+ for name, p in block.named_parameters():
116
+ p.requires_grad = True
117
+ unfrozen_param_names.append(f"decoder.up_blocks.{idx}.{name}")
118
+
119
+ if hasattr(decoder, "mid_block"):
120
+ for name, p in decoder.mid_block.named_parameters():
121
+ p.requires_grad = True
122
+ unfrozen_param_names.append(f"decoder.mid_block.{name}")
123
+ else:
124
+ print("[WARN] decoder.mid_block не найден — mid_block не разморожен.")
125
+
126
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
127
+ for nm in unfrozen_param_names[:200]:
128
+ print(" ", nm)
129
+
130
+ # сохраняем trainable_module (get_param_groups будет учитывать p.requires_grad)
131
+ trainable_module = vae.decoder
132
+
133
+ # --------------------------- Custom PNG Dataset (only .png, skip corrupted) -----------
134
+ class PngFolderDataset(Dataset):
135
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
136
+ # >>> CHANGED: default resolution argument is high-resolution (1024)
137
+ self.root_dir = root_dir
138
+ self.resolution = resolution
139
+ self.paths = []
140
+ # collect png files recursively
141
+ for root, _, files in os.walk(root_dir):
142
+ for fname in files:
143
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
144
+ self.paths.append(os.path.join(root, fname))
145
+ # optional limit
146
+ if limit:
147
+ self.paths = self.paths[:limit]
148
+ # verify images and keep only valid ones
149
+ valid = []
150
+ for p in self.paths:
151
+ try:
152
+ with Image.open(p) as im:
153
+ im.verify() # fast check for truncated/corrupted images
154
+ valid.append(p)
155
+ except (OSError, UnidentifiedImageError):
156
+ # skip corrupted image
157
+ continue
158
+ self.paths = valid
159
+ if len(self.paths) == 0:
160
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
161
+ # final shuffle for randomness
162
+ random.shuffle(self.paths)
163
+
164
+ def __len__(self):
165
+ return len(self.paths)
166
+
167
+ def __getitem__(self, idx):
168
+ p = self.paths[idx % len(self.paths)]
169
+ # open and convert to RGB; ensure file is closed promptly
170
+ with Image.open(p) as img:
171
+ img = img.convert("RGB")
172
+ # return PIL image (collate will transform)
173
+ if not resize_long_side or resize_long_side <= 0:
174
+ return img
175
+ w, h = img.size
176
+ long = max(w, h)
177
+ if long <= resize_long_side:
178
+ return img
179
+ scale = resize_long_side / float(long)
180
+ new_w = int(round(w * scale))
181
+ new_h = int(round(h * scale))
182
+ return img.resize((new_w, new_h), Image.LANCZOS)
183
+
184
+ # --------------------------- Датасет и трансформы ---------------------------
185
+
186
+ def random_crop(img, sz):
187
+ w, h = img.size
188
+ if w < sz or h < sz:
189
+ img = img.resize((max(sz, w), max(sz, h)), Image.LANCZOS)
190
+ x = random.randint(0, max(1, img.width - sz))
191
+ y = random.randint(0, max(1, img.height - sz))
192
+ return img.crop((x, y, x + sz, y + sz))
193
+
194
+ tfm = transforms.Compose([
195
+ transforms.ToTensor(),
196
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
197
+ ])
198
+
199
+ # build dataset using high_resolution crops
200
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit) # >>> CHANGED
201
+ if len(dataset) < batch_size:
202
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
203
+
204
+ # collate_fn кропит до high_resolution
205
+ def collate_fn(batch):
206
+ imgs = []
207
+ for img in batch: # img is PIL.Image
208
+ img = random_crop(img, high_resolution) # >>> CHANGED: crop high-res
209
+ imgs.append(tfm(img))
210
+ return torch.stack(imgs)
211
+
212
+ dataloader = DataLoader(
213
+ dataset,
214
+ batch_size=batch_size,
215
+ shuffle=True,
216
+ collate_fn=collate_fn,
217
+ num_workers=num_workers,
218
+ pin_memory=True,
219
+ drop_last=True
220
+ )
221
+
222
+ # --------------------------- Оптимизатор ---------------------------
223
+ def get_param_groups(module, weight_decay=0.001):
224
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
225
+ decay_params = []
226
+ no_decay_params = []
227
+ for n, p in module.named_parameters():
228
+ if not p.requires_grad:
229
+ continue
230
+ if any(nd in n for nd in no_decay):
231
+ no_decay_params.append(p)
232
+ else:
233
+ decay_params.append(p)
234
+ return [
235
+ {"params": decay_params, "weight_decay": weight_decay},
236
+ {"params": no_decay_params, "weight_decay": 0.0},
237
+ ]
238
+
239
+ def create_optimizer(name, param_groups):
240
+ if name == "adam8bit":
241
+ return bnb.optim.AdamW8bit(
242
+ param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps
243
+ )
244
+ raise ValueError(name)
245
+
246
+ param_groups = get_param_groups(trainable_module, weight_decay=0.001)
247
+ optimizer = create_optimizer(optimizer_type, param_groups)
248
+
249
+ # --------------------------- Подготовка Accelerate (вместе) ---------------------------
250
+ batches_per_epoch = len(dataloader) # число микро-батчей (dataloader steps)
251
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) # число optimizer.step() за эпоху
252
+ total_steps = steps_per_epoch * num_epochs
253
+
254
+ def lr_lambda(step):
255
+ if not use_decay:
256
+ return 1.0
257
+ x = float(step) / float(max(1, total_steps))
258
+ warmup = float(warmup_percent)
259
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
260
+ if x < warmup:
261
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
262
+ decay_ratio = (x - warmup) / (1.0 - warmup)
263
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
264
+
265
+ scheduler = LambdaLR(optimizer, lr_lambda)
266
+
267
+ # Подготовка
268
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
269
+
270
+ trainable_params = [p for p in vae.decoder.parameters() if p.requires_grad]
271
+
272
+ # --------------------------- Сэмплы и LPIPS helper ---------------------------
273
+ @torch.no_grad()
274
+ def get_fixed_samples(n=3):
275
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
276
+ pil_imgs = [dataset[i] for i in idx] # dataset returns PIL.Image
277
+ tensors = []
278
+ for img in pil_imgs:
279
+ img = random_crop(img, high_resolution) # >>> CHANGED: high-res fixed samples
280
+ tensors.append(tfm(img))
281
+ return torch.stack(tensors).to(accelerator.device, dtype)
282
+
283
+ fixed_samples = get_fixed_samples()
284
+
285
+ _lpips_net = None
286
+ def _get_lpips():
287
+ global _lpips_net
288
+ if _lpips_net is None:
289
+ # lpips uses its internal vgg, but we use it as-is.
290
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
291
+ return _lpips_net
292
+
293
+ @torch.no_grad()
294
+ def generate_and_save_samples(step=None):
295
+ try:
296
+ temp_vae = accelerator.unwrap_model(vae).eval()
297
+ lpips_net = _get_lpips()
298
+ with torch.no_grad():
299
+ # >>> CHANGED: use high-res fixed_samples, downsample to model_res for encoding
300
+ orig_high = fixed_samples # already on device
301
+ # make low-res input for model
302
+ if model_resolution==high_resolution:
303
+ orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
304
+ else:
305
+ orig_low =orig_high
306
+
307
+ # ensure dtype matches model params to avoid dtype mismatch
308
+ model_dtype = next(temp_vae.parameters()).dtype
309
+ orig_low = orig_low.to(dtype=model_dtype)
310
+
311
+ latent_dist = temp_vae.encode(orig_low).latent_dist
312
+ latents = latent_dist.mean
313
+ rec = temp_vae.decode(latents).sample # expected to be upscaled to high_res
314
+
315
+ # make sure rec is float32 in range [0,1] for saving
316
+ # if rec spatial size differs from orig_high, resize rec to orig_high
317
+ if rec.shape[-2:] != orig_high.shape[-2:]:
318
+ rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
319
+
320
+ rec_img = ((rec.float() / 2.0 + 0.5).clamp(0, 1) * 255).cpu().numpy()
321
+ for i in range(rec_img.shape[0]):
322
+ arr = rec_img[i].transpose(1, 2, 0).astype(np.uint8)
323
+ Image.fromarray(arr).save(f"{generated_folder}/sample_{step if step is not None else 'init'}_{i}.jpg", quality=95)
324
+
325
+ # LPIPS на полном изображении (high-res)
326
+ lpips_scores = []
327
+ for i in range(rec.shape[0]):
328
+ orig_full = orig_high[i:i+1] # [B, C, H, W], in [-1,1]
329
+ rec_full = rec[i:i+1]
330
+ # ensure same spatial size/dtype
331
+ if rec_full.shape[-2:] != orig_full.shape[-2:]:
332
+ rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
333
+ rec_full = rec_full.to(torch.float32)
334
+ orig_full = orig_full.to(torch.float32)
335
+ lpips_val = lpips_net(orig_full, rec_full).item()
336
+ lpips_scores.append(lpips_val)
337
+ avg_lpips = float(np.mean(lpips_scores))
338
+ if use_wandb and accelerator.is_main_process:
339
+ wandb.log({
340
+ "generated_images": [wandb.Image(Image.fromarray(rec_img[i].transpose(1,2,0).astype(np.uint8))) for i in range(rec_img.shape[0])],
341
+ "lpips_mean": avg_lpips
342
+ }, step=step)
343
+ finally:
344
+ gc.collect()
345
+ torch.cuda.empty_cache()
346
+
347
+ if accelerator.is_main_process and save_model:
348
+ print("Генерация сэмплов до старта обучения...")
349
+ generate_and_save_samples(0)
350
+
351
+ accelerator.wait_for_everyone()
352
+
353
+ # --------------------------- Тренировка ---------------------------
354
+
355
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
356
+ global_step = 0
357
+ min_loss = float("inf")
358
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
359
+
360
+ for epoch in range(num_epochs):
361
+ vae.train()
362
+ batch_losses = []
363
+ batch_losses_mae = []
364
+ batch_losses_lpips = []
365
+ batch_losses_perc = []
366
+ batch_grads = []
367
+ for imgs in dataloader:
368
+ with accelerator.accumulate(vae):
369
+ # imgs: high-res tensor from dataloader ([-1,1]), move to device
370
+ imgs = imgs.to(accelerator.device)
371
+
372
+ # >>> CHANGED: create low-res input for model by downsampling high-res crop
373
+ if model_resolution==high_resolution:
374
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
375
+ else:
376
+ imgs_low = imgs
377
+
378
+ # ensure dtype matches model params to avoid float/half mismatch
379
+ model_dtype = next(vae.parameters()).dtype
380
+ if imgs_low.dtype != model_dtype:
381
+ imgs_low_model = imgs_low.to(dtype=model_dtype)
382
+ else:
383
+ imgs_low_model = imgs_low
384
+
385
+ # Encode/decode on low-res input
386
+ latent_dist = vae.encode(imgs_low_model).latent_dist
387
+ latents = latent_dist.mean
388
+ rec = vae.decode(latents).sample # rec is expected to be high-res (upscaled)
389
+
390
+ # If rec isn't the same spatial size as original high-res input, resize to high-res
391
+ if rec.shape[-2:] != imgs.shape[-2:]:
392
+ rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
393
+
394
+ # Now compute losses **on high-res** (rec vs imgs)
395
+ rec_f32 = rec.to(torch.float32)
396
+ imgs_f32 = imgs.to(torch.float32)
397
+
398
+ # MAE
399
+ mae_loss = F.l1_loss(rec_f32, imgs_f32)
400
+
401
+ # LPIPS (ensure float32)
402
+ lpips_loss = _get_lpips()(rec_f32, imgs_f32).mean()
403
+
404
+ # dynamic perceptual weighting (same as before)
405
+ if float(mae_loss.detach().cpu().item()) > 1e-12:
406
+ desired_multiplier = lpips_ratio / max(1.0 - lpips_ratio, 1e-12)
407
+ new_weight = (mae_loss.item() / float(lpips_loss.detach().cpu().item())) * desired_multiplier
408
+ else:
409
+ new_weight = perceptual_loss_weight
410
+
411
+ perceptual_loss_weight = float(np.clip(new_weight, min_perceptual_weight, max_perceptual_weight))
412
+ batch_losses_perc.append(perceptual_loss_weight)
413
+ if len(batch_losses_perc) >= sample_interval:
414
+ avg_perc = float(np.mean(batch_losses_perc[-sample_interval:]))
415
+ else:
416
+ avg_perc = float(np.mean(batch_losses_perc[-sample_interval:]))
417
+
418
+ total_loss = mae_loss + avg_perc * lpips_loss
419
+
420
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
421
+ print("NaN/Inf loss – stopping")
422
+ raise RuntimeError("NaN/Inf loss")
423
+
424
+ accelerator.backward(total_loss)
425
+
426
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
427
+ if accelerator.sync_gradients:
428
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
429
+ optimizer.step()
430
+ scheduler.step()
431
+ optimizer.zero_grad(set_to_none=True)
432
+
433
+ global_step += 1
434
+ progress.update(1)
435
+
436
+ # --- Логирование ---
437
+ if accelerator.is_main_process:
438
+ try:
439
+ current_lr = optimizer.param_groups[0]["lr"]
440
+ except Exception:
441
+ current_lr = scheduler.get_last_lr()[0]
442
+
443
+ batch_losses.append(total_loss.detach().item())
444
+ batch_losses_mae.append(mae_loss.detach().item())
445
+ batch_losses_lpips.append(lpips_loss.detach().item())
446
+ batch_grads.append(float(grad_norm if isinstance(grad_norm, (float, int)) else grad_norm.cpu().item()))
447
+
448
+ if use_wandb and accelerator.sync_gradients:
449
+ wandb.log({
450
+ "mae_loss": mae_loss.detach().item(),
451
+ "lpips_loss": lpips_loss.detach().item(),
452
+ "perceptual_loss_weight": avg_perc,
453
+ "total_loss": total_loss.detach().item(),
454
+ "learning_rate": current_lr,
455
+ "epoch": epoch,
456
+ "grad_norm": batch_grads[-1],
457
+ }, step=global_step)
458
+
459
+ # периодические сэмплы и чекпоинты
460
+ if global_step > 0 and global_step % sample_interval == 0:
461
+ # делаем генерацию и лог только в main process (генерация использует fixed_samples high-res)
462
+ if accelerator.is_main_process:
463
+ generate_and_save_samples(global_step)
464
+
465
+ accelerator.wait_for_everyone()
466
+
467
+ # сколько микро-батчей нужно взять для усреднения
468
+ n_micro = sample_interval * gradient_accumulation_steps
469
+ # защищаем от выхода за пределы
470
+ if len(batch_losses) >= n_micro:
471
+ avg_loss = float(np.mean(batch_losses[-n_micro:]))
472
+ avg_loss_mae = float(np.mean(batch_losses_mae[-n_micro:]))
473
+ avg_loss_lpips = float(np.mean(batch_losses_lpips[-n_micro:]))
474
+ else:
475
+ avg_loss = float(np.mean(batch_losses)) if batch_losses else float("nan")
476
+ avg_loss_mae = float(np.mean(batch_losses_mae)) if batch_losses_mae else float("nan")
477
+ avg_loss_lpips = float(np.mean(batch_losses_lpips)) if batch_losses_lpips else float("nan")
478
+
479
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
480
+
481
+ if accelerator.is_main_process:
482
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
483
+ if save_model and avg_loss < min_loss * save_barrier:
484
+ min_loss = avg_loss
485
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
486
+ if use_wandb:
487
+ wandb.log({"interm_loss": avg_loss,"interm_loss_mae": avg_loss_mae,"interm_loss_lpips": avg_loss_lpips, "interm_grad": avg_grad}, step=global_step)
488
+
489
+ if accelerator.is_main_process:
490
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
491
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
492
+ if use_wandb:
493
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
494
+
495
+ # --------------------------- Финальное сохранение ---------------------------
496
+ if accelerator.is_main_process:
497
+ print("Training finished – saving final model")
498
+ if save_model:
499
+ accelerator.unwrap_model(vae).save_pretrained(save_as)
500
+
501
+ accelerator.free_memory()
502
+ if torch.distributed.is_initialized():
503
+ torch.distributed.destroy_process_group()
504
+ print("Готово!")