File size: 389 Bytes
170b97b
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from safetensors import safe_open

from safetensors.torch import save_file


if __name__ == "__main__":
    tensors = {}
    with safe_open("model.safetensors", framework="pt", device=0) as f:
        for k in f.keys():
            if k.startswith("language_model."):
                tensors[k.split("language_model.")[1]] = f.get_tensor(k)

    save_file(tensors, "model_fix.safetensors")