import os, safetensors
import numpy as np
from safetensors import safe_open
from safetensors.torch import save_file



ckpt_path = "/scratch/ssd004/scratch/yashkant/stats/stats_sketch_sky_transportation_AB.safetensors"

# read ckpt
tensors = {}
with safe_open(ckpt_path, framework="pt", device=0) as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k) * 0.0

save_path = os.path.join(os.path.dirname(ckpt_path), "an.safetensors")
save_file(tensors, save_path)
print(f"Saved {save_path}")

