lym00 commited on
Commit
bec2369
·
verified ·
1 Parent(s): 6b3b9b9

Upload 2 files

Browse files
Files changed (2) hide show
  1. quantized/combine.py +17 -9
  2. quantized/metadata.json +14 -6
quantized/combine.py CHANGED
@@ -37,16 +37,16 @@ if os.path.isfile(metadata_file):
37
  try:
38
  with open(metadata_file, "r", encoding="utf-8") as f:
39
  metadata = json.load(f)
40
- # Convert nested objects to JSON strings as required by safetensors
41
  for k, v in metadata.items():
42
  if isinstance(v, dict):
43
  metadata[k] = json.dumps(v)
44
  except Exception as e:
45
  print(f"⚠️ Failed to load metadata from {metadata_file}: {e}")
46
  print("⏳ Falling back to hardcoded metadata...")
47
- metadata = {} # Will populate below
48
 
49
- # Fallback metadata (if file load failed or was missing)
50
  if not metadata:
51
  metadata = {
52
  "model_class": "NunchakuFluxTransformer2dModel",
@@ -71,11 +71,6 @@ if not metadata:
71
  "vec_in_dim": 768
72
  }
73
  }),
74
- "quantization_config": json.dumps({
75
- "method": "svdquant",
76
- "weight": {"dtype": "int4", "scale_dtype": None, "group_size": 64},
77
- "activation": {"dtype": "int4", "scale_dtype": None, "group_size": 64}
78
- }),
79
  "config": json.dumps({
80
  "_class_name": "FluxTransformer2DModel",
81
  "_diffusers_version": "0.34.0.dev0",
@@ -91,10 +86,23 @@ if not metadata:
91
  "out_channels": None,
92
  "patch_size": 1,
93
  "pooled_projection_dim": 768
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  })
95
  }
96
 
97
- # Save the combined safetensors file
98
  safetensors.torch.save_file(combined_state_dict, output_file, metadata=metadata)
99
 
100
  print(f"\n✅ Combined model saved to: {output_file}")
 
37
  try:
38
  with open(metadata_file, "r", encoding="utf-8") as f:
39
  metadata = json.load(f)
40
+ # Convert nested objects to JSON strings
41
  for k, v in metadata.items():
42
  if isinstance(v, dict):
43
  metadata[k] = json.dumps(v)
44
  except Exception as e:
45
  print(f"⚠️ Failed to load metadata from {metadata_file}: {e}")
46
  print("⏳ Falling back to hardcoded metadata...")
47
+ metadata = {}
48
 
49
+ # Hardcoded FP4 metadata fallback
50
  if not metadata:
51
  metadata = {
52
  "model_class": "NunchakuFluxTransformer2dModel",
 
71
  "vec_in_dim": 768
72
  }
73
  }),
 
 
 
 
 
74
  "config": json.dumps({
75
  "_class_name": "FluxTransformer2DModel",
76
  "_diffusers_version": "0.34.0.dev0",
 
86
  "out_channels": None,
87
  "patch_size": 1,
88
  "pooled_projection_dim": 768
89
+ }),
90
+ "quantization_config": json.dumps({
91
+ "method": "svdquant",
92
+ "weight": {
93
+ "dtype": "fp4_e2m1_all",
94
+ "scale_dtype": [None, "fp8_e4m3_nan"],
95
+ "group_size": 16
96
+ },
97
+ "activation": {
98
+ "dtype": "fp4_e2m1_all",
99
+ "scale_dtype": "fp8_e4m3_nan",
100
+ "group_size": 16
101
+ }
102
  })
103
  }
104
 
105
+ # Save the combined file
106
  safetensors.torch.save_file(combined_state_dict, output_file, metadata=metadata)
107
 
108
  print(f"\n✅ Combined model saved to: {output_file}")
quantized/metadata.json CHANGED
@@ -21,11 +21,6 @@
21
  "vec_in_dim": 768
22
  }
23
  },
24
- "quantization_config": {
25
- "method": "svdquant",
26
- "weight": {"dtype": "int4", "scale_dtype": null, "group_size": 64},
27
- "activation": {"dtype": "int4", "scale_dtype": null, "group_size": 64}
28
- },
29
  "config": {
30
  "_class_name": "FluxTransformer2DModel",
31
  "_diffusers_version": "0.34.0.dev0",
@@ -41,5 +36,18 @@
41
  "out_channels": null,
42
  "patch_size": 1,
43
  "pooled_projection_dim": 768
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  }
45
- }
 
21
  "vec_in_dim": 768
22
  }
23
  },
 
 
 
 
 
24
  "config": {
25
  "_class_name": "FluxTransformer2DModel",
26
  "_diffusers_version": "0.34.0.dev0",
 
36
  "out_channels": null,
37
  "patch_size": 1,
38
  "pooled_projection_dim": 768
39
+ },
40
+ "quantization_config": {
41
+ "method": "svdquant",
42
+ "weight": {
43
+ "dtype": "fp4_e2m1_all",
44
+ "scale_dtype": [null, "fp8_e4m3_nan"],
45
+ "group_size": 16
46
+ },
47
+ "activation": {
48
+ "dtype": "fp4_e2m1_all",
49
+ "scale_dtype": "fp8_e4m3_nan",
50
+ "group_size": 16
51
+ }
52
  }
53
+ }