import os 
import sys
import glob 
import yaml
from datetime import datetime
from tqdm import tqdm
import numpy as np
import argparse

# export PATH and variables
os.environ["PATH"] += ":/scratch/ssd004/scratch/calvinyu/git-lfs-3.5.1/bin"
os.environ["GIT_THETA_MANUAL_MERGE"] = "True"
os.environ["GIT_THETA_CHECKPOINT_TYPE"] = "safetensors"

scratch_dir = "/scratch/ssd004/scratch/calvinyu/"
git_dir = os.path.join(scratch_dir, "git_merge/")


from huggingface_hub import HfApi, HfFileSystem
from huggingface_hub import HfApi, hf_hub_download, snapshot_download

os.system("huggingface-cli login --token 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")

api = HfApi()
fs = HfFileSystem()
uploaded_files = fs.ls("yashkant/evalmerge/seq", detail=False)
uploaded_ckpts = [os.path.basename(f) for f in uploaded_files if f.endswith(".safetensors")]

commited_ckpts = open("/scratch/ssd004/scratch/calvinyu/git_merge/.gitattributes", "r") if os.path.exists("/scratch/ssd004/scratch/calvinyu/git_merge/.gitattributes") else []
commited_ckpts = [line.split(" ")[0] for line in commited_ckpts.readlines()] if commited_ckpts else []


# # download task arithmetic checkpoints
# for file in uploaded_ckpts:
#     if "task_arithmetic" not in file:
#         continue
#     file = f"seq/{file}"
#     hf_hub_download(
#         repo_id="yashkant/evalmerge",
#         filename=file,
#         repo_type="model",
#         local_dir="/scratch/ssd004/scratch/calvinyu/git_merge/merged/",
#         local_dir_use_symlinks=False,
#     )


def commit_ckpts(ckpts_to_commit):
    print(f"ADD: {ckpts_to_commit}")
    for _ckpt_to_commit in ckpts_to_commit:
        if _ckpt_to_commit in commited_ckpts:
            continue
        commit_cmd = "sleep 1"
        _commit_cmd = f"cd {git_dir} && git theta track {_ckpt_to_commit} && git theta add {_ckpt_to_commit} && git commit -m 'Add {_ckpt_to_commit}'"
        commit_cmd = f"{commit_cmd} && {_commit_cmd}"
        exit_code = os.system(commit_cmd)
        if exit_code != 0:
            sys.exit(0)

def output_exists(cmd):
    output = cmd.split("--output")[-1].split("--merge")[0].strip()
    local_exists = os.path.exists(os.path.join(git_dir, output))
    upload_exists = os.path.basename(output) in uploaded_ckpts
    exists = local_exists or upload_exists
    # if not exists: breakpoint()
    return exists

if __name__ == "__main__":

    ckpts = glob.glob(os.path.join(scratch_dir, "git_merge/stats/", "*.safetensors"))
    print(f"total checkpoints: {len(ckpts)}")

    assert os.path.exists(git_dir)

    # get all available models
    ckpts = [ckpt.replace(git_dir, "") for ckpt in ckpts]
    an = [ckpt for ckpt in ckpts if "ancestor" in ckpt][0]
    ab = sorted([ckpt for ckpt in ckpts if "AB.safetensors" in ckpt])
    ab_trim = [ckpt.replace("AB.safetensors", "AB_trimmed.safetensors") for ckpt in ab]
    ab_fisher = [ckpt.replace("AB.safetensors", "AB_fisher.safetensors") for ckpt in ab]
    ab_gram = [ckpt.replace("AB.safetensors", "AB_gram.safetensors") for ckpt in ab]

    ta = [ckpt.replace(f"{git_dir}/", "") for ckpt in glob.glob( f"{git_dir}/merged/seq/*.safetensors") if "task_arithmetic" in ckpt]
    
    assert all([os.path.exists(os.path.join(git_dir, ckpt)) for ckpt in ab + ab_trim + ab_fisher + ab_gram])
    
    # merges = ["average", "task_arithmetic", "dare", "ties", "fisher", "regmean", "mats"]

    # merge all models dir
    ts = datetime.now().strftime('%B-%d-%Y-%H-%M-%S').replace(' ', '_')
    merge_dir = os.path.join(git_dir, f"merged/seq/")
    os.makedirs(merge_dir, exist_ok=True)

    # merge log as yaml 
    merge_log_path = os.path.join(merge_dir, f"logs/log.yaml")
    merge_log = {}

    # relative path
    merge_dir = merge_dir.replace(git_dir, "")
    skip_if_exists = True
    merge_cmd = "git-theta-merge-cli"
    an_add = False
    step_size = 1

    for mi in range(2, len(ab) + 1, step_size):
        step_size = 2 if mi == 2 else 1

        # commit all checkpoints
        ckpts_to_commit = ab[mi-step_size:mi] 
        if not an_add:
            ckpts_to_commit += [an]
            an_add = True
        commit_ckpts(ckpts_to_commit)

        # average merge
        # merge = "average"
        # _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
        # _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
        # _merge_cmd += " --merge average-gpu"
        # if not output_exists(_merge_cmd):
        #     os.system(f"cd {git_dir} && {_merge_cmd}")
        #     print(f"average, mi: {mi}")

        # task_arithmetic
        # merge = "task_arithmetic"
        # _merge_cmd = merge_cmd + " --ancestor " + an
        # best_ta_lambda = 0.1
        # _merge_cmd += f" --x:merge_lambda {best_ta_lambda}"
        # _merge_cmd = _merge_cmd + " --models " + " ".join(ab[:mi])
        # _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
        # _merge_cmd += " --merge task-arithmetic-gpu"
        # if not output_exists(_merge_cmd):
        #     os.system(f"cd {git_dir} && {_merge_cmd}")
        #     print(f"task_arithmetic, mi: {mi}")

        # dare-task_arithmetic
        # merge = "dare"
        # _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
        # _merge_cmd += " --ancestor " + an
        # best_ta_lambda = 0.1 # from task_arithmetic
        # best_dp = 0.1 # from dare
        # _merge_cmd += f" --x:merge_lambda {best_ta_lambda}"
        # _merge_cmd += f" --x:dropout_probability {best_dp}"
        # _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
        # _merge_cmd += " --merge dare-task-arithmetic-gpu"
        # if not output_exists(_merge_cmd):
        #     os.system(f"cd {git_dir} && {_merge_cmd}")
        #     print(f"dare, mi: {mi}")

        # commit all checkpoints
        # ckpts_to_commit = ab_fisher[mi-step_size:mi]
        # commit_ckpts(ckpts_to_commit)

        # fisher
        # merge = "fisher"
        # _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
        # _merge_cmd += " --aux_data " + " ".join(ab_fisher[:mi])
        # _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
        # _merge_cmd += " --merge fisher-gpu"
        # if not output_exists(_merge_cmd):
        #     os.system(f"cd {git_dir} && {_merge_cmd}")
        #     print(f"fisher, mi: {mi}")

        # commit all checkpoints
        # ckpts_to_commit = ab_trim[mi-step_size:mi]
        # commit_ckpts(ckpts_to_commit)

        # ties
        # merge = "ties"
        # _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
        # _merge_cmd += " --aux_data " + " ".join(ab_trim[:mi])
        # _merge_cmd += " --ancestor " + an
        # best_ties_lambda = .4
        # _merge_cmd += f" --x:merge_lambda {best_ties_lambda}"
        # _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
        # _merge_cmd += " --merge ties-gpu"
        # if not output_exists(_merge_cmd):
        #     os.system(f"cd {git_dir} && {_merge_cmd}")
        #     print(f"ties, mi: {mi}")

        # commit all checkpoints
        ckpts_to_commit = ab_gram[mi-step_size:mi]
        commit_ckpts(ckpts_to_commit)
                
        # regmean
        # merge = "regmean"
        # _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
        # _merge_cmd += " --aux_data " + " ".join(ab_gram[:mi])
        # best_rm_lambda = 0.0
        # _merge_cmd += f" --x:merge_lambda {best_rm_lambda}"
        # _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
        # _merge_cmd += " --merge reg-mean-gpu"
        # if not output_exists(_merge_cmd):
        #     os.system(f"cd {git_dir} && {_merge_cmd}")
        #     print(f"regmean, mi: {mi}")

        # commit all checkpoints
        ckpts_to_commit = [f"merged/seq/task_arithmetic_{mi}.safetensors"]
        commit_ckpts(ckpts_to_commit)

        # mats (at the end)
        merge = "mats"
        best_iterations = 10
        _merge_cmd = merge_cmd + " --models " + " ".join(ab[:mi])
        _merge_cmd += f" --x:iterations {best_iterations}"
        _merge_cmd += " --aux_data " + " ".join(ab_gram[:mi])
        _merge_cmd += " --ancestor " + os.path.join(merge_dir, f"task_arithmetic_{mi}.safetensors")
        _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_{mi}.safetensors")
        _merge_cmd += " --merge covariance-mats-gpu"
        if not output_exists(_merge_cmd):
            os.system(f"cd {git_dir} && {_merge_cmd}")
            print(f"mats, mi: {mi}")

