Skip to content

Latest commit

 

History

History
executable file
·
64 lines (54 loc) · 2.08 KB

README.md

File metadata and controls

executable file
·
64 lines (54 loc) · 2.08 KB

README

Abstract

Installation

This code use python 3.10.0

To run this program, you'll need to install the following Python packages:

The prefered installation method is by using conda:

conda env create -f environment.yml

Datasets

  • babi-tasks: automatically downloaded from http://www.thespermwhale.com/jaseweston/babi/
  • Omniglot embedding: We provide omniglot embedding from a standard CNN prototypical network the data are located at "./datasets/omniglot_proto_emb/"

Usage

All our tasks relies on hydra configuration system.

Babi tasks

Single task: (n is the wanted babi task)

 python run_babi.py task_id=\'n\'

Joint training (reproduce setup lsp joint experiment):

 python run_babi.py task_id=\'1,4,5,6,7,8,9,10,11,12,13,14,15,18,20\' metric=babi_joint

RL tasks

Pair associations:

python run_rl.py task=ap_omniglot_fixed training=ap_fixed task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=ap_pg_fa

Match-to-sample (fixed example version):

python run_rl.py task=mp_omniglot_fixed training=mp_fixed task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=mp_pg_fa

Match-to-sample (sampled examples version):

python run_rl.py task=mp_omniglot_sampled training=mp_sampled task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=mp_pg_fa

Radial maze:

python run_rl.py task=radial_omniglot_fixed training=radial_fixed task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=radial_pg_fa

Radial maze (reward switching):

python run_rl.py task=radial_omniglot_fixed training=radial_fixed task.generator_params.omniglot_path=./datasets/omniglot_proto_emb/test_set model=radial_pg_fa task.switch_reward_rule=everytime