-
Notifications
You must be signed in to change notification settings - Fork 1
Dataset
The dataset is generated using the scripts/create_shape_dataset.py
script.
It generates a training, validation, and testing set of 3 different domains: visual (v), attributes (attr), and text (t).
Here are some visual examples:
You can configure the generation of the dataset from the config. Here are the relevant configuration entries:
-
seed
seeds the generation of the dataset. -
img_size
width and height of the images. -
simple_shapes_path
where to save the dataset. Please take a look below for the structure of the generated dataset. -
domain_loaders.t.bert_latents
where to save the BERT features of the text domain. It is a.npy
file.
The dataset is created in the location given by simple_shapes_path
.
It generates:
- 1 000 000 training samples: Only the first 500 000 are used during training. The other 500 000 were created if we wanted to separate the data used for the unmatched and matched examples, which still needs to be done in the end.
- 50 000 validation and 50 000 test examples.
It generates different files and folders:
-
train
,val
, andtest
folders with the different images; - the attributes in the files
$split_labels.npy
with$split
in['train', 'val', 'test']
; - the text captions in the files
$split_captions.npy
; - the text caption choices in the
$split_caption_choices.npy
(choices used to create the text structure using the text generator); - the BERT features in the files
$split_${domain_loaders.t.bert_latents}
;
from bim_gw.datasets import load_dataset
from bim_gw.utils import get_args
from omegaconf import OmegaConf
args = get_args()
# In most cases
local_args = args.global_workspace
# If you want more control. For instance, when training a unimodal module.
local_args = OmegaConf.create({
"batch_size": 32,
"prop_labelled_images": 1., # proportion of paired examples. 1 => all dataset is paired, 0 => no paired example.
"remove_sync_domains": None, # whetherremove some synchronization combinationoved.
# If [["v", "attr"]], there will be no vision and attribute pairs used for training.
"split_ood": False, # (bool) whether to train with a specific distribution split and
# test with in_dist and out-of-distribution sets.
"selected_domains": ["v"], # domains to load in the dataset. Here we only get visual data.
# ======
# Optional parameters
# ======
"use_pre_saved": true, # whether to use pre-saved latents instead of the unimodal module. Particularly useful for the visual domains.
# If true, will use args.global_workspace.load_pre_saved_latents parameter to load the saved latents.
# More details in the next chapter.
"sync_uses_whole_dataset": false, # whether to use the 1 000 000 examples.
# If false, only the first 500 000.
# Keep at false unless you know what you are doing.
})
data = load_dataset(args, local_args) # SimpleShapesDataModule instance
data.prepare_data()
data.setup()
# {test, val}_dataloader["in_dist"] is the "in distribution," and the {test, val}_dataloader["ood"] is the "out-of-distribution" dataset if an OOD split is set.
test_dataloader = iter(data.test_dataloader()["in_dist"])
domains = next(test_dataloader)
vision_domain = domains["v"]
vision_sub_parts = vision_domain.sub_parts
vision_available_masks = vision_domain.available_masks
available_vision_domains = vision_sub_parts[vision_available_masks] # only keep items where visual information is available.
# The unavailable items are filled with 0 values.
# Note that only the training set has masked items. All samples from the validation and test sets are matched.
In the backend, load_dataset
will instantiate the bim_gw.datasets.simple_shapes.data_modules.SimpleShapesDataModule
datamodule class.
This class instantiates three bim_gw.datasets.simple_shapes.datasets.SimpleShapesDataset
classes for training, validation, and testing.
Note: we will use the data
variable as an instance of this class.
-
data.train_set
,data.val_set
,data.test_set
the SimpleShapesDataset instances. -
data.train_dataloader(shuffle=True)
, by default, shuffle is enabled. -
data.val_dataloader()
,data.test_dataloader()
The __init__
of the DataModule class should also add unimodal domain loader callbacks associated with this dataset.
The next chapter will detail module registries and how to add domain callbacks.
dataset = data.train_set
Typical dataset access:
-
len(dataset)
, size of the dataset; -
domain_item = dataset[k]
, get first item.
SimpleShapesDataset
uses data fetchers in bim_gw.datasets.simple_shapes.fetchers
to fetch domain-related information from the dataset.
-
domain_fetcher.get_null_item()
return the 0 item when the item is masked -
domain_fetcher.get_item(k)
return domain-specific data -
domain_fetcher.get_items(k)
returns transformed domain-specific data if k is not ; otherwisewise,domain_fetcher.get_null_item()
.
Let's get data from a dataloader:
test_dataloader = iter(data.test_dataloader()["in_dist"]) # keep the in_dist dataloader. Similar to the out-of-distribution one.
domains = next(test_dataloader)
domains
is a dictionary where the keys are the domain names (as filled in the configuration value global_workspace.selected_domains
).
Let's now look at the format of one domain:
visual_domain = domains['v']
visual_domain
is a list that corresponds to the output of the visual data fetcher.
-
visual_domain.available_masks
is the domain_mask. It's 0 if the domain is not available and 1 otherwise. -
visual_domain["img"]
is the image data. -
generic_domain.sub_parts
for images, there is only one raw output (the image), but there can be several. For example, the text domain has BERT attributes, the raw text sentence, and a dictionary containing the choices to generate the sentence structure. generic_domain[key] == generic_domain.sub_parts[key]
To add a new dataset, register a new dataset entry to the DatasetRegister:
from bim_gw.utils import registries
# Callback that instantiates the data module
@registries.register_dataset("new_dataset")
def load_new_dataset(args, local_args, **kwargs):
import NewDataModule # import in the callback to prevent loading unnecessary datasets
return NewDataModule(...) # instantiate as you need
After registering, registries.get_dataset()
can load your dataset. Set the configuration value for args.current_dataset
to "new_dataset".
If you prefer to use registering without a decorator, you can do:
from bim_gw.utils import registries
def load_new_dataset(args, local_args, **kwargs):
import NewDataModule # import in the callback to prevent loading unnecessary datasets
return NewDataModule(...) # instantiate as you need
registries.add_dataset_to_registry("shapes", load_new_dataset)