import os, sys
import glob
from tqdm import tqdm
import numpy as np
import argparse


def read_val():
    scores = sorted(glob.glob("/scratch/ssd004/scratch/calvinyu/git_merge/merged/all/*.npy"))
    scores_dict = {}
    split = "test"
    
    for score in tqdm(scores):
        if f"{split}_full." not in score:
            continue

        ckpt = os.path.basename(score).replace(".npy", "")
        merge_method = ckpt.split("_")[0]
        
        merge_params = ckpt.replace(f"_clip_scores_{split}_indomain", "")
        merge_params = merge_params.replace(f"{merge_method}_", "")

        if merge_method not in scores_dict:
            scores_dict[merge_method] = {}

        scores_dict[merge_method][merge_params] = np.load(score, allow_pickle=True).item()['indomain_mean']
    
    print(f"split: {split}")
    for mm in scores_dict:
        print(f"merge method: {mm}")
        best_params = max(scores_dict[mm], key=scores_dict[mm].get)
        for params in scores_dict[mm]:
            print(f"params: {params} and score: {round(scores_dict[mm][params], 2)}")
        print(f"\nbest params: {best_params} and score: {round(scores_dict[mm][best_params], 2)}")
        print(f"\n\n-----------------")        


    # for score in tqdm(scores):
    #     if f"{split}_full." not in score:
    #         continue

    #     ckpt = os.path.basename(score).replace(".npy", "")
    #     merge_method = ckpt.split("_")[0]
        
    #     merge_params = ckpt.replace(f"_clip_scores_{split}_indomain", "")
    #     merge_params = merge_params.replace(f"{merge_method}_", "")

    #     if merge_method not in scores_dict:
    #         scores_dict[merge_method] = {}

    #     try:            
    #         scores_dict[merge_method][merge_params] = np.load(score, allow_pickle=True).item()['indomain_fid']
    #     except:
    #         print(f"error: {score}")
    #         scores_dict[merge_method][merge_params] = -1.0
    # print(f"split: {split}")

    # for mm in scores_dict:
    #     print(f"merge method: {mm}")
    #     best_params = max(scores_dict[mm], key=scores_dict[mm].get)
    #     for params in scores_dict[mm]:
    #         print(f"params: {params} and score: {round(scores_dict[mm][params], 2)}")
    #     print(f"\nbest params: {best_params} and score: {round(scores_dict[mm][best_params], 2)}")
    #     print(f"\n\n-----------------")        

    # for score in tqdm(scores):
    #     if f"{split}_full." not in score:
    #         continue

    #     ckpt = os.path.basename(score).replace(".npy", "")
    #     merge_method = ckpt.split("_")[0]
        
    #     merge_params = ckpt.replace(f"_clip_scores_{split}_indomain", "")
    #     merge_params = merge_params.replace(f"{merge_method}_", "")

    #     if merge_method not in scores_dict:
    #         scores_dict[merge_method] = {}
    #     scores_dict[merge_method][merge_params] = np.load(score, allow_pickle=True).item()['outdomain_fid']
    
    # print(f"split: {split}")
    # for mm in scores_dict:
    #     print(f"merge method: {mm}")
    #     best_params = max(scores_dict[mm], key=scores_dict[mm].get)
    #     for params in scores_dict[mm]:
    #         print(f"params: {params} and score: {round(scores_dict[mm][params], 2)}")
    #     print(f"\nbest params: {best_params} and score: {round(scores_dict[mm][best_params], 2)}")
    #     print(f"\n\n-----------------")        


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--split_idx", type=int, required=True)
    args = argparser.parse_args()

    ckpts = glob.glob("/scratch/ssd004/scratch/calvinyu/git_merge/merged/seq/*.safetensors")
    total_ckpts = sorted(glob.glob("/scratch/ssd004/scratch/calvinyu/git_merge/merged/seq/*.safetensors"))

    # ckpts = glob.glob("/scratch/ssd004/scratch/calvinyu/git_merge/merged/all/*.safetensors")
    # total_ckpts = sorted(glob.glob("/scratch/ssd004/scratch/calvinyu/git_merge/merged/all/*.safetensors"))

    # ckpts = glob.glob("/scratch/ssd004/scratch/yashkant/merged/all/*.safetensors")
    # total_ckpts = glob.glob("/scratch/ssd004/scratch/yashkant/merged/all/*.safetensors")

    # # filter best checkpoints
    # best_ckpts = [
    #     "average",
    #     "fisher",
    #     "task_arithmetic_lambda_0.1",
    #     "mats_iterations_50",
    #     "regmean_lambda_0.9",
    #     "ties_lambda_0.2",
    #     "dare_lambda_0.1_dropout_0.1"
    # ]
    # filtered_ckpts = []
    # for ckpt in best_ckpts:
    #     for ckpt_ in ckpts:
    #         if ckpt in ckpt_:
    #             filtered_ckpts.append(ckpt_)
    #             break
    # assert len(filtered_ckpts) == len(best_ckpts)
    # ckpts = filtered_ckpts
    # remain_ckpts = sorted(list(set(total_ckpts) - set(ckpts)))

    # filter out mats and regmean
    # total_ckpts = [ckpt for ckpt in total_ckpts if "task_arithmetic_lambda_0.7" in ckpt or "regmean_lambda_0.4" in ckpt or "mats" in ckpt or "dare" in ckpt]
    # total_ckpts = [ckpt for ckpt in total_ckpts if "dare" not in ckpt]

    new_ckpts = []
    for ckpt in total_ckpts:
        if not os.path.exists(ckpt.replace(".safetensors", f"_clip_scores_val_test_full.npy")):
            new_ckpts.append(ckpt)
    total_ckpts = new_ckpts

    
    print(f"regmean count: {len([ckpt for ckpt in total_ckpts if 'regmean' in ckpt])}")
    print(f"dare count: {len([ckpt for ckpt in total_ckpts if 'dare' in ckpt])}")
    print(f"fisher count: {len([ckpt for ckpt in total_ckpts if 'fisher' in ckpt])}")
    print(f"average count: {len([ckpt for ckpt in total_ckpts if 'average' in ckpt])}")
    print(f"task_arithmetic count: {len([ckpt for ckpt in total_ckpts if 'task_arithmetic' in ckpt])}")
    print(f"ties count: {len([ckpt for ckpt in total_ckpts if 'ties' in ckpt])}")
    print(f"mats count: {len([ckpt for ckpt in total_ckpts if 'mats' in ckpt])}")
    print(f"----")
    print(f"total checkpoints: {len(total_ckpts)}")
        
    # split in groups of 12
    split_idx = args.split_idx
    split_size = 5
    remain_ckpts = total_ckpts
    remain_ckpts_split = [remain_ckpts[i:i+split_size] for i in range(0, len(remain_ckpts), split_size)]

    for ckpt in tqdm((remain_ckpts_split[split_idx]), total=len(remain_ckpts_split[split_idx]), desc="inference"):
        print(f"inference for {ckpt}")
        exit_code = os.system(f"python scripts/inference.py --ckpt {ckpt}")
        if exit_code != 0:
            breakpoint()

    # read val scores and find optimal hyperparams
    read_val()


