scripts / merge_sharded_safetensors.py
lym00's picture
Upload merge_sharded_safetensors.py
fca2d3f verified
import os
from safetensors import safe_open
from safetensors.torch import save_file
import torch # Needed for torch.cat
def merge_safetensor_files(sftsr_files, output_file="model.safetensors"):
slices_dict = {}
metadata = {}
for idx, file in enumerate(sftsr_files):
with safe_open(file, framework="pt") as sf_tsr:
if idx == 0:
metadata = sf_tsr.metadata()
for key in sf_tsr.keys():
tensor = sf_tsr.get_tensor(key)
if key not in slices_dict:
slices_dict[key] = []
slices_dict[key].append(tensor)
merged_tensors = {}
for key, slices in slices_dict.items():
if len(slices) == 1:
merged_tensors[key] = slices[0]
else:
# Simple heuristic: find dim with mismatched size
ref_shape = slices[0].shape
concat_dim = None
for dim in range(len(ref_shape)):
dim_sizes = [s.shape[dim] for s in slices]
if len(set(dim_sizes)) > 1:
concat_dim = dim
break
if concat_dim is None:
concat_dim = 0 # fallback
merged_tensors[key] = torch.cat(slices, dim=concat_dim)
print(f"Merged key '{key}' from {len(slices)} slices along dim {concat_dim}")
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(merged_tensors, output_file, metadata)
print(f"Merged {len(sftsr_files)} shards into {output_file}")
def get_safetensor_files(directory):
safetensors_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith(".safetensors"):
safetensors_files.append(os.path.join(root, file))
return safetensors_files
if __name__ == "__main__":
safetensor_files = get_safetensor_files("./shards")
print(f"The following shards/chunks will be merged: {safetensor_files}")
default_output = "./output/merged_model.safetensors"
user_output = input(f"Enter output file path [{default_output}]: ").strip()
output_file = user_output if user_output else default_output
merge_safetensor_files(safetensor_files, output_file=output_file)