lym00 commited on
Commit
6b3b9b9
·
verified ·
1 Parent(s): 69c64b7

Upload 2 files

Browse files
Files changed (2) hide show
  1. quantized/combine.py +101 -11
  2. quantized/metadata.json +45 -0
quantized/combine.py CHANGED
@@ -1,11 +1,101 @@
1
- import safetensors.torch
2
-
3
- # Load both files
4
- quantized_blocks = safetensors.torch.load_file("transformer_blocks.safetensors")
5
- unquantized_layers = safetensors.torch.load_file("unquantized_layers.safetensors")
6
-
7
- # Combine the state dicts
8
- combined_state_dict = {**quantized_blocks, **unquantized_layers}
9
-
10
- # Save the combined model
11
- safetensors.torch.save_file(combined_state_dict, "combined.safetensors")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import safetensors.torch
4
+ import os
5
+ import json
6
+
7
+ def prompt_file(prompt, default):
8
+ path = input(f"{prompt} [{default}]: ").strip()
9
+ return path if path else default
10
+
11
+ # Prompt for input/output files
12
+ quant_file = prompt_file("Enter path to quantized blocks", "transformer_blocks.safetensors")
13
+ unquant_file = prompt_file("Enter path to unquantized layers", "unquantized_layers.safetensors")
14
+ output_file = prompt_file("Enter path to save combined model", "combined.safetensors")
15
+ metadata_file = prompt_file("Enter path to metadata.json (optional)", "metadata.json")
16
+
17
+ # Validate file existence
18
+ for f in [quant_file, unquant_file]:
19
+ if not os.path.isfile(f):
20
+ raise FileNotFoundError(f"File not found: {f}")
21
+
22
+ # Load state dicts
23
+ quantized_blocks = safetensors.torch.load_file(quant_file)
24
+ unquantized_layers = safetensors.torch.load_file(unquant_file)
25
+
26
+ # Warn about key overlaps
27
+ overlap = set(quantized_blocks) & set(unquantized_layers)
28
+ if overlap:
29
+ print(f"⚠️ Warning: Overlapping keys (unquantized will override): {overlap}")
30
+
31
+ # Merge state dicts
32
+ combined_state_dict = {**quantized_blocks, **unquantized_layers}
33
+
34
+ # Attempt to load metadata.json
35
+ metadata = {}
36
+ if os.path.isfile(metadata_file):
37
+ try:
38
+ with open(metadata_file, "r", encoding="utf-8") as f:
39
+ metadata = json.load(f)
40
+ # Convert nested objects to JSON strings as required by safetensors
41
+ for k, v in metadata.items():
42
+ if isinstance(v, dict):
43
+ metadata[k] = json.dumps(v)
44
+ except Exception as e:
45
+ print(f"⚠️ Failed to load metadata from {metadata_file}: {e}")
46
+ print("⏳ Falling back to hardcoded metadata...")
47
+ metadata = {} # Will populate below
48
+
49
+ # Fallback metadata (if file load failed or was missing)
50
+ if not metadata:
51
+ metadata = {
52
+ "model_class": "NunchakuFluxTransformer2dModel",
53
+ "comfy_config": json.dumps({
54
+ "model_class": "Flux",
55
+ "model_config": {
56
+ "axes_dim": [16, 56, 56],
57
+ "context_in_dim": 4096,
58
+ "depth": 19,
59
+ "depth_single_blocks": 38,
60
+ "disable_unet_model_creation": True,
61
+ "guidance_embed": True,
62
+ "hidden_size": 3072,
63
+ "image_model": "flux",
64
+ "in_channels": 16,
65
+ "mlp_ratio": 4.0,
66
+ "num_heads": 24,
67
+ "out_channels": 16,
68
+ "patch_size": 2,
69
+ "qkv_bias": True,
70
+ "theta": 10000,
71
+ "vec_in_dim": 768
72
+ }
73
+ }),
74
+ "quantization_config": json.dumps({
75
+ "method": "svdquant",
76
+ "weight": {"dtype": "int4", "scale_dtype": None, "group_size": 64},
77
+ "activation": {"dtype": "int4", "scale_dtype": None, "group_size": 64}
78
+ }),
79
+ "config": json.dumps({
80
+ "_class_name": "FluxTransformer2DModel",
81
+ "_diffusers_version": "0.34.0.dev0",
82
+ "_name_or_path": "../checkpoints/flux-dev/transformer",
83
+ "attention_head_dim": 128,
84
+ "axes_dims_rope": [16, 56, 56],
85
+ "guidance_embeds": True,
86
+ "in_channels": 64,
87
+ "joint_attention_dim": 4096,
88
+ "num_attention_heads": 24,
89
+ "num_layers": 19,
90
+ "num_single_layers": 38,
91
+ "out_channels": None,
92
+ "patch_size": 1,
93
+ "pooled_projection_dim": 768
94
+ })
95
+ }
96
+
97
+ # Save the combined safetensors file
98
+ safetensors.torch.save_file(combined_state_dict, output_file, metadata=metadata)
99
+
100
+ print(f"\n✅ Combined model saved to: {output_file}")
101
+ print(f"ℹ️ Metadata keys included: {', '.join(metadata.keys())}")
quantized/metadata.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_class": "NunchakuFluxTransformer2dModel",
3
+ "comfy_config": {
4
+ "model_class": "Flux",
5
+ "model_config": {
6
+ "axes_dim": [16, 56, 56],
7
+ "context_in_dim": 4096,
8
+ "depth": 19,
9
+ "depth_single_blocks": 38,
10
+ "disable_unet_model_creation": true,
11
+ "guidance_embed": true,
12
+ "hidden_size": 3072,
13
+ "image_model": "flux",
14
+ "in_channels": 16,
15
+ "mlp_ratio": 4.0,
16
+ "num_heads": 24,
17
+ "out_channels": 16,
18
+ "patch_size": 2,
19
+ "qkv_bias": true,
20
+ "theta": 10000,
21
+ "vec_in_dim": 768
22
+ }
23
+ },
24
+ "quantization_config": {
25
+ "method": "svdquant",
26
+ "weight": {"dtype": "int4", "scale_dtype": null, "group_size": 64},
27
+ "activation": {"dtype": "int4", "scale_dtype": null, "group_size": 64}
28
+ },
29
+ "config": {
30
+ "_class_name": "FluxTransformer2DModel",
31
+ "_diffusers_version": "0.34.0.dev0",
32
+ "_name_or_path": "../checkpoints/flux-dev/transformer",
33
+ "attention_head_dim": 128,
34
+ "axes_dims_rope": [16, 56, 56],
35
+ "guidance_embeds": true,
36
+ "in_channels": 64,
37
+ "joint_attention_dim": 4096,
38
+ "num_attention_heads": 24,
39
+ "num_layers": 19,
40
+ "num_single_layers": 38,
41
+ "out_channels": null,
42
+ "patch_size": 1,
43
+ "pooled_projection_dim": 768
44
+ }
45
+ }