File size: 2,262 Bytes
fca2d3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
from safetensors import safe_open
from safetensors.torch import save_file
import torch  # Needed for torch.cat

def merge_safetensor_files(sftsr_files, output_file="model.safetensors"):
    slices_dict = {}
    metadata = {}

    for idx, file in enumerate(sftsr_files):
        with safe_open(file, framework="pt") as sf_tsr:
            if idx == 0:
                metadata = sf_tsr.metadata()
            for key in sf_tsr.keys():
                tensor = sf_tsr.get_tensor(key)
                if key not in slices_dict:
                    slices_dict[key] = []
                slices_dict[key].append(tensor)
    
    merged_tensors = {}
    for key, slices in slices_dict.items():
        if len(slices) == 1:
            merged_tensors[key] = slices[0]
        else:
            # Simple heuristic: find dim with mismatched size
            ref_shape = slices[0].shape
            concat_dim = None
            for dim in range(len(ref_shape)):
                dim_sizes = [s.shape[dim] for s in slices]
                if len(set(dim_sizes)) > 1:
                    concat_dim = dim
                    break
            if concat_dim is None:
                concat_dim = 0  # fallback
            merged_tensors[key] = torch.cat(slices, dim=concat_dim)
            print(f"Merged key '{key}' from {len(slices)} slices along dim {concat_dim}")

    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    save_file(merged_tensors, output_file, metadata)
    print(f"Merged {len(sftsr_files)} shards into {output_file}")

def get_safetensor_files(directory):
    safetensors_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".safetensors"):
                safetensors_files.append(os.path.join(root, file))
    return safetensors_files

if __name__ == "__main__":
    safetensor_files = get_safetensor_files("./shards")
    print(f"The following shards/chunks will be merged: {safetensor_files}")

    default_output = "./output/merged_model.safetensors"
    user_output = input(f"Enter output file path [{default_output}]: ").strip()
    output_file = user_output if user_output else default_output

    merge_safetensor_files(safetensor_files, output_file=output_file)