import os 
import sys
import glob 

from tqdm import tqdm


# export PATH and variables
os.environ["PATH"] += ":/scratch/ssd004/scratch/calvinyu/git-lfs-3.5.1/bin"
os.environ["GIT_THETA_MANUAL_MERGE"] = "True"
os.environ["GIT_THETA_CHECKPOINT_TYPE"] = "safetensors"


if __name__ == "__main__":
    scratch_dir = "/scratch/ssd004/scratch/calvinyu/"
    ckpts = glob.glob(os.path.join(scratch_dir, "git_merge/stats/", "*.safetensors"))
    print(f"total checkpoints: {len(ckpts)}")
    

    git_dir = os.path.join(scratch_dir, "git_merge/")
    assert os.path.exists(git_dir)
    ckpts = [ckpt.replace(git_dir, "") for ckpt in ckpts]

    # read gitattributes and only commit new files
    # gitattributes_file = os.path.join(git_dir, ".gitattributes")
    # if os.path.exists(gitattributes_file):
    #     with open(gitattributes_file, "r") as f:
    #         gitattributes = f.readlines()
    #     gitattributes = [line.split()[0] for line in gitattributes]
    #     ckpts = [ckpt for ckpt in ckpts if ckpt not in gitattributes]

    ckpts = [ckpt for ckpt in ckpts if "trimmed" not in ckpt]
    # ckpts = [ckpt for ckpt in ckpts if "AB." not in ckpt]
    ckpts = [ckpt for ckpt in ckpts if "fisher" not in ckpt]
    # ckpts = [ckpt for ckpt in ckpts if "gram" not in ckpt]
    # ckpts = [ckpt for ckpt in ckpts if "ancestor" not in ckpt]
    # breakpoint()


    # add task arithmetic models 
    # ckpts += glob.glob(os.path.join(git_dir, f"merged/seq/task_arithmetic_*.safetensors"))

    # add task arithmetic ckpt (used as ancestor) by mats
    print(f"filtered checkpoints: {len(ckpts)}")
    
    # track, add, and commit
    exit_code = 0
    for ckpt in tqdm(ckpts):

        exit_code += os.system(f"cd {git_dir} && git theta track {ckpt}")
        exit_code += os.system(f"cd {git_dir} && git theta add {ckpt}")
        exit_code += os.system(f"cd {git_dir} && git commit -m 'Add {ckpt}'")

        if exit_code != 0:
            sys.exit(exit_code)

