import os
from domainnet import *
from datasets import load_dataset, Image
from datasets import Dataset
from tqdm import tqdm
import numpy as np

dataset_path = "/datasets/DomainNet"
splits_dataset_path = "/h/calvinyu/evalmerge/splits/"

domain_categories_map = {
    "clipart": ["cloth", "furniture", "mammal", "tool"],
    "infograph": ["building", "electricity", "human_body", "office"],
    "painting": ["cold_blooded", "food", "nature", "road_transportation"],
    "quickdraw": ["fruit", "music", "sport", "tree"],
    "real": ["bird", "kitchen", "shape", "vegatable"],
    "sketch": ["insect", "others", "sky_transportation", "water_transportation"]
}


all_categories = list(set([item for sublist in domain_categories_map.values() for item in sublist]))
all_domains = list(domain_categories_map.keys())
all_domain_categories_map = {domain: all_categories for domain in all_domains}

# remove domain_categories map from all_domain_categories_map
outdomain_categories_map = {domain: [category for category in categories if category not in domain_categories_map[domain]] for domain, categories in all_domain_categories_map.items()}


def indices(lst, item):
    return [i for i, x in enumerate(lst) if x == item]


def load_domainnet(num_samples=None, full=False, remaining_map=False):
    print("{:<25} | {:<25} | {:<5} | {:<5} | {:<5} | {:<5}".format("Domain", "Category", "Train", "Val", "Test", "Train/Val/Test"))
    print("-" * 100)
    datasets = {}

    if remaining_map:
        map = outdomain_categories_map
    else:
        map = all_domain_categories_map if full else domain_categories_map
    
    for domain, categories in tqdm(map.items(), total=len(map), desc="reading datasets", disable=True):
        domain_train_path = os.path.join(splits_dataset_path, f"{domain}_train_fold.txt")
        domain_val_path = os.path.join(splits_dataset_path, f"{domain}_validation_fold.txt")
        domain_test_path = os.path.join(splits_dataset_path, f"{domain}_test_fold.txt")

        with open(domain_train_path, "r") as f:
            train_filenames = f.read().splitlines()
            train_filenames = [filename.split(" ")[0] for filename in train_filenames]

        with open(domain_val_path, "r") as f:
            val_filenames = f.read().splitlines()
            val_filenames = [filename.split(" ")[0] for filename in val_filenames]

        with open(domain_test_path, "r") as f:
            test_filenames = f.read().splitlines()
            test_filenames = [filename.split(" ")[0] for filename in test_filenames]
        
        datasets[domain] = {}
        for category in categories:
            classes = list(CATEGORIES[category].keys())
            classes = [_class.replace(" ", "_") for _class in classes]
            assert all([os.path.exists(os.path.join(dataset_path, domain, _class)) for _class in classes])

            train_files = [os.path.join(dataset_path, filename) for filename in train_filenames if filename.split("/")[1] in classes]
            train_captions = [f"photo of a {file.split('/')[-2].replace('_', ' ')} in {domain} style" for file in train_files]
            # train_captions = [f"{file.split('/')[-2].replace('_', ' ')}-{domain}" for file in train_files]

            val_files = [os.path.join(dataset_path, filename) for filename in val_filenames if filename.split("/")[1] in classes]
            val_captions = [f"photo of a {file.split('/')[-2].replace('_', ' ')} in {domain} style" for file in val_files]
            # val_captions = [f"{file.split('/')[-2].replace('_', ' ')}-{domain}" for file in val_files]

            test_files = [os.path.join(dataset_path, filename) for filename in test_filenames if filename.split("/")[1] in classes]
            test_captions = [f"photo of a {file.split('/')[-2].replace('_', ' ')} in {domain} style" for file in test_files]
            # test_captions = [f"{file.split('/')[-2].replace('_', ' ')}-{domain}" for file in test_files]

            train_files, train_captions = train_files[:num_samples], train_captions[:num_samples]
            dataset_dict = {"image": train_files, "text": train_captions, "image_paths":train_files}
            train_dataset = Dataset.from_dict(dataset_dict).cast_column("image", Image())

            val_files, val_captions = val_files[:num_samples], val_captions[:num_samples]
            dataset_dict = {"image": val_files, "text": val_captions, "image_paths":val_files}
            val_dataset = Dataset.from_dict(dataset_dict).cast_column("image", Image())

            test_files, test_captions = test_files[:num_samples], test_captions[:num_samples]
            dataset_dict = {"image": test_files, "text": test_captions, "image_paths":test_files}
            test_dataset = Dataset.from_dict(dataset_dict).cast_column("image", Image())

            datasets[domain][category] = {"train": train_dataset, "val": val_dataset, "test": test_dataset}
            total_samples = len(train_files) + len(val_files) + len(test_files)
            train_ratio, val_ratio, test_ratio = len(train_files) / total_samples, len(val_files) / total_samples, len(test_files) / total_samples

            # train / test ratio
            print(f"{domain:<25} | {category:<25} | {len(train_files):<5} | {len(val_files):<5} | {len(test_files):<5} | {train_ratio:.2f} | {val_ratio:.2f} | {test_ratio:.2f}")

    return datasets


def get_fid_images(split="test", captions_threshold=None):
    test_captions, indomain_test_captions, outdomain_test_captions = [], [], []
    dataset = load_domainnet(full=False)
    total_images, indomain_images, outdomain_images = [], [], []
    
    for domain in dataset:
        for category in dataset[domain]:
            all_captions = dataset[domain][category][split]['text']
            captions = list(set(all_captions))[:captions_threshold]
            
            # select three images for each caption
            all_images = np.array(dataset[domain][category][split]['image_paths'])
            images = [all_images[indices(all_captions, caption)[:3]] for caption in captions]
            
            # flatten list of lists
            images = [item for sublist in images for item in sublist]
            images = list(set(images))
            indomain_test_captions.extend(captions)
            indomain_images.extend(images)

    dataset = load_domainnet(remaining_map=True)
    all_outdomain_images = []
    for domain in dataset:
        for category in dataset[domain]:
            all_captions = dataset[domain][category][split]['text']
            captions = list(set(all_captions))[:captions_threshold]
            outdomain_test_captions.extend(captions)
            # select one image for each caption
            all_images = dataset[domain][category][split]['image_paths']
            images = [all_images[all_captions.index(caption)] for caption in captions]
            outdomain_images.extend(images)
            all_outdomain_images.extend(all_images)

    print(f"split {split}, indomain: {len(indomain_test_captions)}, outdomain: {len(outdomain_test_captions)}, total: {len(indomain_test_captions) + len(outdomain_test_captions)}")

    # create FID images
    from PIL import Image

    def square_crop_image(image):
        width, height = image.size
        # select square crop that is as big as possible
        new_width, new_height = min(width, height), min(width, height)
        left = (width - new_width) / 2
        top = (height - new_height) / 2
        right = (width + new_width) / 2
        bottom = (height + new_height) / 2
        return image.crop((left, top, right, bottom))

    # remove all symlinks
    indomain_dir = "/h/calvinyu/evalmerge/inference/fid/indomain"
    outdomain_dir = "/h/calvinyu/evalmerge/inference/fid/outdomain"
    os.system(f"find {indomain_dir}  -type l -delete")
    os.system(f"find {outdomain_dir}  -type l -delete")
    os.system(f"rm -rf {indomain_dir}")
    os.system(f"rm -rf {outdomain_dir}")

    # symlink indomain and outdomain images for fid calculation
    os.makedirs(indomain_dir, exist_ok=True)
    for i, image in tqdm(enumerate(indomain_images), total=len(indomain_images)):
        image = Image.open(image)
        image = square_crop_image(image)
        image = image.resize((512, 512))
        save_path = os.path.join(indomain_dir, f"{i}.jpg")
        image.save(save_path)
        # os.symlink(image, save_path)
    
    # outdomain images have repeats, so we add some random samples from remaining ones to outdomain
    outdomain_images = list(set(outdomain_images)) + list(np.random.choice(list(set(all_outdomain_images) - set(outdomain_images)), len(outdomain_test_captions) - len(set(outdomain_images))))
    os.makedirs(outdomain_dir, exist_ok=True)
    


    for i, image in tqdm(enumerate(outdomain_images), total=len(outdomain_images)):
        image = Image.open(image)
        # center crop and resize to 512x512
        image = square_crop_image(image)
        image = image.resize((512, 512))
        save_path = os.path.join(outdomain_dir, f"{i}.jpg")
        image.save(save_path)
        # os.symlink(image, save_path)
    
    print(f"added symlinks for indomain: {len(indomain_images)}, outdomain: {len(outdomain_images)}")


def get_fig_images(split="test", captions_threshold=None):
    test_captions, indomain_test_captions, outdomain_test_captions = [], [], []
    in_dataset = load_domainnet(full=False)
    out_dataset = load_domainnet(remaining_map=True)
    breakpoint()

    # create FID images
    from PIL import Image

    def square_crop_image(image):
        width, height = image.size
        # select square crop that is as big as possible
        new_width, new_height = min(width, height), min(width, height)
        left = (width - new_width) / 2
        top = (height - new_height) / 2
        right = (width + new_width) / 2
        bottom = (height + new_height) / 2
        return image.crop((left, top, right, bottom))
    
    def preprocess(image):
        image = Image.open(image)
        # center crop and resize to 512x512
        image = square_crop_image(image)
        image = image.resize((512, 512))
        return image



    rows = [("apple", "banana"), ("penguin", "owl"), ("shovel", "skateboard"), ("windmill", "bridge"), ("pizza", "donut"), ("butterfly", "ant")]
    cols = ['quickdraw', 'real', 'clipart', 'infograph', 'painting', 'sketch']

    breakpoint()

    row_images = []
    for row, col in zip(rows, cols):
        sub_row1_images = in_dataset[col][row[0]][split]['image_paths']
        sub_row2_images = in_dataset[col][row[1]][split]['image_paths']
        
        # select images with highest image size
        sub_row1_image_sizes = [os.path.getsize(image) for image in sub_row1_images]
        sub_row2_image_sizes = [os.path.getsize(image) for image in sub_row2_images]

        sub_row1_images_max = [sub_row1_images[i] for i in indices(sub_row1_image_sizes, max(sub_row1_image_sizes))]
        sub_row2_images_max = [sub_row2_images[i] for i in indices(sub_row2_image_sizes, max(sub_row2_image_sizes))]

        sub_row1 = preprocess(sub_row1_images_max[0])
        sub_row2 = preprocess(sub_row2_images_max[0])

    total_images, indomain_images, outdomain_images = [], [], []
    for domain in dataset:
        for category in dataset[domain]:
            all_captions = dataset[domain][category][split]['text']
            captions = list(set(all_captions))[:captions_threshold]
            
            # select three images for each caption
            all_images = np.array(dataset[domain][category][split]['image_paths'])
            images = [all_images[indices(all_captions, caption)[:3]] for caption in captions]
            
            # flatten list of lists
            images = [item for sublist in images for item in sublist]
            images = list(set(images))
            indomain_test_captions.extend(captions)
            indomain_images.extend(images)

    all_outdomain_images = []
    for domain in dataset:
        for category in dataset[domain]:
            all_captions = dataset[domain][category][split]['text']
            captions = list(set(all_captions))[:captions_threshold]
            outdomain_test_captions.extend(captions)
            # select one image for each caption
            all_images = dataset[domain][category][split]['image_paths']
            images = [all_images[all_captions.index(caption)] for caption in captions]
            outdomain_images.extend(images)
            all_outdomain_images.extend(all_images)

    print(f"split {split}, indomain: {len(indomain_test_captions)}, outdomain: {len(outdomain_test_captions)}, total: {len(indomain_test_captions) + len(outdomain_test_captions)}")


    for i, image in tqdm(enumerate(outdomain_images), total=len(outdomain_images)):
        image = Image.open(image)
        # center crop and resize to 512x512
        image = square_crop_image(image)
        image = image.resize((512, 512))
        save_path = os.path.join(outdomain_dir, f"{i}.jpg")
        image.save(save_path)
        # os.symlink(image, save_path)
    
    print(f"added symlinks for indomain: {len(indomain_images)}, outdomain: {len(outdomain_images)}")


if __name__ == "__main__":
    # datasets = load_domainnet(full=True)
    # get_fid_images()
    get_fig_images()
    pass

"""
srun -p a40 --job-name=debug --gres=gpu:1 --cpus-per-task=8 --mem-per-cpu=5G --qos=m2 --time=8:00:00 --kill-on-bad-exit=1 --pty bash
split test, indomain: 347, outdomain: 1722, total: 2081
"""