This code use python 3.10.0
To run this program, you'll need to install the following Python packages:
The prefered installation method is by using conda:
conda env create -f environment.yml
- babi-tasks: automatically downloaded from http://www.thespermwhale.com/jaseweston/babi/
- Omniglot embedding: We provide omniglot embedding from a standard CNN prototypical network the data are located at "./datasets/omniglot_proto_emb/"
All our tasks relies on hydra configuration system.
Single task: (n is the wanted babi task)
python run_babi.py task_id=\'n\'
Joint training (reproduce setup lsp joint experiment):
python run_babi.py task_id=\'1,4,5,6,7,8,9,10,11,12,13,14,15,18,20\' metric=babi_joint
Pair associations:
python run_rl.py task=ap_omniglot_fixed training=ap_fixed task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=ap_pg_fa
Match-to-sample (fixed example version):
python run_rl.py task=mp_omniglot_fixed training=mp_fixed task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=mp_pg_fa
Match-to-sample (sampled examples version):
python run_rl.py task=mp_omniglot_sampled training=mp_sampled task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=mp_pg_fa
Radial maze:
python run_rl.py task=radial_omniglot_fixed training=radial_fixed task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=radial_pg_fa
Radial maze (reward switching):
python run_rl.py task=radial_omniglot_fixed training=radial_fixed task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=radial_pg_fa task.switch_reward_rule=everytime