import os
from domainnet import *
from datasets import load_dataset, Image
from datasets import Dataset
from tqdm import tqdm
import numpy as np

dataset_path = "/datasets/DomainNet"
splits_dataset_path = "/h/calvinyu/evalmerge/splits/"

domain_categories_map = {
    "clipart": ["cloth", "furniture", "mammal", "tool"],
    "infograph": ["building", "electricity", "human_body", "office"],
    "painting": ["cold_blooded", "food", "nature", "road_transportation"],
    "quickdraw": ["fruit", "music", "sport", "tree"],
    "real": ["bird", "kitchen", "shape", "vegatable"],
    "sketch": ["insect", "others", "sky_transportation", "water_transportation"]
}


all_categories = list(set([item for sublist in domain_categories_map.values() for item in sublist]))
all_domains = list(domain_categories_map.keys())
all_domain_categories_map = {domain: all_categories for domain in all_domains}

# remove domain_categories map from all_domain_categories_map
outdomain_categories_map = {domain: [category for category in categories if category not in domain_categories_map[domain]] for domain, categories in all_domain_categories_map.items()}


def indices(lst, item):
    return [i for i, x in enumerate(lst) if x == item]


def load_domainnet(num_samples=None, full=False, remaining_map=False):
    print("{:<25} | {:<25} | {:<5} | {:<5} | {:<5} | {:<5}".format("Domain", "Category", "Train", "Val", "Test", "Train/Val/Test"))
    print("-" * 100)
    datasets = {}

    if remaining_map:
        map = outdomain_categories_map
    else:
        map = all_domain_categories_map if full else domain_categories_map
    
    for domain, categories in tqdm(map.items(), total=len(map), desc="reading datasets", disable=True):
        domain_train_path = os.path.join(splits_dataset_path, f"{domain}_train_fold.txt")
        domain_val_path = os.path.join(splits_dataset_path, f"{domain}_validation_fold.txt")
        domain_test_path = os.path.join(splits_dataset_path, f"{domain}_test_fold.txt")

        with open(domain_train_path, "r") as f:
            train_filenames = f.read().splitlines()
            train_filenames = [filename.split(" ")[0] for filename in train_filenames]

        with open(domain_val_path, "r") as f:
            val_filenames = f.read().splitlines()
            val_filenames = [filename.split(" ")[0] for filename in val_filenames]

        with open(domain_test_path, "r") as f:
            test_filenames = f.read().splitlines()
            test_filenames = [filename.split(" ")[0] for filename in test_filenames]
        
        datasets[domain] = {}
        for category in categories:
            classes = list(CATEGORIES[category].keys())
            classes = [_class.replace(" ", "_") for _class in classes]
            assert all([os.path.exists(os.path.join(dataset_path, domain, _class)) for _class in classes])

            train_files = [os.path.join(dataset_path, filename) for filename in train_filenames if filename.split("/")[1] in classes]
            train_captions = [file.split('/')[-2].replace('_', ' ') for file in train_files]
            # train_captions = [f"{file.split('/')[-2].replace('_', ' ')}-{domain}" for file in train_files]

            val_files = [os.path.join(dataset_path, filename) for filename in val_filenames if filename.split("/")[1] in classes]
            val_captions = [file.split('/')[-2].replace('_', ' ') for file in val_files]
            # val_captions = [f"{file.split('/')[-2].replace('_', ' ')}-{domain}" for file in val_files]

            test_files = [os.path.join(dataset_path, filename) for filename in test_filenames if filename.split("/")[1] in classes]
            test_captions = [file.split('/')[-2].replace('_', ' ') for file in test_files]
            # test_captions = [f"{file.split('/')[-2].replace('_', ' ')}-{domain}" for file in test_files]

            train_files, train_captions = train_files[:num_samples], train_captions[:num_samples]
            dataset_dict = {"image": train_files, "text": train_captions, "image_paths":train_files}
            train_dataset = Dataset.from_dict(dataset_dict).cast_column("image", Image())

            val_files, val_captions = val_files[:num_samples], val_captions[:num_samples]
            dataset_dict = {"image": val_files, "text": val_captions, "image_paths":val_files}
            val_dataset = Dataset.from_dict(dataset_dict).cast_column("image", Image())

            test_files, test_captions = test_files[:num_samples], test_captions[:num_samples]
            dataset_dict = {"image": test_files, "text": test_captions, "image_paths":test_files}
            test_dataset = Dataset.from_dict(dataset_dict).cast_column("image", Image())

            datasets[domain][category] = {"train": train_dataset, "val": val_dataset, "test": test_dataset}
            total_samples = len(train_files) + len(val_files) + len(test_files)
            train_ratio, val_ratio, test_ratio = len(train_files) / total_samples, len(val_files) / total_samples, len(test_files) / total_samples

            # train / test ratio
            print(f"{domain:<25} | {category:<25} | {len(train_files):<5} | {len(val_files):<5} | {len(test_files):<5} | {train_ratio:.2f} | {val_ratio:.2f} | {test_ratio:.2f}")

    return datasets


def get_fig_images(split="test", captions_threshold=None):
    test_captions, indomain_test_captions, outdomain_test_captions = [], [], []
    in_dataset = load_domainnet(full=False)
    out_dataset = load_domainnet(remaining_map=True)
    # create FID images
    from PIL import Image

    def square_crop_image(image):
        width, height = image.size
        # select square crop that is as big as possible
        new_width, new_height = min(width, height), min(width, height)
        left = (width - new_width) / 2
        top = (height - new_height) / 2
        right = (width + new_width) / 2
        bottom = (height + new_height) / 2
        return image.crop((left, top, right, bottom))
    
    def preprocess(image):
        image = Image.open(image)
        # center crop and resize to 512x512
        image = square_crop_image(image)
        image = image.resize((512, 512))
        return image

    def make_cell(sub_row1, sub_row2, rgb_color):
        cell = Image.new("RGB", (512 * 2, 512 * 2))

        cell.paste(sub_row1[0], (0, 0))
        cell.paste(sub_row1[1], (512, 0))
        cell.paste(sub_row2[0], (0, 512))
        cell.paste(sub_row2[1], (512, 512))
        
        # add a vertical and horizontal borders inside cell at 512-1 to 512+1 
        cell_image = np.array(cell)
        pixel_thickness = 4
        cell_image[512-pixel_thickness:512+pixel_thickness, :, :] = 0
        cell_image[:, 512-pixel_thickness:512+pixel_thickness, :] = 0
        cell = Image.fromarray(cell_image)

        # add a colored border 48 pixels thick around the cell
        pixel_thickness = 48
        imh, imw = np.array(cell).shape[:2]
        new_cell = Image.new("RGB", (imw + pixel_thickness, imh + pixel_thickness), (0, 0, 0))
        new_cell.paste(cell, (pixel_thickness // 2, pixel_thickness // 2))

        # make the the border dashed instead of solid
        if rgb_color == "dash":
            new_cell = np.array(new_cell)
            dashed_border = new_cell[:, :pixel_thickness//2]
            for i in range(len(dashed_border)):
                if (i // (pixel_thickness + 24)) % 2 == 1:
                    dashed_border[i] = 255 
            # insert left
            new_cell[:, :pixel_thickness//2] = dashed_border
            # insert right
            new_cell[:, -pixel_thickness//2:] = dashed_border

            # insert top
            new_cell[:pixel_thickness//2, :] = np.swapaxes(dashed_border, 0, 1)

            # insert bottom
            new_cell[-pixel_thickness//2:, :] = np.swapaxes(dashed_border, 0, 1)
            
            new_cell = Image.fromarray(new_cell)


        # add a white border 16 pixels thick around the cell
        pixel_thickness = 48
        imh, imw = np.array(new_cell).shape[:2]
        newer_cell = Image.new("RGB", (imw + pixel_thickness, imh + pixel_thickness), (255, 255, 255))
        newer_cell.paste(new_cell, (pixel_thickness // 2, pixel_thickness // 2))

        return newer_cell

    def get_cell_images(row, col, subcat, dataset):
        sub_row1_images = [img for img, txt in zip(dataset[col][subcat][split]['image_paths'], dataset[col][subcat][split]['text']) if txt == row[0]]

        sub_row2_images = [img for img, txt in zip(dataset[col][subcat][split]['image_paths'], dataset[col][subcat][split]['text']) if txt == row[1]]
        
        # select images with highest image size
        sub_row1_image_sizes = [os.path.getsize(image) for image in sub_row1_images]
        sub_row2_image_sizes = [os.path.getsize(image) for image in sub_row2_images]

        # sort images by image size
        sub_row1_images_sorted = sorted(sub_row1_images, key=lambda x: os.path.getsize(x), reverse=True)
        sub_row2_images_sorted = sorted(sub_row2_images, key=lambda x: os.path.getsize(x), reverse=True)

        if col == "clipart":

            if row[1] in ["banana", "shovel"]:
                sub_row2_images_sorted = sub_row2_images_sorted[1:]
            
            if row[0] in ["banana", "shovel"]:
                sub_row1_images_sorted = sub_row1_images_sorted[4:]
        
        elif col == "real":
            if row[0] in ["apple"]:
                sub_row1_images_sorted = sub_row1_images_sorted[1:]

        sub_row1 = [preprocess(sub_row1_images_sorted[0]), preprocess(sub_row1_images_sorted[1])]
        sub_row2 = [preprocess(sub_row2_images_sorted[0]), preprocess(sub_row2_images_sorted[1])]

        return sub_row1, sub_row2

    rows = [("apple", "banana"), ("penguin", "owl"), ("shovel", "skateboard"), ("windmill", "bridge"), ("pizza", "donut"), ("butterfly", "ant")]
    subcats = ["fruit", "bird", "tool", "building", "food", "insect"]
    cols = ['quickdraw', 'real', 'clipart', 'infograph', 'painting', 'sketch']


    
    row_images = []
    for i, (row, col, subcat) in enumerate(zip(rows, cols, subcats)):
        if i >= 3: break
        # indomain cell
        sub_row1, sub_row2 = get_cell_images(row, col, subcat, in_dataset)
        in_cell = make_cell(sub_row1, sub_row2, rgb_color="solid")

        # outdomain cell
        out_cells = []
        for j, od in enumerate(cols):
            if j >= 3: break
            if od == col: continue
            sub_row1, sub_row2 = get_cell_images(row, od, subcat, out_dataset)
            # rgb_color = (127 // 2, 255, 127 // 2)
            out_cell = make_cell(sub_row1, sub_row2, rgb_color="dash")
            out_cells.append(out_cell)

        # attach in_cell and out_cells horizontally
        row_image = [np.array(out_cell) for out_cell in out_cells]
        row_image.insert(i, np.array(in_cell))
        row_image = np.concatenate(row_image, axis=1)
        row_image = Image.fromarray(row_image)
        row_images.append(row_image)
    

    # stack row_images vertically
    row_images = np.concatenate(row_images, axis=0)
    row_images = Image.fromarray(row_images)
    row_images.save(f"row_images_{i}.jpg")



if __name__ == "__main__":
    # datasets = load_domainnet(full=True)
    # get_fid_images()
    get_fig_images()
    pass

"""
srun -p a40 --job-name=debug --gres=gpu:1 --cpus-per-task=8 --mem-per-cpu=5G --qos=m2 --time=8:00:00 --kill-on-bad-exit=1 --pty bash
split test, indomain: 347, outdomain: 1722, total: 2081
"""