Upload merge_sharded_safetensors.py
Browse files- merge_sharded_safetensors.py +58 -0
merge_sharded_safetensors.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from safetensors import safe_open
|
3 |
+
from safetensors.torch import save_file
|
4 |
+
import torch # Needed for torch.cat
|
5 |
+
|
6 |
+
def merge_safetensor_files(sftsr_files, output_file="model.safetensors"):
|
7 |
+
slices_dict = {}
|
8 |
+
metadata = {}
|
9 |
+
|
10 |
+
for idx, file in enumerate(sftsr_files):
|
11 |
+
with safe_open(file, framework="pt") as sf_tsr:
|
12 |
+
if idx == 0:
|
13 |
+
metadata = sf_tsr.metadata()
|
14 |
+
for key in sf_tsr.keys():
|
15 |
+
tensor = sf_tsr.get_tensor(key)
|
16 |
+
if key not in slices_dict:
|
17 |
+
slices_dict[key] = []
|
18 |
+
slices_dict[key].append(tensor)
|
19 |
+
|
20 |
+
merged_tensors = {}
|
21 |
+
for key, slices in slices_dict.items():
|
22 |
+
if len(slices) == 1:
|
23 |
+
merged_tensors[key] = slices[0]
|
24 |
+
else:
|
25 |
+
# Simple heuristic: find dim with mismatched size
|
26 |
+
ref_shape = slices[0].shape
|
27 |
+
concat_dim = None
|
28 |
+
for dim in range(len(ref_shape)):
|
29 |
+
dim_sizes = [s.shape[dim] for s in slices]
|
30 |
+
if len(set(dim_sizes)) > 1:
|
31 |
+
concat_dim = dim
|
32 |
+
break
|
33 |
+
if concat_dim is None:
|
34 |
+
concat_dim = 0 # fallback
|
35 |
+
merged_tensors[key] = torch.cat(slices, dim=concat_dim)
|
36 |
+
print(f"Merged key '{key}' from {len(slices)} slices along dim {concat_dim}")
|
37 |
+
|
38 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
39 |
+
save_file(merged_tensors, output_file, metadata)
|
40 |
+
print(f"Merged {len(sftsr_files)} shards into {output_file}")
|
41 |
+
|
42 |
+
def get_safetensor_files(directory):
|
43 |
+
safetensors_files = []
|
44 |
+
for root, _, files in os.walk(directory):
|
45 |
+
for file in files:
|
46 |
+
if file.endswith(".safetensors"):
|
47 |
+
safetensors_files.append(os.path.join(root, file))
|
48 |
+
return safetensors_files
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
safetensor_files = get_safetensor_files("./shards")
|
52 |
+
print(f"The following shards/chunks will be merged: {safetensor_files}")
|
53 |
+
|
54 |
+
default_output = "./output/merged_model.safetensors"
|
55 |
+
user_output = input(f"Enter output file path [{default_output}]: ").strip()
|
56 |
+
output_file = user_output if user_output else default_output
|
57 |
+
|
58 |
+
merge_safetensor_files(safetensor_files, output_file=output_file)
|