-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
executable file
·145 lines (127 loc) · 5.43 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
Central configration file for the project. Acts as a storage
of global variables and various configuration settings.
"""
import os
import pprint
from easydict import EasyDict as edict
from datetime import datetime
CONFIG = edict()
"""
Evaluation of RL-agent
"""
# These settings control the evaluation runs of the saved agents.
# EVAL_RL_log is which saved agent should be used. If a number n, it picks the n:th,
# latest log availabled. Note. n=1 picks the penultimate available log
# If set to a specific log it tries to load that log
CONFIG.EVAL_RL_log = None
CONFIG.EVAL_RL_saved_logs = False # If enabled picks the model from those in saved_logs
CONFIG.EVAL_RL_multiply_images = 1
CONFIG.EVAL_save_vis_iter = 10
CONFIG.EVAL_RL_use_val_set = True
"""
RL-agent
"""
######################### This is where the important settings start #########################
# Batch n Stuff
CONFIG.RL_nbr_epochs = 10000
CONFIG.RL_batch_size = 32
CONFIG.RL_multiply_images = 2
CONFIG.RL_max_episode_length = 10
CONFIG.MISC_priv = False
# Architecture
CONFIG.RL_agent_network = 'LSTMAgent' # AiRLoc agent
CONFIG.RL_patch_embedder = 'ShareNet'
CONFIG.RL_freeze_patch_embedder = True
CONFIG.RL_priv_pretrained = True
CONFIG.EE_temporal = True
CONFIG.EE_residual = True
# Optimizer
CONFIG.RL_learning_rate = 1e-4
CONFIG.RL_nbr_eps_update = (CONFIG.RL_batch_size * CONFIG.RL_multiply_images)//1
CONFIG.RL_weight_decay = 0
CONFIG.RL_momentum = 0.90
CONFIG.RL_optimizer = 'adam'
CONFIG.RL_beta1 = 0.9
CONFIG.RL_beta2 = 0.999
#Env setup
CONFIG.RL_agent_allowed_outside = True
CONFIG.RL_normalize_weights = True
CONFIG.RL_eval_deterministic = True
CONFIG.RL_priv_grid_location = False
CONFIG.RL_priv_use_seg = True # Set to True when training sem seg-based RL-agent (but False during inference -- should not use ground truth then!)
"""
RL Rewards
"""
CONFIG.RL_reward_goal = 3
CONFIG.RL_reward_failed = 0
CONFIG.RL_reward_closer = 0
CONFIG.RL_reward_iou_scale = 0
CONFIG.RL_reward_step_outside = 0
CONFIG.RL_reward_distance = False
CONFIG.RL_reward_step = -1
# LSTM Agent settings
CONFIG.RL_LSTM_pos_emb = True
# Pretrained doerch
#CONFIG.RL_pretrained_doerch_net = 'doerchnet/logs/without-sem-seg' # without sem-seg
CONFIG.RL_pretrained_doerch_net = 'doerchnet/logs/with-sem-seg' # with sem-seg
######################### This is where they end #########################
CONFIG.RL_max_start_goal_dist = 999 # Since CONFIG.MISC_grid_game=True by default --> actual max distance will become min(RL_max_start_goal_dist, grid-size - 1), i.e. 4 in 5x5, 6 in 7x7
CONFIG.RL_min_start_goal_iou = None # Maximum allowed IoU between a start and goal patch (this flag is not used when CONFIG.MISC_grid_game=True, as is default)
CONFIG.RL_done_iou = 0.40 # Since CONFIG.MISC_grid_game=True by default --> the agent is done if and only if it overlaps 100% with the goal. Thus any RL_done_iou \in (0,1] works here.
CONFIG.RL_discount_factor = 0.9
CONFIG.RL_softmax_step_size = 1.1 # Set to 1.1 because 48x48 patches --> 1.1*48 = 52.8, with int(52.8)=52, which implies a grid setup of 48x48 patches with 4 pixel distance in between
CONFIG.RL_entropy = None
CONFIG.RL_entropy_lower = None
# Pretrained segmenter
CONFIG.RL_pretrained_segmentation_net = 'segmentations/logs/sem-seg-model'
CONFIG.RL_predict_seg_mask = False # Set to True during inference if using a sem-seg based RL-agent
"""
Random Search baseline agent
"""
CONFIG.RANDOM_batch_size = 1
CONFIG.RANDOM_using_memory = True # If true, the agent cannot visit the same patch twice
CONFIG.RANDOM_stop_iou = 0.2 # Not used in grid game setup
CONFIG.RANDOM_min_iou_visited = 0.3 # At what IoU should a location be considered already visited (not used in grid game setup)
CONFIG.RANDOM_WARNING_steps = 500 # Warn user if agent takes this many step without funding goal
"""
Statistics / Logging / Plotting
"""
CONFIG.STATS_dir_base = os.path.dirname(os.path.abspath(__file__))
CONFIG.STATS_log_dir_base = os.path.join(CONFIG.STATS_dir_base, 'logs')
CONFIG.STATS_log_dir = os.path.join(CONFIG.STATS_log_dir_base,
str(datetime.now()).replace(' ', '_')
.replace(':', '-').replace('.', '-'))
"""
Plotting
"""
# The option below lets the user choose which LOG directory to plot information from
# An integer signifies the n:th most recent log. A specific log name tries to find that directory
CONFIG.PLOT_log_dir = 1
# The option below lets the user choose which EVAL directory to plot information from.
# I.e, choose which eval session to plot from given a specific training session
CONFIG.PLOT_eval_dir = None
"""
Miscellaneous
"""
CONFIG.MISC_include_baseline = True
CONFIG.MISC_use_gpu = True
CONFIG.MISC_dataset = 'masa_filt'
CONFIG.MISC_dataset_split_file = None
CONFIG.MISC_grid_game = True
CONFIG.MISC_random_seed = 0
#CONFIG.MISC_rnd_crop = True
CONFIG.MISC_rgb_max = 255
#CONFIG.MISC_im_size = (256, 256)
CONFIG.MISC_step_sz = int(48*CONFIG.RL_softmax_step_size)
CONFIG.MISC_game_size = 5
CONFIG.MISC_im_size = (int(CONFIG.MISC_step_sz*(CONFIG.MISC_game_size-1)+48),
int(CONFIG.MISC_step_sz*(CONFIG.MISC_game_size-1)+48))
CONFIG.MISC_patch_size = (48, 48)
CONFIG.MISC_print_iter = 50
CONFIG.MISC_save_vis_iter = 400 # How often we save a visualiation
CONFIG.MISC_vis_per_batch = 12
CONFIG.MISC_save_model_iter = 5000 # How often should we save the model weights
CONFIG.MISC_project_root_path = os.path.dirname(__file__)
CONFIG.MISC_main_pid = os.getpid()
CONFIG.MISC_dataset_path = "data" # Set accordingly