File size: 5,974 Bytes
30358db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from diffusers import AsymmetricAutoencoderKL
from safetensors.torch import save_file

# Маппинг ключей Diffusers -> A1111
KEY_MAP = {
    # Encoder (без изменений)
    "encoder.conv_in": "encoder.conv_in",
    "encoder.conv_norm_out": "encoder.norm_out",
    "encoder.conv_out": "encoder.conv_out",
    
    # Encoder blocks (без изменений)
    "encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0",
    "encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1",
    "encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample",
    
    "encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0",
    "encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1",
    "encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample",
    
    "encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0",
    "encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1",
    "encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample",
    
    "encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0",
    "encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1",
    
    # Encoder middle
    "encoder.mid_block.resnets.0": "encoder.mid.block_1",
    "encoder.mid_block.attentions.0": "encoder.mid.attn_1",
    "encoder.mid_block.resnets.1": "encoder.mid.block_2",
    
    # Decoder
    "decoder.conv_in": "decoder.conv_in",
    "decoder.conv_norm_out": "decoder.norm_out",
    "decoder.conv_out": "decoder.conv_out",
    
    # Decoder middle
    "decoder.mid_block.resnets.0": "decoder.mid.block_1",
    "decoder.mid_block.attentions.0": "decoder.mid.attn_1",
    "decoder.mid_block.resnets.1": "decoder.mid.block_2",
    
    # Decoder blocks - ИСПРАВЛЕНО для 4 блоков
    # up_blocks.0 -> up.3 (самый глубокий)
    "decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0",
    "decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1",
    "decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2",
    "decoder.up_blocks.0.resnets.3": "decoder.up.3.block.3",
    "decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample",
    
    # up_blocks.1 -> up.2
    "decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0",
    "decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1",
    "decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2",
    "decoder.up_blocks.1.resnets.3": "decoder.up.2.block.3",
    "decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample",
    
    # up_blocks.2 -> up.1
    "decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0",
    "decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1",
    "decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2",
    "decoder.up_blocks.2.resnets.3": "decoder.up.1.block.3",
    "decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample",
    
    # up_blocks.3 -> up.0 (самый верхний)
    "decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0",
    "decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1",
    "decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2",
    "decoder.up_blocks.3.resnets.3": "decoder.up.0.block.3",
}

# Дополнительные замены для конкретных слоев
LAYER_RENAMES = {
    "conv_shortcut": "nin_shortcut",
    "group_norm": "norm",
    "to_q": "q",
    "to_k": "k", 
    "to_v": "v",
    "to_out.0": "proj_out",
}

def convert_key(key):
    """Конвертирует ключ из формата Diffusers в формат A1111"""
    # Пропускаем специфичные для AsymmetricVAE компоненты
    if "condition_encoder" in key:
        return None  # A1111 не поддерживает condition_encoder
    
    # Сначала проверяем прямые маппинги
    for diffusers_prefix, a1111_prefix in KEY_MAP.items():
        if key.startswith(diffusers_prefix):
            new_key = key.replace(diffusers_prefix, a1111_prefix, 1)
            # Применяем дополнительные замены
            for old, new in LAYER_RENAMES.items():
                new_key = new_key.replace(old, new)
            return new_key
    
    # Если не нашли в маппинге, возвращаем как есть
    return key

# Загружаем VAE
vae = AsymmetricAutoencoderKL.from_pretrained("./asymmetric_vae")
state_dict = vae.state_dict()

# Конвертируем ключи
converted_state_dict = {}
skipped_keys = []

for key, value in state_dict.items():
    new_key = convert_key(key)
    
    if new_key is None:
        skipped_keys.append(key)
        continue
    
    # Проверяем, нужно ли изменить форму для attention весов
    if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]):
        # Преобразуем из [out_features, in_features] в [out_features, in_features, 1, 1]
        if value.dim() == 2:
            value = value.unsqueeze(-1).unsqueeze(-1)
    
    converted_state_dict[new_key] = value

# Сохраняем
save_file(converted_state_dict, "sdxl_vae_asymm_a1111.safetensors")

print(f"Конвертировано {len(converted_state_dict)} ключей")
print(f"Пропущено {len(skipped_keys)} ключей (condition_encoder и др.)")

if skipped_keys:
    print("\nПропущенные ключи:")
    for key in skipped_keys[:10]:  # Показываем первые 10
        print(f"  - {key}")

print("\nПримеры конвертированных ключей:")
for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])):
    if old not in skipped_keys:
        print(f"{old} -> {new}")

# Проверяем attention веса
print("\nAttention веса после конвертации:")
for key, value in converted_state_dict.items():
    if "attn_1" in key and "weight" in key:
        print(f"{key}: {value.shape}")