-
Notifications
You must be signed in to change notification settings - Fork 0
/
mapl_cirup.py
366 lines (303 loc) · 14.2 KB
/
mapl_cirup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
"""
MArkov PLanning with CIRcuit bellman UPdates (mapl-cirup)
"""
import copy
import time
import numpy as np
from typing import Dict, List
from problog import get_evaluatable
from problog.clausedb import ClauseDB
from problog.formula import LogicFormula
from problog.logic import Term, Constant, Clause, Not, And
from problog.program import PrologFile
from problog.engine import DefaultEngine
from problog.sdd_formula_explicit import SDDExplicit, x_constrained_named
from ddc import DDC
class MaplCirup:
"""
Class for mapl-cirup.
The circuit is symbolically representing both the utility function U(s) and the policy π(s), for each explicit
state s.
"""
_engine = DefaultEngine(label_all=True, keep_order=True) # Grounding engine
_next_state_functor = 'x' # Name of the functor used to indicate the next state
_true_term = Term('true') # True term used for querying
_decision_term = Term("?") # Decision term used for querying
_horizon = 0 # Default future lookahead
_discount = 1 # Default discount factor
_error = 0.01 # Default maximum error allowed for value iteration
_utilities: Dict[Term, List[Term]] = {} # Added (expected) utility parameters with their state
_iterations_count: int = 0 # Number of iterations for policy convergence
_vi_time = None # Time required for value iteration
_minimisation_on = False # Activate the SDD minimisation
_minimize_time = 0 # Default minimisation time
def __init__(self, filename, minimisation=False):
"""
DDC initialization. The overall steps are the following:
1. Parse the two input model
3. Fill in some class parameters to perform inference later
4. Grounding
5. Knowledge compilation (KC)
6. Initialize the maximum expected utility semiring
Notice that some operations must be in this order because they may depend on something retrieved earlier.
:param filename: Input file in the ProbLog format
"""
prog = self._parsing(filename)
self._rewards = self._get_reward_func(prog)
self._decisions = self._get_decisions(prog)
self._state_vars = MaplCirup._get_state_vars(prog)
self._add_state_priors(prog)
self._add_utility_parameters(prog)
grounded_prog = self._grounding(prog)
print("Compiling...")
starttime_compilation = time.time()
sdd = self._compilation(grounded_prog)
# TODO Print the size of the SDD after compiling (also after minimising maybe)
if minimisation or self._minimisation_on:
print("Minimizing...")
starttime_minimization = time.time()
self._minimize(sdd)
endtime_minimization = time.time()
self._minimize_time = endtime_minimization - starttime_minimization
self._ddc: DDC = DDC.create_from(sdd, self._state_vars, self._rewards)
endtime_compilation = time.time()
self._compile_time = endtime_compilation - starttime_compilation
print("Compilation done! (circuit size: %s)" % self.size())
self._remove_impossible_states()
# (p, eu, dec) = self._ddc.maxeu() # {'hit': False}
# print("DDC maxeu eval: %s, %s, %s" % (p, eu, dec))
self._ddc.print_info()
return
def _remove_impossible_states(self):
imp_util = self._ddc.impossible_utilities()
print("\nImpossible states (%s/%s):" % (len(imp_util), 2 ** len(self._state_vars)))
to_remove = []
for u, state in self._utilities.items():
if str(u) in imp_util:
print(str(state))
to_remove.append(u)
for u in to_remove:
self._utilities.pop(u)
print()
def _parsing(self, file_path="") -> ClauseDB:
"""
Parse the input model located at 'file_path'.
:return: The parsed program.
"""
return self._engine.prepare(PrologFile(file_path))
def _get_reward_func(self, program: ClauseDB) -> Dict[Term, Constant]:
"""
Retrieve utilities from the parsed program.
:param program: Parsed program.
:return: A dictionary {reward: val}.
"""
return dict(self._engine.query(program, Term('utility', None, None)))
def _get_decisions(self, program: ClauseDB) -> List[Term]:
"""
Retrieve decisions from the parsed program.
:param program: Parsed program.
:return: A list containing all the decisions in the model.
"""
decisions = set()
# Retrieve the decisions
for _, node in program.enum_nodes():
if not node:
continue
node_type = type(node).__name__
if hasattr(node, 'probability'):
if node.probability == self._decision_term:
if node_type == 'choice': # unpack from the choice node
decisions.add(node.functor.args[2])
else:
decisions.add(Term(node.functor, *node.args))
return list(decisions)
@staticmethod
def _get_state_vars(program: ClauseDB) -> List[Term]:
"""
Retrieve the variables representing the state. Notice that it returns only the state variables and not the
decision ones.
:param program: Parsed program.
:return: A set of terms, each is a state variable.
"""
for rule in program:
if type(rule) == Term and rule.functor == 'state_variables':
return list(rule.args)
return []
def _add_state_priors(self, parsed_prog: ClauseDB) -> None:
for var in self._state_vars:
new_var = copy.deepcopy(var)
if var.probability is None:
new_var.probability = Constant(1.0)
parsed_prog.add_fact(new_var)
def _add_utility_parameters(self, parsed_prog: ClauseDB) -> None:
"""
Add parameters representing the future expected utility. They must be connected to primed variables, i.e. the
next state in the transition function.
TODO: Optionally add a utility parameter if the corresponding state has probability > 0.
:param parsed_prog: Parsed program.
:return: Void.
"""
utility_idx: int = 0
state = self._state_vars
while state:
utility_term = Term('u' + str(utility_idx))
self._utilities[utility_term] = state
parsed_prog.add_clause(Clause(utility_term, MaplCirup.big_and(self._wrap_in_next_state_functor(state))))
utility_idx += 1
state = MaplCirup.enumeration_next(state)
def _wrap_in_next_state_functor(self, state: List[Term]) -> List[Term]:
wrapped_state = []
for term in state:
if isinstance(term, Not):
wrapped_state.append(Term(self._next_state_functor, term.args[0]).__invert__())
else:
wrapped_state.append(Term(self._next_state_functor, term))
return wrapped_state
def _grounding(self, parsed_prog: ClauseDB) -> LogicFormula:
"""
Ground the parsed programs.
:param parsed_prog: Parsed program.
:return: Grounded program.
"""
queries = self._decisions + list(self._rewards) + list(self._utilities.keys())
queries += list(map(lambda v: Term(self._next_state_functor, v), self._state_vars))
# fix an order to have the same circuit size at each execution
# (for some reason the reverse order leads to smaller circuits)
queries.sort(key=repr, reverse=True)
queries.append(self._true_term)
return self._engine.ground_all(parsed_prog, queries=queries)
def _compilation(self, grounded_prog: LogicFormula) -> SDDExplicit:
"""
Knowledge compilation into X-constrained SDDs of the model.
:param grounded_prog: Grounded model.
:return: The circuit for the given model.
"""
# print("Start compilation")
kc_class = get_evaluatable(name='sddx')
constraints = x_constrained_named(X_named=self._decisions)
# starttime_compilation = time.time()
circuit: SDDExplicit = kc_class.create_from(grounded_prog, var_constraint=constraints)
# endtime_compilation = time.time()
# compile_time = endtime_compilation - starttime_compilation
# print("Compilation took %s seconds." % compile_time)
return circuit
def _minimize(self, sdd: SDDExplicit) -> None:
"""
SDD Minimization (Sec 5.5 of the advanced manual).
"""
# If one wants to limit times more strictly. Default parameters should be: 180, 60, 30, 10.
# self._circuit.get_manager().get_manager().set_vtree_search_time_limit(60)
# self._circuit.get_manager().get_manager().set_vtree_fragment_time_limit(20)
# self._circuit.get_manager().get_manager().set_vtree_operation_time_limit(10)
# self._circuit.get_manager().get_manager().set_vtree_apply_time_limit(5)
# The following call to 'ref()' is required otherwise the minimization removes necessary nodes
sdd.get_root_inode().ref()
sdd.get_manager().get_manager().minimize_limited()
@staticmethod
def enumeration_next(state: List[Term]) -> List[Term]:
"""
Takes in input a state (as a list of terms), and return the next in the enumeration order.
Assumptions:
- all terms are binary variables
- the first state is when all terms are true
- the last state is when all terms are false
Enumeration order example: [x1,x2], [¬x1,x2], [x1,¬x2], [¬x1,¬x2].
:param state: Current state, represented as a list or Terms.
:return: The next state in the enumeration order. Returns an empty list when the final state is given in input.
"""
next_state: List[Term] = copy.deepcopy(state)
for idx, term in enumerate(state):
if isinstance(term, Not):
next_state[idx] = term.args[0]
else:
next_state[idx] = term.__invert__()
return next_state
return []
@staticmethod
def big_and(terms: List[Term]) -> Term:
"""
Transform a list of terms into a concatenation of logical ands. For example, [x,y,z] -> And(x,And(y,z)).
:param terms: List of terms to be concatenated.
:return: The big and concatenation. If only one term is in the list, it returns the term itself.
"""
if len(terms) == 1:
return terms[0]
else:
return And(terms.pop(), MaplCirup.big_and(terms))
def value_iteration(self, discount: float = None, error: float = None, horizon: int = None) -> None:
starttime_vi = time.time()
if discount is not None:
self._discount = discount
if error is not None:
self._error = error
if horizon is not None:
self._horizon = horizon
old_utility = np.zeros(2**len(self._state_vars))
while True:
if self._discount == 1 or horizon is not None: # loop for horizon length
if self._iterations_count >= self._horizon:
break
new_utility = self._ddc.max_eu()
u_idx = 0
for u in new_utility:
self._ddc.set_utility_label('u'+str(u_idx), self._discount * u)
u_idx += 1
delta = np.linalg.norm(new_utility-old_utility, ord=np.inf)
old_utility = new_utility
self._iterations_count += 1
# print('Iteration ' + str(self._iterations_count) + ' with delta: ' + str(delta))
if self._discount < 1:
if horizon is not None:
# if the horizon is set, loop for horizon length (with discount)
if self._iterations_count >= self._horizon:
break
else:
# loop until convergence
if delta <= self._error: # * (1-self._discount) / self._discount:
break
endtime_vi = time.time()
self._vi_time = endtime_vi - starttime_vi
def print_explicit_policy(self) -> None:
print("\nPOLICY FUNCTION:\n")
state = self._state_vars
while state:
# collect state evidence
state_evidence: Dict[str, bool] = dict()
for term in state:
term_var = str(term.args[0] if isinstance(term, Not) else term)
state_evidence[term_var] = False if isinstance(term, Not) else True
_, eu, decisions = self._ddc.best_dec(state_evidence)
print(str(state) + ' -> ' + str(decisions) + ' (eu: ' + str(eu) + ')')
state = MaplCirup.enumeration_next(state)
def set_horizon(self, horizon: int) -> None:
self._horizon = horizon
def set_discount_factor(self, discount: float) -> None:
self._discount = discount
def size(self) -> int:
"""
Returns the size of the circuit.
"""
return self._ddc.size()
def iterations(self) -> int:
return self._iterations_count
def variables_number(self) -> int:
return len(self._state_vars)
def compile_time(self) -> float:
"""
Returns the amount of time required for compilation.
"""
return self._compile_time
def minimize_time(self) -> float:
"""
Returns the amount of time required for compilation.
"""
return self._minimize_time
def value_iteration_time(self) -> float:
return self._vi_time
def tot_time(self) -> float:
return self._compile_time + self._minimize_time + (self._vi_time if self._vi_time is not None else 0)
def view_dot(self) -> None:
"""
View the dot representation of the transition circuit.
"""
self._ddc.view_dot()