import os 
import sys
import glob 
import yaml
from datetime import datetime
from tqdm import tqdm
import numpy as np


# export PATH and variables
os.environ["PATH"] += ":/scratch/ssd004/scratch/calvinyu/git-lfs-3.5.1/bin"
os.environ["GIT_THETA_MANUAL_MERGE"] = "True"
os.environ["GIT_THETA_CHECKPOINT_TYPE"] = "safetensors"


if __name__ == "__main__":
    scratch_dir = "/scratch/ssd004/scratch/calvinyu/"
    git_dir = os.path.join(scratch_dir, "git_merge/")

    ckpts = glob.glob(os.path.join(scratch_dir, "git_merge/stats/", "*.safetensors"))
    print(f"total checkpoints: {len(ckpts)}")

    assert os.path.exists(git_dir)

    # get all available models
    ckpts = [ckpt.replace(git_dir, "") for ckpt in ckpts]
    an = [ckpt for ckpt in ckpts if "ancestor" in ckpt][0]
    ab = sorted([ckpt for ckpt in ckpts if "AB.safetensors" in ckpt])
    ab_trim = [ckpt.replace("AB.safetensors", "AB_trimmed.safetensors") for ckpt in ab]
    ab_fisher = [ckpt.replace("AB.safetensors", "AB_fisher.safetensors") for ckpt in ab]
    ab_gram = [ckpt.replace("AB.safetensors", "AB_gram.safetensors") for ckpt in ab]
    assert all([os.path.exists(os.path.join(git_dir, ckpt)) for ckpt in ab + ab_trim + ab_fisher + ab_gram])
    
    # get all available merge methods
    # merges = ["average", "task_arithmetic", "dare", "ties", "fisher", "regmean", "mats"]
    merges = ["dare", "mats"]

    # merge all models dir
    ts = datetime.now().strftime('%B-%d-%Y-%H-%M-%S').replace(' ', '_')
    merge_dir = os.path.join(git_dir, f"merged/all/")
    os.makedirs(merge_dir, exist_ok=True)

    # merge log as yaml 
    merge_log_path = os.path.join(merge_dir, f"logs/log.yaml")
    merge_log = {}

    # relative path
    merge_dir = merge_dir.replace(git_dir, "")
    skip_if_exists = True

    # all models
    merge_cmds = []
    for merge in merges:
        merge_cmd = "git-theta-merge-cli"
        
        if merge == "average":
            merge_cmd += " --models " + " ".join(ab)
            merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}.safetensors")
            merge_cmd += " --merge average-gpu"
            merge_cmds.append(merge_cmd)
            print(f"added average, total: {len(merge_cmds)}")
        
        elif merge == "task_arithmetic":
            merge_cmd += " --models " + " ".join(ab)
            merge_cmd += " --ancestor " + an
            # sweep merge_lambda
            for ml in np.linspace(0.1, 1.0, 10):
                _merge_cmd = merge_cmd + " --x:merge_lambda " + str(ml)
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_lambda_{ml}.safetensors")
                _merge_cmd += " --merge task-arithmetic-gpu"
                merge_cmds.append(_merge_cmd)
            print(f"added task_arithmetic, total: {len(merge_cmds)}")

        elif merge == "dare":
            best_ta_lambda = 0.1 # from task_arithmetic
            merge_cmd += " --models " + " ".join(ab)
            merge_cmd += " --ancestor " + an
            merge_cmd += f" --x:merge_lambda {best_ta_lambda}"

            # sweep dropout_probability
            for dp in np.linspace(0.0, 0.9, 10):
                _merge_cmd = merge_cmd + " --x:dropout_probability " + str(dp)
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_lambda_{best_ta_lambda}_dropout_{dp}.safetensors")
                _merge_cmd += " --merge dare-task-arithmetic-gpu"
                merge_cmds.append(_merge_cmd)
            print(f"added dare, total: {len(merge_cmds)}")

        elif merge == "ties":
            merge_cmd += " --models " + " ".join(ab)
            merge_cmd += " --aux_data " + " ".join(ab_trim)
            merge_cmd += " --ancestor " + an

            # sweep merge_lambda
            for ml in np.linspace(0.1, 1.0, 10):
                _merge_cmd = merge_cmd + " --x:merge_lambda " + str(ml)
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_lambda_{ml}.safetensors")
                _merge_cmd += " --merge ties-gpu"
                merge_cmds.append(_merge_cmd)
            print(f"added ties, total: {len(merge_cmds)}")
            
        elif merge == "fisher":
            merge_cmd += " --models " + " ".join(ab)
            merge_cmd += " --aux_data " + " ".join(ab_fisher)
            merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}.safetensors")
            merge_cmd += " --merge fisher-gpu"
            merge_cmds.append(merge_cmd)
            print(f"added fisher, total: {len(merge_cmds)}")
        
        elif merge == "regmean":
            merge_cmd += " --models " + " ".join(ab)
            merge_cmd += " --aux_data " + " ".join(ab_gram)

            # sweep merge_lambda
            for ml in np.linspace(0.0, 0.9, 10):
                _merge_cmd = merge_cmd + " --x:merge_lambda " + str(ml)
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_lambda_{ml}.safetensors")
                _merge_cmd += " --merge reg-mean-gpu"
                merge_cmds.append(_merge_cmd)
            print(f"added regmean, total: {len(merge_cmds)}")

        elif merge == "mats":
            merge_cmd += " --models " + " ".join(ab)
            merge_cmd += " --aux_data " + " ".join(ab_gram)
            merge_cmd += " --ancestor " + "stats/task_arithmetic_lambda_0.1.safetensors" # from task_arithmetic
            
            # sweep iterations
            for it in np.linspace(10, 100, 10):
                it = int(it)
                _merge_cmd = merge_cmd + " --x:iterations " + str(it)
                _merge_cmd += " --output " + os.path.join(merge_dir, f"{merge}_iterations_{it}.safetensors")
                _merge_cmd += " --merge covariance-mats-gpu"
                merge_cmds.append(_merge_cmd)
            print(f"added mats, total: {len(merge_cmds)}")
            
        else:
            raise NotImplementedError(f"Merge method {merge} not implemented.")
    
    run_merge_cmds = []; skip_if_exists = True
    for merge_cmd in (merge_cmds):
        # skip if output exists
        output = merge_cmd.split("--output")[-1].split("--merge")[0].strip()
        if skip_if_exists and os.path.exists(os.path.join(git_dir, output)):
            print(f"skipping {output}")
            continue
        else:
            run_merge_cmds.append(merge_cmd)
    
    
    breakpoint()
    # execute merge
    for merge_cmd in tqdm(run_merge_cmds):
        exit_code = 0
        merge_cmd = f"cd {git_dir} && {merge_cmd}"
        exit_code += os.system(merge_cmd)
        if exit_code != 0:
            breakpoint()
            sys.exit(exit_code)
        else:
            merge_log[merge] = {
                "cmd": merge_cmd,
                "exit_code": exit_code,
                "output": output
            }
        print(merge_cmd)

    # TODO: sequential models


"""
PATH+=:/scratch/ssd004/scratch/calvinyu/git-lfs-3.5.1/bin
export GIT_THETA_MANUAL_MERGE=True
export GIT_THETA_CHECKPOINT_TYPE=safetensors

TESTING:

git-theta-merge-cli --models stats/stats_sketch_water_transportation_AB_trimmed.safetensors stats/stats_sketch_sky_transportation_AB.safetensors --output test.safetensors --merge average-gpu

git-theta-merge-cli --models stats/stats_clipart_cloth_AB.safetensors stats/stats_clipart_furniture_AB.safetensors stats/stats_clipart_mammal_AB.safetensors --output test.safetensors --merge average-gpu

git-theta-merge-cli --models stats/stats_clipart_cloth_AB.safetensors stats/stats_clipart_furniture_AB.safetensors stats/stats_clipart_mammal_AB.safetensors --ancestor stats/ancestor.safetensors --x:merge_lambda 0.1 --output test3.safetensors --merge task-arithmetic-gpu

git-theta-merge-cli --ancestor stats/ancestor.safetensors --models stats/stats_clipart_cloth_AB.safetensors stats/stats_clipart_furniture_AB.safetensors stats/stats_clipart_mammal_AB.safetensors --aux_data stats/stats_clipart_cloth_AB_trimmed.safetensors stats/stats_clipart_furniture_AB_trimmed.safetensors stats/stats_clipart_mammal_AB_trimmed.safetensors --x:merge_lambda 0.2 --output test4.safetensors --merge ties-gpu

git-theta-merge-cli --models stats/stats_clipart_cloth_AB.safetensors stats/stats_clipart_furniture_AB.safetensors stats/stats_clipart_mammal_AB.safetensors --aux_data stats/stats_clipart_cloth_AB_fisher.safetensors stats/stats_clipart_furniture_AB_fisher.safetensors stats/stats_clipart_mammal_AB_fisher.safetensors --output test8.safetensors --merge fisher-gpu

git-theta-merge-cli --models stats/stats_clipart_cloth_AB.safetensors stats/stats_clipart_furniture_AB.safetensors stats/stats_clipart_mammal_AB.safetensors --aux_data stats/stats_clipart_cloth_AB_gram.safetensors stats/stats_clipart_furniture_AB_gram.safetensors stats/stats_clipart_mammal_AB_gram.safetensors --x:merge_lambda 0.1 --output test6.safetensors --merge reg-mean-gpu

git-theta-merge-cli --models stats/stats_clipart_cloth_AB.safetensors stats/stats_clipart_furniture_AB.safetensors stats/stats_clipart_mammal_AB.safetensors --aux_data stats/stats_clipart_cloth_AB_gram.safetensors stats/stats_clipart_furniture_AB_gram.safetensors stats/stats_clipart_mammal_AB_gram.safetensors --ancestor stats/ancestor.safetensors --x:iterations 100 --output test7.safetensors --merge covariance-mats-gpu

# inference
# hyperparam: mats_merged_model (24 fused): 24 indomain (lambda selection)
# results: 24 indomain and 120 outofdomain
# 10 mins * 24 = 240 mins (4 hrs) * 5 merge methods = 240 mins * 5 = 1200 mins (20 hrs)

# first get clip-scores (then get fid scores)
# get hyperparams (check possiblity of outdomain), get results (clip, in and outdomain)
# get dimensions for flops to brian, get 

cd /scratch/ssd004/scratch/calvinyu/git_merge/ && git-theta-merge-cli --models stats/stats_clipart_cloth_AB.safetensors stats/stats_clipart_furniture_AB.safetensors stats/stats_clipart_mammal_AB.safetensors --aux_data stats/stats_clipart_cloth_AB_gram.safetensors stats/stats_clipart_furniture_AB_gram.safetensors stats/stats_clipart_mammal_AB_gram.safetensors --x:merge_lambda 0.2 --output merged/all/regmean_lambda_02.safetensors --merge reg-mean-gpu


'cd /scratch/ssd004/scratch/calvinyu/git_merge/ && git-theta-merge-cli --models stats/stats_clipart_cloth_AB.safetensors stats/stats_clipart_furniture_AB.safetensors stats/stats_clipart_mammal_AB.safetensors stats/stats_clipart_tool_AB.safetensors stats/stats_infograph_building_AB.safetensors stats/stats_infograph_electricity_AB.safetensors stats/stats_infograph_human_body_AB.safetensors stats/stats_infograph_office_AB.safetensors stats/stats_painting_cold_blooded_AB.safetensors stats/stats_painting_food_AB.safetensors stats/stats_painting_nature_AB.safetensors stats/stats_painting_road_transportation_AB.safetensors stats/stats_quickdraw_fruit_AB.safetensors stats/stats_quickdraw_music_AB.safetensors stats/stats_quickdraw_sport_AB.safetensors stats/stats_quickdraw_tree_AB.safetensors stats/stats_real_bird_AB.safetensors stats/stats_real_kitchen_AB.safetensors stats/stats_real_shape_AB.safetensors stats/stats_real_vegatable_AB.safetensors stats/stats_sketch_insect_AB.safetensors stats/stats_sketch_others_AB.safetensors stats/stats_sketch_sky_transportation_AB.safetensors stats/stats_sketch_water_transportation_AB.safetensors --aux_data stats/stats_clipart_cloth_AB_gram.safetensors stats/stats_clipart_furniture_AB_gram.safetensors stats/stats_clipart_mammal_AB_gram.safetensors stats/stats_clipart_tool_AB_gram.safetensors stats/stats_infograph_building_AB_gram.safetensors stats/stats_infograph_electricity_AB_gram.safetensors stats/stats_infograph_human_body_AB_gram.safetensors stats/stats_infograph_office_AB_gram.safetensors stats/stats_painting_cold_blooded_AB_gram.safetensors stats/stats_painting_food_AB_gram.safetensors stats/stats_painting_nature_AB_gram.safetensors stats/stats_painting_road_transportation_AB_gram.safetensors stats/stats_quickdraw_fruit_AB_gram.safetensors stats/stats_quickdraw_music_AB_gram.safetensors stats/stats_quickdraw_sport_AB_gram.safetensors stats/stats_quickdraw_tree_AB_gram.safetensors stats/stats_real_bird_AB_gram.safetensors stats/stats_real_kitchen_AB_gram.safetensors stats/stats_real_shape_AB_gram.safetensors stats/stats_real_vegatable_AB_gram.safetensors stats/stats_sketch_insect_AB_gram.safetensors stats/stats_sketch_others_AB_gram.safetensors stats/stats_sketch_sky_transportation_AB_gram.safetensors stats/stats_sketch_water_transportation_AB_gram.safetensors --x:merge_lambda 0.4 --output merged/all/regmean_lambda_0.4.safetensors --merge reg-mean-gpu'

"""