import os from safetensors import safe_open from safetensors.torch import save_file import torch # Needed for torch.cat def merge_safetensor_files(sftsr_files, output_file="model.safetensors"): slices_dict = {} metadata = {} for idx, file in enumerate(sftsr_files): with safe_open(file, framework="pt") as sf_tsr: if idx == 0: metadata = sf_tsr.metadata() for key in sf_tsr.keys(): tensor = sf_tsr.get_tensor(key) if key not in slices_dict: slices_dict[key] = [] slices_dict[key].append(tensor) merged_tensors = {} for key, slices in slices_dict.items(): if len(slices) == 1: merged_tensors[key] = slices[0] else: # Simple heuristic: find dim with mismatched size ref_shape = slices[0].shape concat_dim = None for dim in range(len(ref_shape)): dim_sizes = [s.shape[dim] for s in slices] if len(set(dim_sizes)) > 1: concat_dim = dim break if concat_dim is None: concat_dim = 0 # fallback merged_tensors[key] = torch.cat(slices, dim=concat_dim) print(f"Merged key '{key}' from {len(slices)} slices along dim {concat_dim}") os.makedirs(os.path.dirname(output_file), exist_ok=True) save_file(merged_tensors, output_file, metadata) print(f"Merged {len(sftsr_files)} shards into {output_file}") def get_safetensor_files(directory): safetensors_files = [] for root, _, files in os.walk(directory): for file in files: if file.endswith(".safetensors"): safetensors_files.append(os.path.join(root, file)) return safetensors_files if __name__ == "__main__": safetensor_files = get_safetensor_files("./shards") print(f"The following shards/chunks will be merged: {safetensor_files}") default_output = "./output/merged_model.safetensors" user_output = input(f"Enter output file path [{default_output}]: ").strip() output_file = user_output if user_output else default_output merge_safetensor_files(safetensor_files, output_file=output_file)