File size: 5,974 Bytes
30358db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
from diffusers import AsymmetricAutoencoderKL
from safetensors.torch import save_file
# Маппинг ключей Diffusers -> A1111
KEY_MAP = {
# Encoder (без изменений)
"encoder.conv_in": "encoder.conv_in",
"encoder.conv_norm_out": "encoder.norm_out",
"encoder.conv_out": "encoder.conv_out",
# Encoder blocks (без изменений)
"encoder.down_blocks.0.resnets.0": "encoder.down.0.block.0",
"encoder.down_blocks.0.resnets.1": "encoder.down.0.block.1",
"encoder.down_blocks.0.downsamplers.0": "encoder.down.0.downsample",
"encoder.down_blocks.1.resnets.0": "encoder.down.1.block.0",
"encoder.down_blocks.1.resnets.1": "encoder.down.1.block.1",
"encoder.down_blocks.1.downsamplers.0": "encoder.down.1.downsample",
"encoder.down_blocks.2.resnets.0": "encoder.down.2.block.0",
"encoder.down_blocks.2.resnets.1": "encoder.down.2.block.1",
"encoder.down_blocks.2.downsamplers.0": "encoder.down.2.downsample",
"encoder.down_blocks.3.resnets.0": "encoder.down.3.block.0",
"encoder.down_blocks.3.resnets.1": "encoder.down.3.block.1",
# Encoder middle
"encoder.mid_block.resnets.0": "encoder.mid.block_1",
"encoder.mid_block.attentions.0": "encoder.mid.attn_1",
"encoder.mid_block.resnets.1": "encoder.mid.block_2",
# Decoder
"decoder.conv_in": "decoder.conv_in",
"decoder.conv_norm_out": "decoder.norm_out",
"decoder.conv_out": "decoder.conv_out",
# Decoder middle
"decoder.mid_block.resnets.0": "decoder.mid.block_1",
"decoder.mid_block.attentions.0": "decoder.mid.attn_1",
"decoder.mid_block.resnets.1": "decoder.mid.block_2",
# Decoder blocks - ИСПРАВЛЕНО для 4 блоков
# up_blocks.0 -> up.3 (самый глубокий)
"decoder.up_blocks.0.resnets.0": "decoder.up.3.block.0",
"decoder.up_blocks.0.resnets.1": "decoder.up.3.block.1",
"decoder.up_blocks.0.resnets.2": "decoder.up.3.block.2",
"decoder.up_blocks.0.resnets.3": "decoder.up.3.block.3",
"decoder.up_blocks.0.upsamplers.0": "decoder.up.3.upsample",
# up_blocks.1 -> up.2
"decoder.up_blocks.1.resnets.0": "decoder.up.2.block.0",
"decoder.up_blocks.1.resnets.1": "decoder.up.2.block.1",
"decoder.up_blocks.1.resnets.2": "decoder.up.2.block.2",
"decoder.up_blocks.1.resnets.3": "decoder.up.2.block.3",
"decoder.up_blocks.1.upsamplers.0": "decoder.up.2.upsample",
# up_blocks.2 -> up.1
"decoder.up_blocks.2.resnets.0": "decoder.up.1.block.0",
"decoder.up_blocks.2.resnets.1": "decoder.up.1.block.1",
"decoder.up_blocks.2.resnets.2": "decoder.up.1.block.2",
"decoder.up_blocks.2.resnets.3": "decoder.up.1.block.3",
"decoder.up_blocks.2.upsamplers.0": "decoder.up.1.upsample",
# up_blocks.3 -> up.0 (самый верхний)
"decoder.up_blocks.3.resnets.0": "decoder.up.0.block.0",
"decoder.up_blocks.3.resnets.1": "decoder.up.0.block.1",
"decoder.up_blocks.3.resnets.2": "decoder.up.0.block.2",
"decoder.up_blocks.3.resnets.3": "decoder.up.0.block.3",
}
# Дополнительные замены для конкретных слоев
LAYER_RENAMES = {
"conv_shortcut": "nin_shortcut",
"group_norm": "norm",
"to_q": "q",
"to_k": "k",
"to_v": "v",
"to_out.0": "proj_out",
}
def convert_key(key):
"""Конвертирует ключ из формата Diffusers в формат A1111"""
# Пропускаем специфичные для AsymmetricVAE компоненты
if "condition_encoder" in key:
return None # A1111 не поддерживает condition_encoder
# Сначала проверяем прямые маппинги
for diffusers_prefix, a1111_prefix in KEY_MAP.items():
if key.startswith(diffusers_prefix):
new_key = key.replace(diffusers_prefix, a1111_prefix, 1)
# Применяем дополнительные замены
for old, new in LAYER_RENAMES.items():
new_key = new_key.replace(old, new)
return new_key
# Если не нашли в маппинге, возвращаем как есть
return key
# Загружаем VAE
vae = AsymmetricAutoencoderKL.from_pretrained("./asymmetric_vae")
state_dict = vae.state_dict()
# Конвертируем ключи
converted_state_dict = {}
skipped_keys = []
for key, value in state_dict.items():
new_key = convert_key(key)
if new_key is None:
skipped_keys.append(key)
continue
# Проверяем, нужно ли изменить форму для attention весов
if "attn_1" in new_key and any(x in new_key for x in ["q.weight", "k.weight", "v.weight", "proj_out.weight"]):
# Преобразуем из [out_features, in_features] в [out_features, in_features, 1, 1]
if value.dim() == 2:
value = value.unsqueeze(-1).unsqueeze(-1)
converted_state_dict[new_key] = value
# Сохраняем
save_file(converted_state_dict, "sdxl_vae_asymm_a1111.safetensors")
print(f"Конвертировано {len(converted_state_dict)} ключей")
print(f"Пропущено {len(skipped_keys)} ключей (condition_encoder и др.)")
if skipped_keys:
print("\nПропущенные ключи:")
for key in skipped_keys[:10]: # Показываем первые 10
print(f" - {key}")
print("\nПримеры конвертированных ключей:")
for i, (old, new) in enumerate(zip(list(state_dict.keys())[:5], list(converted_state_dict.keys())[:5])):
if old not in skipped_keys:
print(f"{old} -> {new}")
# Проверяем attention веса
print("\nAttention веса после конвертации:")
for key, value in converted_state_dict.items():
if "attn_1" in key and "weight" in key:
print(f"{key}: {value.shape}") |