-
Notifications
You must be signed in to change notification settings - Fork 25
/
config.py
97 lines (82 loc) · 2.73 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
from collections import OrderedDict
class Config(object):
def __init__(self, filename, update_config_string=""):
lines = open(filename).readlines()
# remove comments (starting with #)
lines = [l if not l.strip().startswith("#") else "\n" for l in lines]
lines = [l.split('#')[0] if '#' in l else l for l in lines]
s = "".join(lines)
print(s)
self._entries = json.loads(s, object_pairs_hook=OrderedDict)
if update_config_string != "":
config_string_entries = json.loads(update_config_string, object_pairs_hook=OrderedDict)
print("Updating given config with dict", config_string_entries)
self._entries.update(config_string_entries)
def has(self, key):
return key in self._entries
def _value(self, key, dtype, default):
if default is not None:
assert isinstance(default, dtype)
if key in self._entries:
val = self._entries[key]
if isinstance(val, dtype):
return val
else:
raise TypeError()
else:
assert default is not None
return default
def _list_value(self, key, dtype, default):
if default is not None:
assert isinstance(default, list)
for x in default:
assert isinstance(x, dtype)
if key in self._entries:
val = self._entries[key]
assert isinstance(val, list)
for x in val:
assert isinstance(x, dtype)
return val
else:
assert default is not None
return default
def bool(self, key, default=None):
return self._value(key, bool, default)
def str(self, key, default=None):
if isinstance(default, str):
default = str(default)
return self._value(key, str, default)
def int(self, key, default=None):
return self._value(key, int, default)
def float(self, key, default=None):
return self._value(key, float, default)
def dict(self, key, default=None):
return self._value(key, dict, default)
def int_key_dict(self, key, default=None):
if default is not None:
assert isinstance(default, dict)
for k in list(default.keys()):
assert isinstance(k, int)
dict_str = self.str(key, "")
if dict_str == "":
assert default is not None
res = default
else:
res = eval(dict_str)
assert isinstance(res, dict)
for k in list(res.keys()):
assert isinstance(k, int)
return res
def int_list(self, key, default=None):
return self._list_value(key, int, default)
def float_list(self, key, default=None):
return self._list_value(key, float, default)
def str_list(self, key, default=None):
return self._list_value(key, str, default)
def dir(self, key, default=None):
p = self.str(key, default)
if p[-1] != "/":
return p + "/"
else:
return p