Skip to content

Dataset

Benjamin Devillers edited this page Mar 30, 2023 · 17 revisions

Creating the dataset

The dataset is generated using the scripts/create_shape_dataset.py script. It generates a training, validation, and testing set of 3 different domains: visual (v), attributes (attr), and text (t).

Here are some visual examples:

Dataset visual examples

Configuration

You can configure the generation of the dataset from the config. Here are the relevant configuration entries:

  • seed seeds the generation of the dataset.
  • img_size width and height of the images.
  • simple_shapes_path where to save the dataset. Please take a look below for the structure of the generated dataset.
  • domain_loaders.t.bert_latents where to save the BERT features of the text domain. It is a .npy file.

Generated dataset

The dataset is created in the location given by simple_shapes_path.

It generates:

  • 1 000 000 training samples: Only the first 500 000 are used during training. The other 500 000 were created if we wanted to separate the data used for the unmatched and matched examples, which still needs to be done in the end.
  • 50 000 validation and 50 000 test examples.

It generates different files and folders:

  • train, val, and test folders with the different images;
  • the attributes in the files $split_labels.npy with $split in ['train', 'val', 'test'];
  • the text captions in the files $split_captions.npy;
  • the text caption choices in the $split_caption_choices.npy (choices used to create the text structure using the text generator);
  • the BERT features in the files $split_${domain_loaders.t.bert_latents};

Load data

from bim_gw.datasets import load_dataset
from bim_gw.utils import get_args
from omegaconf import OmegaConf


args = get_args()

# In most cases
local_args = args.global_workspace
# If you want more control. For instance, when training a unimodal module.
local_args = OmegaConf.create({
  "batch_size": 32,
  "prop_labelled_images": 1.,  # proportion of paired examples. 1 => all dataset is paired, 0 => no paired example.
  "remove_sync_domains": None,  # whetherremove some synchronization combinationoved. 
                                # If [["v", "attr"]], there will be no vision and attribute pairs used for training.
  "split_ood": False,  # (bool) whether to train with a specific distribution split and
                       # test with in_dist and out-of-distribution sets.
  "selected_domains": ["v"],  # domains to load in the dataset. Here we only get visual data.
  # ======
  # Optional parameters
  # ======
  "use_pre_saved": true,  # whether to use pre-saved latents instead of the unimodal module. Particularly useful for the visual domains.
                          # If true, will use args.global_workspace.load_pre_saved_latents parameter to load the saved latents.
                          # More details in the next chapter.
  "sync_uses_whole_dataset": false,  # whether to use the 1 000 000 examples. 
                                     # If false, only the first 500 000. 
                                     # Keep at false unless you know what you are doing.
})

data = load_dataset(args, local_args)  # SimpleShapesDataModule instance
data.prepare_data()
data.setup()

# {test, val}_dataloader["in_dist"] is the "in distribution," and the {test, val}_dataloader["ood"] is the "out-of-distribution" dataset if an OOD split is set.
test_dataloader = iter(data.test_dataloader()["in_dist"])  
domains = next(test_dataloader)
vision_domain = domains["v"]

vision_sub_parts = vision_domain.sub_parts
vision_available_masks = vision_domain.available_masks
available_vision_domains = vision_sub_parts[vision_available_masks]  # only keep items where visual information is available.
# The unavailable items are filled with 0 values.
# Note that only the training set has masked items. All samples from the validation and test sets are matched.

More on the dataset classes

In the backend, load_dataset will instantiate the bim_gw.datasets.simple_shapes.data_modules.SimpleShapesDataModule datamodule class. This class instantiates three bim_gw.datasets.simple_shapes.datasets.SimpleShapesDataset classes for training, validation, and testing.

SimpleShapesDataModule useful properties and methods

Note: we will use the data variable as an instance of this class.

  • data.train_set, data.val_set, data.test_set the SimpleShapesDataset instances.
  • data.train_dataloader(shuffle=True), by default, shuffle is enabled.
  • data.val_dataloader(), data.test_dataloader()

The __init__ of the DataModule class should also add unimodal domain loader callbacks associated with this dataset. The next chapter will detail module registries and how to add domain callbacks.

SimpleShapesDataset

dataset = data.train_set

Typical dataset access:

  • len(dataset), size of the dataset;
  • domain_item = dataset[k], get first item.

Data fetchers

SimpleShapesDataset uses data fetchers in bim_gw.datasets.simple_shapes.fetchers to fetch domain-related information from the dataset.

  • domain_fetcher.get_null_item() return the 0 item when the item is masked
  • domain_fetcher.get_item(k) return domain-specific data
  • domain_fetcher.get_items(k) returns transformed domain-specific data if k is not ; otherwisewise, domain_fetcher.get_null_item().

Dataset output format

Let's get data from a dataloader:

test_dataloader = iter(data.test_dataloader()["in_dist"])  # keep the in_dist dataloader. Similar to the out-of-distribution one.
domains = next(test_dataloader)

domains is a dictionary where the keys are the domain names (as filled in the configuration value global_workspace.selected_domains).

Let's now look at the format of one domain:

visual_domain = domains['v']

visual_domain is a list that corresponds to the output of the visual data fetcher.

  • visual_domain.available_masks is the domain_mask. It's 0 if the domain is not available and 1 otherwise.
  • visual_domain["img"] is the image data.
  • generic_domain.sub_parts for images, there is only one raw output (the image), but there can be several. For example, the text domain has BERT attributes, the raw text sentence, and a dictionary containing the choices to generate the sentence structure.
  • generic_domain[key] == generic_domain.sub_parts[key]

Dataset Register

To add a new dataset, register a new dataset entry to the DatasetRegister:

from bim_gw.utils import registries

# Callback that instantiates the data module

@registries.register_dataset("new_dataset")
def load_new_dataset(args, local_args, **kwargs):
  import NewDataModule  # import in the callback to prevent loading unnecessary datasets

  return NewDataModule(...)  # instantiate as you need

After registering, registries.get_dataset() can load your dataset. Set the configuration value for args.current_dataset to "new_dataset".

If you prefer to use registering without a decorator, you can do:

from bim_gw.utils import registries

def load_new_dataset(args, local_args, **kwargs):
  import NewDataModule  # import in the callback to prevent loading unnecessary datasets

  return NewDataModule(...)  # instantiate as you need

registries.add_dataset_to_registry("shapes", load_new_dataset)

Up next

← Using trained models | Unimodal modules →