import os
import glob
from tqdm import tqdm
import numpy as np
from huggingface_hub import HfApi, HfFileSystem

os.system("huggingface-cli login --token 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")

api = HfApi()
fs = HfFileSystem()
uploaded_files = fs.ls("yashkant/evalmerge/all", detail=False)
uploaded_ckpts = [os.path.basename(f) for f in uploaded_files if f.endswith(".safetensors")]
uploaded_eval_files = [os.path.basename(f) for f in uploaded_files if f.endswith("_clip_scores_val_test_full.npy")]

breakpoint()

api.upload_folder(
    folder_path="/scratch/ssd004/scratch/calvinyu/git_merge/merged/seq/",
    path_in_repo="seq_fixed", # Upload to a specific folder
    repo_id="yashkant/evalmerge",
    repo_type="model",
)
breakpoint()

type="all"
ckpts = glob.glob(f"/scratch/ssd004/scratch/calvinyu/git_merge/merged/{type}/*.safetensors")

for ckpt in tqdm(ckpts):
    # check if evaluated and contains FID score
    eval_file = ckpt.replace(".safetensors", "_clip_scores_val_test_full.npy")
    
    if os.path.basename(ckpt) in uploaded_ckpts and os.path.basename(eval_file) in uploaded_eval_files:
        print(f"Skipping {ckpt} and {eval_file} as they are already uploaded")
        os.system(f"rm {ckpt}")
        continue

    if not os.path.exists(eval_file):
        # print(f"Skipping {ckpt} as {eval_file} is not present")
        continue

    eval_dict = np.load(eval_file, allow_pickle=True).item()
    if "indomain_fid" not in eval_dict:
        # print(f"Skipping {ckpt} as it does not contain FID score")
        continue
    
    # breakpoint()
    
    api.upload_file(
        path_or_fileobj=ckpt,
        path_in_repo=os.path.join(f"{type}", os.path.basename(ckpt)),
        repo_id="yashkant/evalmerge",
        repo_type="model",
    )

    api.upload_file(
        path_or_fileobj=eval_file,
        path_in_repo=os.path.join(f"{type}", os.path.basename(eval_file)),
        repo_id="yashkant/evalmerge",
        repo_type="model",
    )

    # os.system(f"rm {eval_file}")

breakpoint()
# ckpts = glob.glob("/scratch/ssd004/scratch/calvinyu/git_merge/merged/seq/mats*.safetensors")

# for ckpt in tqdm(ckpts):
#     api.upload_file(
#         path_or_fileobj=ckpt,
#         path_in_repo=os.path.join("seq", os.path.basename(ckpt)),
#         repo_id="yashkant/evalmerge",
#         repo_type="model",
#     )

# for ckpt in tqdm(ckpts):
#     os.system(f"rm {ckpt}")

# api.upload_folder(
#     folder_path="/scratch/ssd004/scratch/calvinyu/git_merge/merged/seq",
#     path_in_repo="seq", # Upload to a specific folder
#     repo_id="yashkant/evalmerge",
#     repo_type="model",
# )

# api.upload_folder(
#     folder_path="/scratch/ssd004/scratch/calvinyu/git_merge/stats",
#     path_in_repo="stats", # Upload to a specific folder
#     repo_id="yashkant/evalmerge",
#     repo_type="model",
# )
