import os 
import sys
import glob 
import yaml
from datetime import datetime
from tqdm import tqdm
import numpy as np
import argparse

# export PATH and variables
os.environ["PATH"] += ":/scratch/ssd004/scratch/calvinyu/git-lfs-3.5.1/bin"
os.environ["GIT_THETA_MANUAL_MERGE"] = "True"
os.environ["GIT_THETA_CHECKPOINT_TYPE"] = "safetensors"


if __name__ == "__main__":
    # argparser = argparse.ArgumentParser()
    # argparser.add_argument("--ta", type=str, default="100")
    # args = argparser.parse_args()

    scratch_dir = "/scratch/ssd004/scratch/calvinyu/"
    git_dir = os.path.join(scratch_dir, "git_merge/")

    ckpts = glob.glob(os.path.join(scratch_dir, "git_merge/stats/", "*.safetensors"))
    print(f"total checkpoints: {len(ckpts)}")

    assert os.path.exists(git_dir)

    # get all available models
    ckpts = [ckpt.replace(git_dir, "") for ckpt in ckpts]
    an = [ckpt for ckpt in ckpts if "ancestor" in ckpt][0]
    ab = sorted([ckpt for ckpt in ckpts if "AB.safetensors" in ckpt])
    ab_trim = [ckpt.replace("AB.safetensors", "AB_trimmed.safetensors") for ckpt in ab]
    ab_fisher = [ckpt.replace("AB.safetensors", "AB_fisher.safetensors") for ckpt in ab]
    ab_gram = [ckpt.replace("AB.safetensors", "AB_gram.safetensors") for ckpt in ab]
    assert all([os.path.exists(os.path.join(git_dir, ckpt)) for ckpt in ab + ab_trim + ab_fisher + ab_gram])
    
    # get all available merge methods
    # merges = ["average", "task_arithmetic", "dare", "ties", "fisher", "regmean", "mats"]
    # merges = ["average", "task_arithmetic", "ties"]
    merges = ["mats"]
    # merges = ["dare"]

    # merge all models dir
    ts = datetime.now().strftime('%B-%d-%Y-%H-%M-%S').replace(' ', '_')
    merge_dir = os.path.join(git_dir, f"merged/seq/")
    os.makedirs(merge_dir, exist_ok=True)

    # merge log as yaml 
    merge_log_path = os.path.join(merge_dir, f"logs/log.yaml")
    merge_log = {}

    # relative path
    merge_dir = merge_dir.replace(git_dir, "")
    skip_if_exists = True

    # all models
    merge_cmds = []
    for merge in merges:
        merge_cmd = "git-theta-merge-cli"
        
        if merge == "average":
            merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}.safetensors")

            # sequential merges
            for mi in range(2, len(ab) + 1, 2):
                _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
                _merge_cmd += " --merge average-gpu"
                merge_cmds.append(_merge_cmd)

            print(f"added average, total: {len(merge_cmds)}")
        
        elif merge == "task_arithmetic":
            merge_cmd += " --ancestor " + an
            best_ta_lambda = 0.3
            merge_cmd += f" --x:merge_lambda {best_ta_lambda}"
            for mi in range(2, len(ab) + 1, 2):
                _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
                _merge_cmd += " --merge task-arithmetic-gpu"
                merge_cmds.append(_merge_cmd)

            print(f"added task_arithmetic, total: {len(merge_cmds)}")

        elif merge == "dare":
            best_ta_lambda = 0.3 # from task_arithmetic
            best_dp = 0.9 # from dare
            merge_cmd += " --ancestor " + an
            merge_cmd += f" --x:merge_lambda {best_ta_lambda}"
            merge_cmd += f" --x:dropout_probability {best_dp}"

            for mi in range(2, len(ab) + 1, 2):
                _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
                _merge_cmd += " --merge dare-task-arithmetic-gpu"
                merge_cmds.append(_merge_cmd)

            print(f"added dare, total: {len(merge_cmds)}")

        elif merge == "ties":
            merge_cmd += " --ancestor " + an
            best_ties_lambda = 1.0
            merge_cmd += f" --x:merge_lambda {best_ties_lambda}"

            for mi in range(2, len(ab) + 1, 2):
                _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
                _merge_cmd += " --aux_data " + " ".join(ab_trim[:mi])
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
                _merge_cmd += " --merge ties-gpu"
                merge_cmds.append(_merge_cmd)
            print(f"added ties, total: {len(merge_cmds)}")
            
        elif merge == "fisher":
            
            for mi in range(2, len(ab) + 1, 2):
                _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
                _merge_cmd += " --aux_data " + " ".join(ab_fisher[:mi])
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
                _merge_cmd += " --merge fisher-gpu"
                merge_cmds.append(_merge_cmd)
            print(f"added fisher, total: {len(merge_cmds)}")
        
        elif merge == "regmean":
            best_rm_lambda = 0.9
            merge_cmd += f" --x:merge_lambda {best_rm_lambda}"

            for mi in range(2, len(ab) + 1, 2):
                _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
                _merge_cmd += " --aux_data " + " ".join(ab_gram[:mi])
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
                _merge_cmd += " --merge reg-mean-gpu"
                merge_cmds.append(_merge_cmd)

            print(f"added regmean, total: {len(merge_cmds)}")

        elif merge == "mats":
            best_iterations = 10
            merge_cmd += f" --x:iterations {best_iterations}"

            for mi in range(2, len(ab) + 1, 2):
                _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
                _merge_cmd += " --ancestor " + os.path.join(merge_dir, f"task_arithmetic_{mi}.safetensors")
                _merge_cmd += " --aux_data " + " ".join(ab_gram[:mi])
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
                _merge_cmd += " --merge covariance-mats-gpu"
                merge_cmds.append(_merge_cmd)

            print(f"added mats, total: {len(merge_cmds)}")
            
        else:
            raise NotImplementedError(f"Merge method {merge} not implemented.")
    
    run_merge_cmds = []
    for merge_cmd in (merge_cmds):
        # skip if output exists
        output = merge_cmd.split("--output")[-1].split("--merge")[0].strip()
        if skip_if_exists and os.path.exists(os.path.join(git_dir, output)):
            print(f"skipping {output} \n {merge_cmd}")
            continue
        else:
            run_merge_cmds.append(merge_cmd)
    
    # execute merge
    for merge_cmd in tqdm(run_merge_cmds):
        exit_code = 0
        merge_cmd = f"cd {git_dir} && {merge_cmd}"
        exit_code += os.system(merge_cmd)
        if exit_code != 0:
            continue
            # sys.exit(exit_code)
        else:
            merge_log[merge] = {
                "cmd": merge_cmd,
                "exit_code": exit_code,
                "output": output
            }
        print(merge_cmd)


