Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOD files added #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
650 changes: 650 additions & 0 deletions trust/utils/.ipynb_checkpoints/custom_dataset-checkpoint.py

Large diffs are not rendered by default.

326 changes: 326 additions & 0 deletions trust/utils/.ipynb_checkpoints/custom_dataset_medmnist-checkpoint.py

Large diffs are not rendered by default.

204 changes: 204 additions & 0 deletions trust/utils/.ipynb_checkpoints/dermamnist-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import numpy as np
import os
import torch
import torchvision
from sklearn import datasets
from torchvision import datasets, transforms
import PIL.Image as Image
from .utils import *
np.random.seed(42)
torch.manual_seed(42)
from torchvision.datasets import cifar
from torch.utils.data import Dataset, Subset, ConcatDataset, DataLoader

class DermaDataset(Dataset):
def __init__(self, data, root='/mnt/data2/akshit/data/', transform=None):
self.root = root
self.transform = transform
self.images = data['images']
self.targets = data['labels'].flatten()


def __len__(self):
return len(self.targets)

def __getitem__(self, idx):
img = Image.fromarray(np.uint8(self.images[idx])).convert('RGB')
if self.transform:
img = self.transform(img)
label = self.targets[idx]
return img, label


def getOODtargets(targets, sel_cls_idx, ood_cls_id):

ood_targets = []
targets_list = list(targets)
for i in range(len(targets_list)):
if(targets_list[i] in list(sel_cls_idx)):
ood_targets.append(targets_list[i])
else:
ood_targets.append(ood_cls_id)
print("num ood samples: ", ood_targets.count(ood_cls_id))
return torch.Tensor(ood_targets)

def create_ood_data(fullset, testset, split_cfg, num_cls, augVal):

np.random.seed(42)
train_idx = []
val_idx = []
lake_idx = []
test_idx = []
selected_classes = np.array(list(range(split_cfg['num_cls_idc'])))
for i in range(num_cls): #all_classes
full_idx_class = list(torch.where(torch.Tensor(fullset.targets) == i)[0].cpu().numpy())
if(i in selected_classes):
test_idx_class = list(torch.where(torch.Tensor(testset.targets) == i)[0].cpu().numpy())
test_idx += test_idx_class
class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_idc_train'], replace=True))
train_idx += class_train_idx
remain_idx = list(set(full_idx_class) - set(class_train_idx))
class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_idc_val'], replace=True))
remain_idx = list(set(remain_idx) - set(class_val_idx))
class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_idc_lake'], replace=True))
else:
class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_ood_train'], replace=False)) #always 0
remain_idx = list(set(full_idx_class) - set(class_train_idx))
class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_ood_val'], replace=False)) #Only for CG ood val has samples
remain_idx = list(set(remain_idx) - set(class_val_idx))
class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_ood_lake'], replace=False)) #many ood samples in lake

if(augVal and (i in selected_classes)): #augment with samples only from the imbalanced classes
train_idx += class_val_idx
val_idx += class_val_idx
lake_idx += class_lake_idx

train_set = SubsetWithTargets(fullset, train_idx, torch.Tensor(fullset.targets)[train_idx])
val_set = SubsetWithTargets(fullset, val_idx, torch.Tensor(fullset.targets)[val_idx])
lake_set = SubsetWithTargets(fullset, lake_idx, getOODtargets(torch.Tensor(fullset.targets)[lake_idx], selected_classes, split_cfg['num_cls_idc']))
test_set = SubsetWithTargets(testset, test_idx, torch.Tensor(testset.targets)[test_idx])

return train_set, val_set, test_set, lake_set, selected_classes

############
# OOD TYPE 1
############

def load_dataset_custom_1(datadir, feature, split_cfg, augVal=False, dataAug=True):

num_cls = 8
path = '/mnt/data2/akshit/'
download_path = '/mnt/data2/akshit/data/cifar10'
train_data = np.load(f'{path}data/dermamnist/dm_train.npz', allow_pickle=True)
val_data = np.load(f'{path}data/dermamnist/dm_val.npz', allow_pickle=True)
test_data = np.load(f'{path}data/dermamnist/dm_test_balanced.npz', allow_pickle=True)
ptrain={
'images': np.concatenate((train_data['images'],val_data['images'])),
'labels': np.concatenate((train_data['labels'],val_data['labels']))
}

# Define the number of classes in our modified CIFAR10, which is 6. We also define our ID classes
cifar_training_transform = transforms.Compose([transforms.RandomCrop(28), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
cifar_test_transform = transforms.Compose([transforms.Resize(28), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
cifar_label_transform = lambda x: 7

# Get the dataset objects from PyTorch. Here, CIFAR10 is downloaded, and the transform is applied when points
# are retrieved.
cifar10_full_train = cifar.CIFAR10(download_path, train=True, download=False, transform=cifar_training_transform, target_transform=cifar_label_transform)
cifar10_test = cifar.CIFAR10(download_path, train=False, download=False, transform=cifar_test_transform, target_transform=cifar_label_transform)


derma_full_train = DermaDataset(data=ptrain, transform=cifar_training_transform)
derma_test = DermaDataset(data=test_data, transform=cifar_test_transform)

fullset = ConcatDataset([derma_full_train, cifar10_full_train])
fullset.targets = np.append(derma_full_train.targets, [7 for i in range(50000)])
test_set = derma_test

if(feature=="ood"):
train_set, val_set, test_set, lake_set, ood_cls_idx = create_ood_data(fullset, test_set, split_cfg, num_cls, augVal)
print("Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set), "Test set: ", len(test_set))
return train_set, val_set, test_set, lake_set, ood_cls_idx, split_cfg['num_cls_idc']

#############
## Taking examples from different classes instead of oversampling
#############
def create_ood_data_2(fullset, testset, split_cfg, num_cls, augVal):

np.random.seed(42)
train_idx = []
val_idx = []
lake_idx = []
test_idx = []
selected_classes = np.array(list(range(split_cfg['num_cls_idc'])))
for i in range(num_cls): #all_classes
full_idx_class = list(torch.where(torch.Tensor(fullset.targets) == i)[0].cpu().numpy())
if(i in selected_classes):
test_idx_class = list(torch.where(torch.Tensor(testset.targets) == i)[0].cpu().numpy())
test_idx += test_idx_class
class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_idc_train'], replace=False))
train_idx += class_train_idx
remain_idx = list(set(full_idx_class) - set(class_train_idx))

class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_idc_val'], replace=False))
remain_idx = list(set(remain_idx) - set(class_val_idx))
## Taking examples from different classes instead of oversampling
if len(remain_idx)>=split_cfg['per_idc_lake']:
class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_idc_lake'], replace=False))
elif len(remain_idx)<split_cfg['per_idc_lake']:
class_lake_idx = list(np.random.choice(np.array(remain_idx), size=len(remain_idx), replace=False))

else:
class_train_idx = list(np.random.choice(np.array(full_idx_class), size=split_cfg['per_ood_train'], replace=False)) #always 0
remain_idx = list(set(full_idx_class) - set(class_train_idx))
class_val_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_ood_val'], replace=False)) #Only for CG ood val has samples
remain_idx = list(set(remain_idx) - set(class_val_idx))
class_lake_idx = list(np.random.choice(np.array(remain_idx), size=split_cfg['per_ood_lake'], replace=False)) #many ood samples in lake

if(augVal and (i in selected_classes)): #augment with samples only from the imbalanced classes
train_idx += class_val_idx
val_idx += class_val_idx
lake_idx += class_lake_idx

train_set = SubsetWithTargets(fullset, train_idx, torch.Tensor(fullset.targets)[train_idx])
val_set = SubsetWithTargets(fullset, val_idx, torch.Tensor(fullset.targets)[val_idx])
lake_set = SubsetWithTargets(fullset, lake_idx, getOODtargets(torch.Tensor(fullset.targets)[lake_idx], selected_classes, split_cfg['num_cls_idc']))
test_set = SubsetWithTargets(testset, test_idx, torch.Tensor(testset.targets)[test_idx])

return train_set, val_set, test_set, lake_set, selected_classes

def load_dataset_custom_2(datadir, feature, split_cfg, augVal=False, dataAug=True):

num_cls = 8
path = '/mnt/data2/akshit/'
download_path = '/mnt/data2/akshit/data/cifar10'
train_data = np.load(f'{path}data/dermamnist/dm_train.npz', allow_pickle=True)
val_data = np.load(f'{path}data/dermamnist/dm_val.npz', allow_pickle=True)
test_data = np.load(f'{path}data/dermamnist/dm_test_balanced.npz', allow_pickle=True)
ptrain={
'images': np.concatenate((train_data['images'],val_data['images'])),
'labels': np.concatenate((train_data['labels'],val_data['labels']))
}

# Define the number of classes in our modified CIFAR10, which is 6. We also define our ID classes
cifar_training_transform = transforms.Compose([transforms.RandomCrop(28), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
cifar_test_transform = transforms.Compose([transforms.Resize(28), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
cifar_label_transform = lambda x: 7

# Get the dataset objects from PyTorch. Here, CIFAR10 is downloaded, and the transform is applied when points
# are retrieved.
cifar10_full_train = cifar.CIFAR10(download_path, train=True, download=False, transform=cifar_training_transform, target_transform=cifar_label_transform)
cifar10_test = cifar.CIFAR10(download_path, train=False, download=False, transform=cifar_test_transform, target_transform=cifar_label_transform)


derma_full_train = DermaDataset(data=ptrain, transform=cifar_training_transform)
derma_test = DermaDataset(data=test_data, transform=cifar_test_transform)

fullset = ConcatDataset([derma_full_train, cifar10_full_train])
fullset.targets = np.append(derma_full_train.targets, [7 for i in range(50000)])
test_set = derma_test

if(feature=="ood"):
train_set, val_set, test_set, lake_set, ood_cls_idx = create_ood_data_2(fullset, test_set, split_cfg, num_cls, augVal)
print("Custom dataset stats: Train size: ", len(train_set), "Val size: ", len(val_set), "Lake size: ", len(lake_set), "Test set: ", len(test_set))
return train_set, val_set, test_set, lake_set, ood_cls_idx, split_cfg['num_cls_idc']
146 changes: 146 additions & 0 deletions trust/utils/.ipynb_checkpoints/medmnist-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from .medmnist_info import INFO


class MedMNIST(Dataset):

flag = ...

def __init__(self,
root,
split='train',
transform=None,
target_transform=None,
download=False):
''' dataset
:param split: 'train', 'val' or 'test', select subset
:param transform: data transformation
:param target_transform: target transformation

'''

self.info = INFO[self.flag]
self.root = root

if download:
self.download()

if not os.path.exists(
os.path.join(self.root, "{}.npz".format(self.flag))):
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')

npz_file = np.load(os.path.join(self.root, "{}.npz".format(self.flag)))

self.split = split
self.transform = transform
self.target_transform = target_transform

if self.split == 'train':
self.data = npz_file['train_images']
self.targets = npz_file['train_labels']
self.targets = np.squeeze(self.targets)
elif self.split == 'val':
self.data = npz_file['val_images']
self.targets = npz_file['val_labels']
self.targets = np.squeeze(self.targets)
elif self.split == 'test':
self.data = npz_file['test_images']
self.targets = npz_file['test_labels']
self.targets = np.squeeze(self.targets)

if self.flag == 'octmnist' or self.flag == 'pneumoniamnist':
new_data = []
for i in self.data:
i = np.stack((i,)*3,axis=-1)
new_data.append(i)
self.data = np.array(new_data)

def __getitem__(self, index):
data, target = self.data[index], self.targets[index].astype(int)
data = Image.fromarray(np.uint8(data))

if self.transform is not None:
data = self.transform(data)

if self.target_transform is not None:
target = self.target_transform(target)

return data, target

def __len__(self):
return self.data.shape[0]

def __repr__(self):
'''Adapted from torchvision.
'''
_repr_indent = 4
head = "Dataset " + self.__class__.__name__

body = ["Number of datapoints: {}".format(self.__len__())]
body.append("Root location: {}".format(self.root))
body.append("Split: {}".format(self.split))
body.append("Task: {}".format(self.info["task"]))
body.append("Number of channels: {}".format(self.info["n_channels"]))
body.append("Meaning of labels: {}".format(self.info["label"]))
body.append("Number of samples: {}".format(self.info["n_samples"]))
body.append("Description: {}".format(self.info["description"]))
body.append("License: {}".format(self.info["license"]))

lines = [head] + [" " * _repr_indent + line for line in body]
return '\n'.join(lines)

def download(self):
try:
from torchvision.datasets.utils import download_url
download_url(url=self.info["url"],
root=self.root,
filename="{}.npz".format(self.flag),
md5=self.info["MD5"])
except:
raise RuntimeError('Something went wrong when downloading! ' +
'Go to the homepage to download manually. ' +
'https://github.com/MedMNIST/MedMNIST')


class PathMNIST(MedMNIST):
flag = "pathmnist"


class OCTMNIST(MedMNIST):
flag = "octmnist"


class PneumoniaMNIST(MedMNIST):
flag = "pneumoniamnist"


class ChestMNIST(MedMNIST):
flag = "chestmnist"


class DermaMNIST(MedMNIST):
flag = "dermamnist"


class RetinaMNIST(MedMNIST):
flag = "retinamnist"


class BreastMNIST(MedMNIST):
flag = "breastmnist"


class OrganMNISTAxial(MedMNIST):
flag = "organmnist_axial"


class OrganMNISTCoronal(MedMNIST):
flag = "organmnist_coronal"


class OrganMNISTSagittal(MedMNIST):
flag = "organmnist_sagittal"
Loading