lym00 commited on
Commit
fca2d3f
·
verified ·
1 Parent(s): 409dba5

Upload merge_sharded_safetensors.py

Browse files
Files changed (1) hide show
  1. merge_sharded_safetensors.py +58 -0
merge_sharded_safetensors.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from safetensors import safe_open
3
+ from safetensors.torch import save_file
4
+ import torch # Needed for torch.cat
5
+
6
+ def merge_safetensor_files(sftsr_files, output_file="model.safetensors"):
7
+ slices_dict = {}
8
+ metadata = {}
9
+
10
+ for idx, file in enumerate(sftsr_files):
11
+ with safe_open(file, framework="pt") as sf_tsr:
12
+ if idx == 0:
13
+ metadata = sf_tsr.metadata()
14
+ for key in sf_tsr.keys():
15
+ tensor = sf_tsr.get_tensor(key)
16
+ if key not in slices_dict:
17
+ slices_dict[key] = []
18
+ slices_dict[key].append(tensor)
19
+
20
+ merged_tensors = {}
21
+ for key, slices in slices_dict.items():
22
+ if len(slices) == 1:
23
+ merged_tensors[key] = slices[0]
24
+ else:
25
+ # Simple heuristic: find dim with mismatched size
26
+ ref_shape = slices[0].shape
27
+ concat_dim = None
28
+ for dim in range(len(ref_shape)):
29
+ dim_sizes = [s.shape[dim] for s in slices]
30
+ if len(set(dim_sizes)) > 1:
31
+ concat_dim = dim
32
+ break
33
+ if concat_dim is None:
34
+ concat_dim = 0 # fallback
35
+ merged_tensors[key] = torch.cat(slices, dim=concat_dim)
36
+ print(f"Merged key '{key}' from {len(slices)} slices along dim {concat_dim}")
37
+
38
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
39
+ save_file(merged_tensors, output_file, metadata)
40
+ print(f"Merged {len(sftsr_files)} shards into {output_file}")
41
+
42
+ def get_safetensor_files(directory):
43
+ safetensors_files = []
44
+ for root, _, files in os.walk(directory):
45
+ for file in files:
46
+ if file.endswith(".safetensors"):
47
+ safetensors_files.append(os.path.join(root, file))
48
+ return safetensors_files
49
+
50
+ if __name__ == "__main__":
51
+ safetensor_files = get_safetensor_files("./shards")
52
+ print(f"The following shards/chunks will be merged: {safetensor_files}")
53
+
54
+ default_output = "./output/merged_model.safetensors"
55
+ user_output = input(f"Enter output file path [{default_output}]: ").strip()
56
+ output_file = user_output if user_output else default_output
57
+
58
+ merge_safetensor_files(safetensor_files, output_file=output_file)