import os, sys
import glob
from tqdm import tqdm
import numpy as np
import argparse


def read_val():
    scores = sorted(glob.glob("/scratch/ssd004/scratch/calvinyu/git_merge/merged/all/*.npy"))
    scores_dict = {}
    full_scores_dict = {}
    split = "test"

    for score in tqdm(scores):
        if f"{split}_full." not in score:
            continue

        ckpt = os.path.basename(score).replace(".npy", "")
        merge_method = ckpt.split("_")[0]
        
        merge_params = ckpt.replace(f"_clip_scores_val_{split}_full", "")
        merge_params = merge_params.replace(f"{merge_method}_", "")

        if merge_method not in scores_dict:
            scores_dict[merge_method] = {}
            full_scores_dict[merge_method] = {}

        # load CLIP and FID for indomain and outdomain
        metrics = np.load(score, allow_pickle=True).item()
        scores_dict[merge_method][merge_params] = metrics['indomain_mean']
        full_scores_dict[merge_method][merge_params] = {
            "indomain_clip": metrics['indomain_mean'],
            "outdomain_clip": metrics['outdomain_mean'],
            "indomain_fid": metrics['indomain_fid'],
            "outdomain_fid": metrics['outdomain_fid']
        }
    
    print(f"split: {split}")
    for mm in scores_dict:
        print(f"merge method: {mm}")
        best_params = max(scores_dict[mm], key=scores_dict[mm].get)
        # for params in scores_dict[mm]:
        #     print(f"params: {params} and score: {round(scores_dict[mm][params], 2)}")
        # print(f"\nbest params: {best_params} and (indomain) CLIP score: {round(scores_dict[mm][best_params], 2)}")
        # print(f"\nbest params: {best_params} and (indomain) FID score: {round(full_scores_dict[mm][best_params]['indomain_fid'], 2)}")

        print(f"\nbest params: {best_params} and (outdomain) FID score: {round(full_scores_dict[mm][best_params]['outdomain_fid'], 2)}")

        print(f"\nbest params: {best_params} and (outdomain) CLIP score: {round(full_scores_dict[mm][best_params]['outdomain_clip'], 2)}")


        print(f"\n\n-----------------")        

    breakpoint()


    # for score in tqdm(scores):
    #     if f"{split}_full." not in score:
    #         continue

    #     ckpt = os.path.basename(score).replace(".npy", "")
    #     merge_method = ckpt.split("_")[0]
        
    #     merge_params = ckpt.replace(f"_clip_scores_{split}_indomain", "")
    #     merge_params = merge_params.replace(f"{merge_method}_", "")

    #     if merge_method not in scores_dict:
    #         scores_dict[merge_method] = {}

    #     try:            
    #         scores_dict[merge_method][merge_params] = np.load(score, allow_pickle=True).item()['indomain_fid']
    #     except:
    #         print(f"error: {score}")
    #         scores_dict[merge_method][merge_params] = -1.0
    # print(f"split: {split}")

    # for mm in scores_dict:
    #     print(f"merge method: {mm}")
    #     best_params = max(scores_dict[mm], key=scores_dict[mm].get)
    #     for params in scores_dict[mm]:
    #         print(f"params: {params} and score: {round(scores_dict[mm][params], 2)}")
    #     print(f"\nbest params: {best_params} and score: {round(scores_dict[mm][best_params], 2)}")
    #     print(f"\n\n-----------------")        

    # for score in tqdm(scores):
    #     if f"{split}_full." not in score:
    #         continue

    #     ckpt = os.path.basename(score).replace(".npy", "")
    #     merge_method = ckpt.split("_")[0]
        
    #     merge_params = ckpt.replace(f"_clip_scores_{split}_indomain", "")
    #     merge_params = merge_params.replace(f"{merge_method}_", "")

    #     if merge_method not in scores_dict:
    #         scores_dict[merge_method] = {}
    #     scores_dict[merge_method][merge_params] = np.load(score, allow_pickle=True).item()['outdomain_fid']
    
    # print(f"split: {split}")
    # for mm in scores_dict:
    #     print(f"merge method: {mm}")
    #     best_params = max(scores_dict[mm], key=scores_dict[mm].get)
    #     for params in scores_dict[mm]:
    #         print(f"params: {params} and score: {round(scores_dict[mm][params], 2)}")
    #     print(f"\nbest params: {best_params} and score: {round(scores_dict[mm][best_params], 2)}")
    #     print(f"\n\n-----------------")        


if __name__ == "__main__":
    # read val scores and find optimal hyperparams
    read_val()


