|
import os |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
import torch |
|
|
|
def merge_safetensor_files(sftsr_files, output_file="model.safetensors"): |
|
slices_dict = {} |
|
metadata = {} |
|
|
|
for idx, file in enumerate(sftsr_files): |
|
with safe_open(file, framework="pt") as sf_tsr: |
|
if idx == 0: |
|
metadata = sf_tsr.metadata() |
|
for key in sf_tsr.keys(): |
|
tensor = sf_tsr.get_tensor(key) |
|
if key not in slices_dict: |
|
slices_dict[key] = [] |
|
slices_dict[key].append(tensor) |
|
|
|
merged_tensors = {} |
|
for key, slices in slices_dict.items(): |
|
if len(slices) == 1: |
|
merged_tensors[key] = slices[0] |
|
else: |
|
|
|
ref_shape = slices[0].shape |
|
concat_dim = None |
|
for dim in range(len(ref_shape)): |
|
dim_sizes = [s.shape[dim] for s in slices] |
|
if len(set(dim_sizes)) > 1: |
|
concat_dim = dim |
|
break |
|
if concat_dim is None: |
|
concat_dim = 0 |
|
merged_tensors[key] = torch.cat(slices, dim=concat_dim) |
|
print(f"Merged key '{key}' from {len(slices)} slices along dim {concat_dim}") |
|
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True) |
|
save_file(merged_tensors, output_file, metadata) |
|
print(f"Merged {len(sftsr_files)} shards into {output_file}") |
|
|
|
def get_safetensor_files(directory): |
|
safetensors_files = [] |
|
for root, _, files in os.walk(directory): |
|
for file in files: |
|
if file.endswith(".safetensors"): |
|
safetensors_files.append(os.path.join(root, file)) |
|
return safetensors_files |
|
|
|
if __name__ == "__main__": |
|
safetensor_files = get_safetensor_files("./shards") |
|
print(f"The following shards/chunks will be merged: {safetensor_files}") |
|
|
|
default_output = "./output/merged_model.safetensors" |
|
user_output = input(f"Enter output file path [{default_output}]: ").strip() |
|
output_file = user_output if user_output else default_output |
|
|
|
merge_safetensor_files(safetensor_files, output_file=output_file) |
|
|