-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconfig.py
133 lines (111 loc) · 4.3 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
# Code adapted from:
# https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py
Source License
# Copyright (c) 2017-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
#
# Based on:
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------
"""
##############################################################################
#Config
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
import os
from utils.attr_dict import AttrDict
__C = AttrDict()
cfg = __C
__C.ITER = 0
__C.EPOCH = 0
__C.RANDOM_SEED = 304
# Use Class Uniform Sampling to give each class proper sampling
__C.CLASS_UNIFORM_PCT = 0.0
# Use class weighted loss per batch to increase loss for low pixel count classes per batch
__C.BATCH_WEIGHTING = False
# Border Relaxation Count
__C.BORDER_WINDOW = 1
# Number of epoch to use before turn off border restriction
__C.REDUCE_BORDER_ITER = -1
__C.REDUCE_BORDER_EPOCH = -1
# Comma Seperated List of class id to relax
__C.STRICTBORDERCLASS = None
datasetroot = os.path.expanduser('~/dg_seg_dataset/')
#Attribute Dictionary for Dataset
__C.DATASET = AttrDict()
#Cityscapes Dir Location
__C.DATASET.CITYSCAPES_DIR = os.path.join(datasetroot,'cityscapes')
#SDC Augmented Cityscapes Dir Location
__C.DATASET.CITYSCAPES_AUG_DIR = ''
#Cityscapes Dir Location
__C.DATASET.IDD_DIR = os.path.join(datasetroot,'idd')
#SDC Augmented Cityscapes Dir Location
__C.DATASET.IDD_AUG_DIR = ''
#Mapillary Dataset Dir Location
__C.DATASET.MAPILLARY_DIR = os.path.join(datasetroot,'mapillary')
#GTAV, BDD100K Dataset Dir Location
__C.DATASET.GTAV_DIR = os.path.join(datasetroot,'GTAV')
__C.DATASET.BDD_DIR = os.path.join(datasetroot,'bdd-100k')
#Synthia Dataset Dir Location
__C.DATASET.SYNTHIA_DIR = os.path.join(datasetroot,'synthia')
#Kitti Dataset Dir Location
__C.DATASET.KITTI_DIR = ''
#SDC Augmented Kitti Dataset Dir Location
__C.DATASET.KITTI_AUG_DIR = ''
#Camvid Dataset Dir Location
__C.DATASET.CAMVID_DIR = ''
#Number of splits to support
__C.DATASET.CV_SPLITS = 3
__C.MODEL = AttrDict()
__C.MODEL.BN = 'pytorch-syncnorm'
__C.MODEL.BNFUNC = torch.nn.SyncBatchNorm
def assert_and_infer_cfg(args, make_immutable=True, train_mode=True):
"""Call this function in your script after you have finished setting all cfg
values that are necessary (e.g., merging a config from a file, merging
command line config options, etc.). By default, this function will also
mark the global cfg as immutable to prevent changing the global cfg settings
during script execution (which can lead to hard to debug errors or code
that's harder to understand than is necessary).
"""
if hasattr(args, 'syncbn') and args.syncbn:
__C.MODEL.BN = 'pytorch-syncnorm'
__C.MODEL.BNFUNC = torch.nn.SyncBatchNorm
print('Using pytorch sync batch norm')
else:
__C.MODEL.BNFUNC = torch.nn.BatchNorm2d
print('Using regular batch norm')
if not train_mode:
cfg.immutable(True)
return
if args.class_uniform_pct:
cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct
if args.batch_weighting:
__C.BATCH_WEIGHTING = True
if args.jointwtborder:
if args.strict_bdr_cls != '':
__C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")]
if args.rlx_off_iter > -1:
__C.REDUCE_BORDER_ITER = args.rlx_off_iter
if make_immutable:
cfg.immutable(True)