From 45dd8b8e6767448ff859203c834d01bd94f68ab9 Mon Sep 17 00:00:00 2001 From: "tassilo.wald" Date: Fri, 14 Aug 2020 14:26:58 +0200 Subject: [PATCH] Added a small sanity check to verify data is properly split. Also added docstring and ToDo's for later edit --- datasets/example_dataset/create_splits.py | 52 ++++++++++++++++++----- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/datasets/example_dataset/create_splits.py b/datasets/example_dataset/create_splits.py index a13ec2d..3a1ec32 100644 --- a/datasets/example_dataset/create_splits.py +++ b/datasets/example_dataset/create_splits.py @@ -23,31 +23,63 @@ def create_splits(output_dir, image_dir): + """File to split the dataset into multiple folds and the train, validation and test set. + + :param output_dir: Directory to write the splits file to + :param image_dir: Directory where the images lie in. + """ npy_files = subfiles(image_dir, suffix=".npy", join=False) sample_size = len(npy_files) - trainset_size = sample_size * 50 // 100 - valset_size = sample_size * 25 // 100 - testset_size = sample_size * 25 // 100 + testset_size = int(sample_size * 0.25) + valset_size = int(sample_size * 0.25) + trainset_size = sample_size - valset_size - testset_size # Assure all samples are used. if sample_size < (trainset_size + valset_size + testset_size): raise ValueError("Assure more total samples exist than train test and val samples combined!") splits = [] sample_set = {sample[:-4] for sample in npy_files.copy()} # Remove the file extension - test_set = set(random.sample(sample_set, testset_size)) # IMO the Testset should be static for all splits - # (otherwise we leak test samples into train batches) + test_samples = random.sample(sample_set, testset_size) # IMO the Testset should be static for all splits + for split in range(0, 5): - train_set = set(random.sample(sample_set - test_set, trainset_size)) - val_set = set(random.sample(sample_set - train_set - test_set, valset_size)) + train_samples = random.sample(sample_set - set(test_samples), trainset_size) + val_samples = list(sample_set - set(train_samples) - set(test_samples)) + + train_samples.sort() + val_samples.sort() split_dict = dict() - split_dict['train'] = list(train_set) - split_dict['val'] = list(val_set) - split_dict['test'] = list(test_set) + split_dict['train'] = train_samples + split_dict['val'] = val_samples + split_dict['test'] = test_samples splits.append(split_dict) # Todo: IMO it is better to write that dict as JSON. This (unlike pickle) allows the user to inspect the file with an editor with open(os.path.join(output_dir, 'splits.pkl'), 'wb') as f: pickle.dump(splits, f) + + splits_sanity_check(output_dir) + + +# ToDo: The naming "splits.pkl should not be distributed over multiple files. This makes changing of it less clear. +# Instead move saving and loading to one file. (Here would be a good place) +# Other usages are: spleen/create_splits.py:57 (Which is redundand anyways?); +# UNetExperiment3D.py:55 and UNetExperiment.py:55 +def splits_sanity_check(path): + """ Takes path to a splits file and verifies that no samples from the test dataset leaked into train or validation. + :param path + """ + with open(os.path.join(path, 'splits.pkl'), 'rb') as f: + splits = pickle.load(f) + for i in range(len(splits)): + samples = splits[i] + tr_samples = set(samples["train"]) + vl_samples = set(samples["val"]) + ts_samples = set(samples["test"]) + + assert len(tr_samples.intersection(vl_samples)) == 0, "Train and validation samples overlap!" + assert len(vl_samples.intersection(ts_samples)) == 0, "Validation and Test samples overlap!" + assert len(tr_samples.intersection(ts_samples)) == 0, "Train and Test samples overlap!" + return