peteromallet commited on
Commit
f61e618
·
verified ·
1 Parent(s): 4aaac17

Upload convert_custom_lora.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. convert_custom_lora.py +117 -0
convert_custom_lora.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from safetensors.torch import load_file, save_file
4
+ from collections import defaultdict
5
+
6
+ def convert_comfy_to_wan_lora_final_fp16(lora_path, output_path):
7
+ """
8
+ Converts a ComfyUI-style LoRA to the format expected by 'wan.modules.model'.
9
+ - Keeps 'diffusion_model.' prefix.
10
+ - Converts 'lora_A' to 'lora_down', 'lora_B' to 'lora_up'.
11
+ - Skips per-layer '.alpha' keys.
12
+ - Skips keys related to 'img_emb.' that are under the 'diffusion_model.' prefix.
13
+ - Converts all LoRA weight tensors to float16.
14
+
15
+ Args:
16
+ lora_path (str): Path to the input ComfyUI LoRA .safetensors file.
17
+ output_path (str): Path to save the converted LoRA .safetensors file.
18
+ """
19
+ try:
20
+ source_state_dict = load_file(lora_path)
21
+ except Exception as e:
22
+ print(f"Error loading LoRA file '{lora_path}': {e}")
23
+ return
24
+
25
+ diffusers_state_dict = {}
26
+ print(f"Loaded {len(source_state_dict)} tensors from {lora_path}")
27
+
28
+ source_comfy_prefix = "diffusion_model."
29
+ target_wan_prefix = "diffusion_model."
30
+
31
+ converted_count = 0
32
+ skipped_alpha_keys_count = 0
33
+ skipped_img_emb_keys_count = 0
34
+ problematic_keys = []
35
+
36
+ for key, tensor in source_state_dict.items():
37
+ original_key = key
38
+
39
+ if not key.startswith(source_comfy_prefix):
40
+ problematic_keys.append(f"{original_key} (Key does not start with expected prefix '{source_comfy_prefix}')")
41
+ continue
42
+
43
+ module_and_lora_part = key[len(source_comfy_prefix):]
44
+
45
+ if module_and_lora_part.startswith("img_emb."):
46
+ skipped_img_emb_keys_count += 1
47
+ continue
48
+
49
+ new_key_module_base = ""
50
+ new_lora_suffix = ""
51
+ is_weight_tensor = False # Flag to identify tensors that need dtype conversion
52
+
53
+ if module_and_lora_part.endswith(".lora_A.weight"):
54
+ new_key_module_base = module_and_lora_part[:-len(".lora_A.weight")]
55
+ new_lora_suffix = ".lora_down.weight"
56
+ is_weight_tensor = True
57
+ elif module_and_lora_part.endswith(".lora_B.weight"):
58
+ new_key_module_base = module_and_lora_part[:-len(".lora_B.weight")]
59
+ new_lora_suffix = ".lora_up.weight"
60
+ is_weight_tensor = True
61
+ elif module_and_lora_part.endswith(".alpha"):
62
+ skipped_alpha_keys_count += 1
63
+ continue # Alpha keys are skipped and don't need dtype conversion if they were kept
64
+ else:
65
+ problematic_keys.append(f"{original_key} (Unknown LoRA suffix or non-LoRA key within '{source_comfy_prefix}' structure: '...{module_and_lora_part[-25:]}')")
66
+ continue
67
+
68
+ new_key = target_wan_prefix + new_key_module_base + new_lora_suffix
69
+
70
+ # Convert to float16 if it's a weight tensor
71
+ if is_weight_tensor:
72
+ if tensor.is_floating_point(): # Only convert floating point types
73
+ diffusers_state_dict[new_key] = tensor.to(torch.float16)
74
+ else: # Should not happen for LoRA weights, but as a safeguard
75
+ diffusers_state_dict[new_key] = tensor
76
+ print(f"Warning: Tensor {original_key} was not floating point, dtype not changed.")
77
+
78
+ else: # Should not be reached if only lora_A/B weights are processed
79
+ diffusers_state_dict[new_key] = tensor
80
+
81
+
82
+ converted_count += 1
83
+
84
+ print(f"\nKey conversion finished.")
85
+ print(f"Successfully processed and converted {converted_count} LoRA weight keys (to float16).")
86
+ if skipped_alpha_keys_count > 0:
87
+ print(f"Skipped {skipped_alpha_keys_count} '.alpha' keys.")
88
+ if skipped_img_emb_keys_count > 0:
89
+ print(f"Skipped {skipped_img_emb_keys_count} 'diffusion_model.img_emb.' related keys.")
90
+ if problematic_keys:
91
+ print(f"Found {len(problematic_keys)} other keys that were also skipped (see details below):")
92
+ for pkey in problematic_keys:
93
+ print(f" - {pkey}")
94
+
95
+ if diffusers_state_dict:
96
+ print(f"Output dictionary has {len(diffusers_state_dict)} keys.")
97
+ print(f"Now attempting to save the file to: {output_path} (This might take a while for large files)...")
98
+ try:
99
+ save_file(diffusers_state_dict, output_path)
100
+ print(f"\nSuccessfully saved converted LoRA to: {output_path}")
101
+ except Exception as e:
102
+ print(f"Error saving converted LoRA file '{output_path}': {e}")
103
+ elif converted_count == 0 and source_state_dict:
104
+ print("\nNo keys were converted. Check input LoRA format and skipped key counts.")
105
+ elif not source_state_dict:
106
+ print("\nInput LoRA file seems empty or could not be loaded. No conversion performed.")
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser(
110
+ description="Convert ComfyUI-style LoRA to 'wan.modules.model' format, converting weights to float16.",
111
+ formatter_class=argparse.RawTextHelpFormatter
112
+ )
113
+ parser.add_argument("lora_path", type=str, help="Path to the input ComfyUI LoRA (.safetensors) file.")
114
+ parser.add_argument("output_path", type=str, help="Path to save the converted LoRA (.safetensors) file.")
115
+ args = parser.parse_args()
116
+
117
+ convert_comfy_to_wan_lora_final_fp16(args.lora_path, args.output_path)