quandao92's picture
Upload 48 files
71d05bb verified
raw
history blame
5.51 kB
import torch.utils.data as data
import json
import random
from PIL import Image
import numpy as np
import torch
import os
def generate_class_info(dataset_name, mode='train'):
class_name_map_class_id = {}
if dataset_name == 'mvtec':
# obj_list = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill',
# 'transistor', 'metal_nut', 'screw', 'toothbrush', 'zipper', 'tile', 'wood']
obj_list = ['bottle']
elif dataset_name == '4inlab':
if mode=='train':
obj_list = ['shinpyung'] # With training
elif mode=='test':
obj_list = ['shinpyung'] # With testing
elif dataset_name == 'task1':
if mode=='train':
obj_list = ['cup']
elif dataset_name == 'task2':
if mode=='train':
obj_list = ['fire']
elif dataset_name == 'smoke_cloud':
if mode=='train':
obj_list = ['fire']
for k, index in zip(obj_list, range(len(obj_list))):
class_name_map_class_id[k] = index
return obj_list, class_name_map_class_id
class Dataset_test(data.Dataset):
def __init__(self, root, transform, target_transform, dataset_name, mode="test"):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.data_all = []
meta_info = json.load(open(f'{self.root}/meta_train.json', 'r'))
name = self.root.split('/')[-1]
meta_info = meta_info[mode]
self.cls_names = list(meta_info.keys())
for cls_name in self.cls_names:
self.data_all.extend(meta_info[cls_name])
self.length = len(self.data_all)
self.obj_list, self.class_name_map_class_id = generate_class_info(dataset_name,mode='test')
def __len__(self):
return self.length
def __getitem__(self, index):
data = self.data_all[index]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
img = Image.open(os.path.join(self.root, img_path))
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
if os.path.isdir(os.path.join(self.root, mask_path)):
# just for classification not report error
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
# transforms
img = self.transform(img) if self.transform is not None else img
img_mask = self.target_transform(
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
img_mask = [] if img_mask is None else img_mask
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}
class Dataset_train(data.Dataset):
def __init__(self, root, transform, target_transform, dataset_name, mode="train"):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.data_all = []
meta_info = json.load(open(f'{self.root}/meta_train.json', 'r'))
name = self.root.split('/')[-1]
meta_info = meta_info[mode]
self.cls_names = list(meta_info.keys())
for cls_name in self.cls_names:
self.data_all.extend(meta_info[cls_name])
self.length = len(self.data_all)
self.obj_list, self.class_name_map_class_id = generate_class_info(dataset_name,mode='train')
def __len__(self):
return self.length
def __getitem__(self, index):
data = self.data_all[index]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
img = Image.open(os.path.join(self.root, img_path))
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
if os.path.isdir(os.path.join(self.root, mask_path)):
# just for classification not report error
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
# transforms
img = self.transform(img) if self.transform is not None else img
img_mask = self.target_transform(
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
img_mask = [] if img_mask is None else img_mask
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}