import os 
import sys
import glob 
import yaml
from datetime import datetime
from tqdm import tqdm
import numpy as np
import argparse

# export PATH and variables
os.environ["PATH"] += ":/scratch/ssd004/scratch/calvinyu/git-lfs-3.5.1/bin"
os.environ["GIT_THETA_MANUAL_MERGE"] = "True"
os.environ["GIT_THETA_CHECKPOINT_TYPE"] = "safetensors"


if __name__ == "__main__":
    scratch_dir = "/scratch/ssd004/scratch/calvinyu/"
    git_dir = os.path.join(scratch_dir, "git_merge/")

    ckpts = glob.glob(os.path.join(scratch_dir, "git_merge/stats/", "*.safetensors"))
    print(f"total checkpoints: {len(ckpts)}")

    assert os.path.exists(git_dir)

    # get all available models
    ckpts = [ckpt.replace(git_dir, "") for ckpt in ckpts]
    an = [ckpt for ckpt in ckpts if "ancestor" in ckpt][0]
    ab = sorted([ckpt for ckpt in ckpts if "AB.safetensors" in ckpt])
    ab_trim = [ckpt.replace("AB.safetensors", "AB_trimmed.safetensors") for ckpt in ab]
    ab_fisher = [ckpt.replace("AB.safetensors", "AB_fisher.safetensors") for ckpt in ab]
    ab_gram = [ckpt.replace("AB.safetensors", "AB_gram.safetensors") for ckpt in ab]
    assert all([os.path.exists(os.path.join(git_dir, ckpt)) for ckpt in ab + ab_trim + ab_fisher + ab_gram])
    
    merges = ["mats"]

    # merge all models dir
    ts = datetime.now().strftime('%B-%d-%Y-%H-%M-%S').replace(' ', '_')
    merge_dir = os.path.join(git_dir, f"merged/seq/")
    os.makedirs(merge_dir, exist_ok=True)

    # merge log as yaml 
    merge_log_path = os.path.join(merge_dir, f"logs/log.yaml")
    merge_log = {}

    # relative path
    merge_dir = merge_dir.replace(git_dir, "")
    skip_if_exists = True

    # all models
    merge_cmds = []; commit_cmds = []
    for merge in merges:
        merge_cmd = "git-theta-merge-cli"
        
        if merge == "mats":
            best_iterations = 10
            merge_cmd += f" --x:iterations {best_iterations}"

            for mi in range(2, len(ab) + 1, 2):
                _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
                _merge_cmd += " --ancestor " + os.path.join(merge_dir, f"task_arithmetic_{mi}.safetensors")
                _merge_cmd += " --aux_data " + " ".join(ab_gram[:mi])
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
                _merge_cmd += " --merge covariance-mats-gpu"
                merge_cmds.append(_merge_cmd)

                ckpts_to_commit = ab[mi-2:mi] + [os.path.join(merge_dir, f"task_arithmetic_{mi}.safetensors")] + ab_gram[mi-2:mi]
                commit_cmd = "sleep 1"
                for _ckpt_to_commit in ckpts_to_commit:
                    _commit_cmd = f"cd {git_dir} && git theta track {_ckpt_to_commit} && git theta add {_ckpt_to_commit} && git commit -m 'Add {_ckpt_to_commit}'"
                    commit_cmd = f"{commit_cmd} && {_commit_cmd}"
                commit_cmds.append(commit_cmd)

            print(f"added mats, total: {len(merge_cmds)}")
            
        else:
            raise NotImplementedError(f"Merge method {merge} not implemented.")
    
    run_merge_cmds = []; run_commit_cmds = []; skip_if_exists = True
    for merge_cmd, commit_cmd in zip(merge_cmds, commit_cmds):
        # skip if output exists
        output = merge_cmd.split("--output")[-1].split("--merge")[0].strip()
        if skip_if_exists and os.path.exists(os.path.join(git_dir, output)):
            print(f"skipping {output} \n {merge_cmd}")
            continue
        else:
            run_merge_cmds.append(merge_cmd)
            run_commit_cmds.append(commit_cmd)
    
    breakpoint()

    # execute merge
    for merge_cmd, commit_cmd in tqdm(zip(run_merge_cmds, run_commit_cmds)):
        exit_code = 0
        merge_cmd = f"cd {git_dir} && {merge_cmd}"
        commit_merge_cmd = f"{commit_cmd} && {merge_cmd}"
        print(f"COMMANDS: \n {commit_merge_cmd} \n {merge_cmd}")
        exit_code += os.system(commit_merge_cmd)
        if exit_code != 0:
            breakpoint()
        else:
            merge_log[merge] = {
                "cmd": merge_cmd,
                "exit_code": exit_code,
                "output": output
            }
        print(merge_cmd)


"""
sleep 1 && cd /scratch/ssd004/scratch/calvinyu/git_merge/ && git theta track stats/stats_clipart_cloth_AB.safetensors && git theta add stats/stats_clipart_cloth_AB.safetensors && git commit -m 'Add stats/stats_clipart_cloth_AB.safetensors' && cd /scratch/ssd004/scratch/calvinyu/git_merge/ && git theta track stats/stats_clipart_furniture_AB.safetensors && git theta add stats/stats_clipart_furniture_AB.safetensors && git commit -m 'Add stats/stats_clipart_furniture_AB.safetensors' && cd /scratch/ssd004/scratch/calvinyu/git_merge/ && git theta track merged/seq/task_arithmetic_2.safetensors && git theta add merged/seq/task_arithmetic_2.safetensors && git commit -m 'Add merged/seq/task_arithmetic_2.safetensors' && cd /scratch/ssd004/scratch/calvinyu/git_merge/ && git theta track stats/stats_clipart_cloth_AB_gram.safetensors && git theta add stats/stats_clipart_cloth_AB_gram.safetensors && git commit -m 'Add stats/stats_clipart_cloth_AB_gram.safetensors' && cd /scratch/ssd004/scratch/calvinyu/git_merge/ && git theta track stats/stats_clipart_furniture_AB_gram.safetensors && git theta add stats/stats_clipart_furniture_AB_gram.safetensors && git commit -m 'Add stats/stats_clipart_furniture_AB_gram.safetensors'
"""