diff --git a/dist/macq-0.3.3.tar.gz b/dist/macq-0.3.3.tar.gz deleted file mode 100644 index 8a3c1e7e..00000000 Binary files a/dist/macq-0.3.3.tar.gz and /dev/null differ diff --git a/dist/macq-0.3.3-py3-none-any.whl b/dist/macq-0.3.4-py3-none-any.whl similarity index 74% rename from dist/macq-0.3.3-py3-none-any.whl rename to dist/macq-0.3.4-py3-none-any.whl index 4a880c6f..eba4b776 100644 Binary files a/dist/macq-0.3.3-py3-none-any.whl and b/dist/macq-0.3.4-py3-none-any.whl differ diff --git a/dist/macq-0.3.4.tar.gz b/dist/macq-0.3.4.tar.gz new file mode 100644 index 00000000..92602495 Binary files /dev/null and b/dist/macq-0.3.4.tar.gz differ diff --git a/docs/macq.html b/docs/macq.html index 0f74d5b0..1f0e588b 100644 --- a/docs/macq.html +++ b/docs/macq.html @@ -3,7 +3,7 @@
- +167 def to_pddl_lifted( -168 self, -169 domain_name: str, -170 problem_name: str, -171 domain_filename: str, -172 problem_filename: str, -173 ): -174 """Dumps a Model with typed lifted actions & fluents to PDDL files. -175 -176 Args: -177 domain_name (str): -178 The name of the domain to be generated. -179 problem_name (str): -180 The name of the problem to be generated. -181 domain_filename (str): -182 The name of the domain file to be generated. -183 problem_filename (str): -184 The name of the problem file to be generated. -185 """ -186 self.fluents: Set[LearnedLiftedFluent] -187 self.actions: Set[LearnedLiftedAction] -188 -189 lang = tarski.language(domain_name) -190 problem = tarski.fstrips.create_fstrips_problem( -191 domain_name=domain_name, problem_name=problem_name, language=lang -192 ) -193 sorts = set() -194 -195 if self.fluents: -196 for f in self.fluents: -197 for sort in f.param_sorts: -198 if sort not in sorts: -199 lang.sort(sort) -200 sorts.add(sort) -201 -202 lang.predicate(f.name, *f.param_sorts) -203 -204 if self.actions: -205 for a in self.actions: -206 vars = [lang.variable(f"x{i}", s) for i, s in enumerate(a.param_sorts)] -207 -208 if len(a.precond) == 1: -209 precond = lang.get(list(a.precond)[0].name)(*[vars[i] for i in a.precond[0].param_act_inds]) # type: ignore -210 else: -211 precond = CompoundFormula( -212 Connective.And, -213 [ -214 lang.get(f.name)(*[vars[i] for i in f.param_act_inds]) # type: ignore -215 for f in a.precond -216 ], -217 ) -218 -219 adds = [lang.get(f.name)(*[vars[i] for i in f.param_act_inds]) for f in a.add] # type: ignore -220 dels = [lang.get(f.name)(*[vars[i] for i in f.param_act_inds]) for f in a.delete] # type: ignore -221 effects = [fs.AddEffect(e) for e in adds] + [fs.DelEffect(e) for e in dels] # fmt: skip -222 -223 problem.action( -224 a.name, -225 parameters=vars, -226 precondition=precond, -227 effects=effects, -228 ) -229 -230 problem.init = tarski.model.create(lang) # type: ignore -231 problem.goal = land() # type: ignore -232 writer = iofs.FstripsWriter(problem) -233 writer.write(domain_filename, problem_filename) +@@ -1855,61 +2093,61 @@164 def to_pddl_lifted( +165 self, +166 domain_name: str, +167 problem_name: str, +168 domain_filename: str, +169 problem_filename: str, +170 ): +171 """Dumps a Model with typed lifted actions & fluents to PDDL files. +172 +173 Args: +174 domain_name (str): +175 The name of the domain to be generated. +176 problem_name (str): +177 The name of the problem to be generated. +178 domain_filename (str): +179 The name of the domain file to be generated. +180 problem_filename (str): +181 The name of the problem file to be generated. +182 """ +183 self.fluents: Set[LearnedLiftedFluent] +184 self.actions: Set[LearnedLiftedAction] +185 +186 lang = tarski.language(domain_name) +187 problem = tarski.fstrips.create_fstrips_problem( +188 domain_name=domain_name, problem_name=problem_name, language=lang +189 ) +190 sorts = set() +191 +192 if self.fluents: +193 for f in self.fluents: +194 for sort in f.param_sorts: +195 if sort not in sorts: +196 lang.sort(sort) +197 sorts.add(sort) +198 +199 lang.predicate(f.name, *f.param_sorts) +200 +201 if self.actions: +202 for a in self.actions: +203 vars = [lang.variable(f"x{i}", s) for i, s in enumerate(a.param_sorts)] +204 +205 if len(a.precond) == 1: +206 precond = lang.get(list(a.precond)[0].name)(*[vars[i] for i in a.precond[0].param_act_inds]) # type: ignore +207 else: +208 precond = CompoundFormula( +209 Connective.And, +210 [ +211 lang.get(f.name)(*[vars[i] for i in f.param_act_inds]) # type: ignore +212 for f in a.precond +213 ], +214 ) +215 +216 adds = [lang.get(f.name)(*[vars[i] for i in f.param_act_inds]) for f in a.add] # type: ignore +217 dels = [lang.get(f.name)(*[vars[i] for i in f.param_act_inds]) for f in a.delete] # type: ignore +218 effects = [fs.AddEffect(e) for e in adds] + [fs.DelEffect(e) for e in dels] # fmt: skip +219 +220 problem.action( +221 a.name, +222 parameters=vars, +223 precondition=precond, +224 effects=effects, +225 ) +226 +227 problem.init = tarski.model.create(lang) # type: ignore +228 problem.goal = land() # type: ignore +229 writer = iofs.FstripsWriter(problem) +230 writer.write(domain_filename, problem_filename)API Documentation
235 def to_pddl_grounded( -236 self, -237 domain_name: str, -238 problem_name: str, -239 domain_filename: str, -240 problem_filename: str, -241 ): -242 """Dumps a Model to two PDDL files. The conversion only uses 0-arity predicates, and no types, objects, -243 or parameters of any kind are used. Actions are represented as ground actions with no parameters. -244 -245 Args: -246 domain_name (str): -247 The name of the domain to be generated. -248 problem_name (str): -249 The name of the problem to be generated. -250 domain_filename (str): -251 The name of the domain file to be generated. -252 problem_filename (str): -253 The name of the problem file to be generated. -254 """ -255 lang = tarski.language(domain_name) -256 problem = tarski.fstrips.create_fstrips_problem( -257 domain_name=domain_name, problem_name=problem_name, language=lang -258 ) -259 if self.fluents: -260 # create 0-arity predicates -261 for f in self.fluents: -262 # NOTE: want there to be no brackets in any fluents referenced as tarski adds these later. -263 # fluents (their string conversion) must be in the following format: (on object a object b) -264 test = str(f) -265 lang.predicate(str(f)[1:-1].replace(" ", "_")) -266 if self.actions: -267 for a in self.actions: -268 # fetch all the relevant 0-arity predicates and create formulas to set up the ground actions -269 preconds = self.__to_tarski_formula({a[1:-1] for a in a.precond}, lang) -270 adds = [lang.get(f"{e.replace(' ', '_')[1:-1]}")() for e in a.add] -271 dels = [lang.get(f"{e.replace(' ', '_')[1:-1]}")() for e in a.delete] -272 effects = [fs.AddEffect(e) for e in adds] -273 effects.extend([fs.DelEffect(e) for e in dels]) -274 # set up action -275 problem.action( -276 name=a.details() -277 .replace("(", "") -278 .replace(")", "") -279 .replace(" ", "_"), -280 parameters=[], -281 precondition=preconds, -282 effects=effects, -283 ) -284 # create empty init and goal -285 problem.init = tarski.model.create(lang) -286 problem.goal = land() -287 # write to files -288 writer = iofs.FstripsWriter(problem) -289 writer.write(domain_filename, problem_filename) +@@ -1941,18 +2179,18 @@232 def to_pddl_grounded( +233 self, +234 domain_name: str, +235 problem_name: str, +236 domain_filename: str, +237 problem_filename: str, +238 ): +239 """Dumps a Model to two PDDL files. The conversion only uses 0-arity predicates, and no types, objects, +240 or parameters of any kind are used. Actions are represented as ground actions with no parameters. +241 +242 Args: +243 domain_name (str): +244 The name of the domain to be generated. +245 problem_name (str): +246 The name of the problem to be generated. +247 domain_filename (str): +248 The name of the domain file to be generated. +249 problem_filename (str): +250 The name of the problem file to be generated. +251 """ +252 lang = tarski.language(domain_name) +253 problem = tarski.fstrips.create_fstrips_problem( +254 domain_name=domain_name, problem_name=problem_name, language=lang +255 ) +256 if self.fluents: +257 # create 0-arity predicates +258 for f in self.fluents: +259 # NOTE: want there to be no brackets in any fluents referenced as tarski adds these later. +260 # fluents (their string conversion) must be in the following format: (on object a object b) +261 test = str(f) +262 lang.predicate(str(f)[1:-1].replace(" ", "_")) +263 if self.actions: +264 for a in self.actions: +265 # fetch all the relevant 0-arity predicates and create formulas to set up the ground actions +266 preconds = self.__to_tarski_formula({a[1:-1] for a in a.precond}, lang) +267 adds = [lang.get(f"{e.replace(' ', '_')[1:-1]}")() for e in a.add] +268 dels = [lang.get(f"{e.replace(' ', '_')[1:-1]}")() for e in a.delete] +269 effects = [fs.AddEffect(e) for e in adds] +270 effects.extend([fs.DelEffect(e) for e in dels]) +271 # set up action +272 problem.action( +273 name=a.details() +274 .replace("(", "") +275 .replace(")", "") +276 .replace(" ", "_"), +277 parameters=[], +278 precondition=preconds, +279 effects=effects, +280 ) +281 # create empty init and goal +282 problem.init = tarski.model.create(lang) +283 problem.goal = land() +284 # write to files +285 writer = iofs.FstripsWriter(problem) +286 writer.write(domain_filename, problem_filename)API Documentation
294 @staticmethod -295 def deserialize(string: str): -296 """Deserializes a json string into a Model. -297 -298 Args: -299 string (str): -300 The json string representing a model. -301 -302 Returns: -303 A Model object matching the one specified by `string`. -304 """ -305 return Model._from_json(loads(string)) +@@ -2259,6 +2497,7 @@291 @staticmethod +292 def deserialize(string: str): +293 """Deserializes a json string into a Model. +294 +295 Args: +296 string (str): +297 The json string representing a model. +298 +299 Returns: +300 A Model object matching the one specified by `string`. +301 """ +302 return Model._from_json(loads(string))Inherited Members
@@ -4110,6 +4349,7 @@
- builtins.BaseException
- with_traceback
+- args
Inherited Members
@@ -4901,665 +5141,700 @@
- builtins.BaseException
- with_traceback
+- args
Inherited Members
166class LOCM: -167 """LOCM""" -168 -169 zero_obj = PlanningObject("zero", "zero") -170 -171 def __new__( -172 cls, -173 obs_tracelist: ObservedTraceList, -174 statics: Optional[Statics] = None, -175 viz: bool = False, -176 view: bool = False, -177 debug: Union[bool, Dict[str, bool], List[str]] = False, -178 ): -179 """Creates a new Model object. -180 Args: -181 observations (ObservationList): -182 The state observations to extract the model from. -183 statics (Dict[str, List[str]]): -184 A dictionary mapping an action name and its arguments to the -185 list of static preconditions of the action. A precondition should -186 be a tuple, where the first element is the predicate name and the -187 rest correspond to the arguments of the action (1-indexed). -188 E.g. static( next(C1, C2), put_on_card_in_homecell(C2, C1, _) ) -189 should is provided as: {"put_on_card_in_homecell": [("next", 2, 1)]} -190 viz (bool): -191 Whether to visualize the FSM. -192 view (bool): -193 Whether to view the FSM visualization. -194 -195 Raises: -196 IncompatibleObservationToken: -197 Raised if the observations are not identity observation. -198 """ -199 if obs_tracelist.type is not ActionObservation: -200 raise IncompatibleObservationToken(obs_tracelist.type, LOCM) -201 -202 if len(obs_tracelist) != 1: -203 warn("LOCM only supports a single trace, using first trace only") -204 -205 if isinstance(debug, bool) and debug: -206 debug = defaultdict(lambda: True) -207 elif isinstance(debug, dict): -208 debug = defaultdict(lambda: False, debug) -209 elif isinstance(debug, list): -210 debug = defaultdict(lambda: False, {k: True for k in debug}) -211 else: -212 debug = defaultdict(lambda: False) -213 -214 obs_trace = obs_tracelist[0] -215 fluents, actions = None, None -216 -217 sorts = LOCM._get_sorts(obs_trace, debug=debug["get_sorts"]) +@@ -5577,76 +5852,77 @@183class LOCM: +184 """LOCM""" +185 +186 zero_obj = PlanningObject("zero", "zero") +187 +188 def __new__( +189 cls, +190 obs_tracelist: ObservedTraceList, +191 statics: Optional[Statics] = None, +192 viz: bool = False, +193 view: bool = False, +194 debug: Union[bool, Dict[str, bool], List[str]] = False, +195 ): +196 """Creates a new Model object. +197 Args: +198 observations (ObservationList): +199 The state observations to extract the model from. +200 statics (Dict[str, List[str]]): +201 A dictionary mapping an action name and its arguments to the +202 list of static preconditions of the action. A precondition should +203 be a tuple, where the first element is the predicate name and the +204 rest correspond to the arguments of the action (1-indexed). +205 E.g. static( next(C1, C2), put_on_card_in_homecell(C2, C1, _) ) +206 should is provided as: {"put_on_card_in_homecell": [("next", 2, 1)]} +207 viz (bool): +208 Whether to visualize the FSM. +209 view (bool): +210 Whether to view the FSM visualization. +211 +212 Raises: +213 IncompatibleObservationToken: +214 Raised if the observations are not identity observation. +215 """ +216 if obs_tracelist.type is not ActionObservation: +217 raise IncompatibleObservationToken(obs_tracelist.type, LOCM) 218 -219 if debug["sorts"]: -220 print(f"Sorts:\n{sorts}", end="\n\n") +219 if len(obs_tracelist) != 1: +220 warn("LOCM only supports a single trace, using first trace only") 221 -222 TS, ap_state_pointers, OS = LOCM._step1(obs_trace, sorts, debug["step1"]) -223 HS = LOCM._step3(TS, ap_state_pointers, OS, sorts, debug["step3"]) -224 bindings = LOCM._step4(HS, debug["step4"]) -225 bindings = LOCM._step5(HS, bindings, debug["step5"]) -226 fluents, actions = LOCM._step7( -227 OS, -228 ap_state_pointers, -229 sorts, -230 bindings, -231 statics if statics is not None else {}, -232 debug["step7"], -233 ) -234 -235 if viz: -236 state_machines = LOCM.get_state_machines(ap_state_pointers, OS, bindings) -237 for sm in state_machines: -238 sm.render(view=view) -239 -240 return Model(fluents, actions) -241 -242 @staticmethod -243 def _get_sorts(obs_trace: List[Observation], debug=False) -> Sorts: -244 sorts = [] # initialize list of sorts for this trace -245 # track actions seen in the trace, and the sort each actions params belong to -246 ap_sort_pointers: Dict[str, List[int]] = {} -247 # track objects seen in the trace, and the sort each belongs to -248 # obj_sort_pointers: Dict[str, int] = {} -249 sorted_objs = [] -250 -251 def get_obj_sort(obj: PlanningObject) -> int: -252 """Returns the sort index of the object""" -253 for i, sort in enumerate(sorts): -254 if obj in sort: -255 return i -256 raise ValueError(f"Object {obj} not in any sort") +222 if isinstance(debug, bool) and debug: +223 debug = defaultdict(lambda: True) +224 elif isinstance(debug, dict): +225 debug = defaultdict(lambda: False, debug) +226 elif isinstance(debug, list): +227 debug = defaultdict(lambda: False, {k: True for k in debug}) +228 else: +229 debug = defaultdict(lambda: False) +230 +231 obs_trace = obs_tracelist[0] +232 fluents, actions = None, None +233 +234 sorts = LOCM._get_sorts(obs_trace, debug=debug["get_sorts"]) +235 +236 if debug["sorts"]: +237 sortid2objs = {v: [] for v in set(sorts.values())} +238 for k, v in sorts.items(): +239 sortid2objs[v].append(k) +240 print("\nSorts:\n") +241 pprint(sortid2objs) +242 print("\n") +243 +244 TS, ap_state_pointers, OS = LOCM._step1(obs_trace, sorts, debug["step1"]) +245 HS = LOCM._step3(TS, ap_state_pointers, OS, sorts, debug["step3"]) +246 bindings = LOCM._step4(HS, debug["step4"]) +247 bindings = LOCM._step5(HS, bindings, debug["step5"]) +248 fluents, actions = LOCM._step7( +249 OS, +250 ap_state_pointers, +251 sorts, +252 bindings, +253 statics if statics is not None else {}, +254 debug["step7"], +255 viz, +256 ) 257 -258 for obs in obs_trace: -259 action = obs.action -260 if action is None: -261 continue -262 -263 if debug: -264 print("\n\naction:", action.name, action.obj_params) -265 -266 if action.name not in ap_sort_pointers: # new action -267 if debug: -268 print("new action") -269 -270 ap_sort_pointers[action.name] = [] -271 -272 # for each parameter of the action -273 for obj in action.obj_params: -274 if obj.name not in sorted_objs: # unsorted object -275 # append a sort (set) containing the object -276 sorts.append({obj}) -277 -278 # record the object has been sorted and the index of the sort it belongs to -279 obj_sort = len(sorts) - 1 -280 sorted_objs.append(obj.name) -281 ap_sort_pointers[action.name].append(obj_sort) -282 -283 if debug: -284 print("new object", obj.name) -285 print("sorts:", sorts) -286 -287 else: # object already sorted -288 # look up the sort of the object -289 obj_sort = get_obj_sort(obj) -290 ap_sort_pointers[action.name].append(obj_sort) -291 -292 if debug: -293 print("sorted object", obj.name) -294 print("sorts:", sorts) +258 return Model(fluents, actions) +259 +260 @staticmethod +261 def _get_sorts(obs_trace: List[Observation], debug=False) -> Sorts: +262 sorts = [] # initialize list of sorts for this trace +263 # track actions seen in the trace, and the sort each actions params belong to +264 ap_sort_pointers: Dict[str, List[int]] = {} +265 # track objects seen in the trace, and the sort each belongs to +266 # obj_sort_pointers: Dict[str, int] = {} +267 sorted_objs = [] +268 +269 def get_obj_sort(obj: PlanningObject) -> int: +270 """Returns the sort index of the object""" +271 for i, sort in enumerate(sorts): +272 if obj in sort: +273 return i +274 raise ValueError(f"Object {obj} not in any sort") +275 +276 for obs in obs_trace: +277 action = obs.action +278 if action is None: +279 continue +280 +281 if debug: +282 print("\n\naction:", action.name, action.obj_params) +283 +284 if action.name not in ap_sort_pointers: # new action +285 if debug: +286 print("new action") +287 +288 ap_sort_pointers[action.name] = [] +289 +290 # for each parameter of the action +291 for obj in action.obj_params: +292 if obj.name not in sorted_objs: # unsorted object +293 # append a sort (set) containing the object +294 sorts.append({obj}) 295 -296 if debug: -297 print("ap sorts:", ap_sort_pointers) -298 -299 else: # action seen before -300 if debug: -301 print("seen action") -302 -303 for ap_sort, obj in zip( -304 ap_sort_pointers[action.name], action.obj_params -305 ): -306 if debug: -307 print("checking obj", obj.name) -308 print("ap sort:", ap_sort) +296 # record the object has been sorted and the index of the sort it belongs to +297 obj_sort = len(sorts) - 1 +298 sorted_objs.append(obj.name) +299 ap_sort_pointers[action.name].append(obj_sort) +300 +301 if debug: +302 print("new object", obj.name) +303 print("sorts:", sorts) +304 +305 else: # object already sorted +306 # look up the sort of the object +307 obj_sort = get_obj_sort(obj) +308 ap_sort_pointers[action.name].append(obj_sort) 309 -310 if obj.name not in sorted_objs: # unsorted object -311 if debug: -312 print("unsorted object", obj.name) -313 print("sorts:", sorts) -314 -315 # add the object to the sort of current action parameter -316 sorts[ap_sort].add(obj) -317 sorted_objs.append(obj.name) -318 -319 else: # object already has a sort -320 # retrieve the sort the object belongs to -321 obj_sort = get_obj_sort(obj) -322 -323 if debug: -324 print(f"retrieving sorted obj {obj.name}") -325 print(f"obj_sort_idx: {obj_sort}") -326 print(f"seq_sorts: {sorts}") +310 if debug: +311 print("sorted object", obj.name) +312 print("sorts:", sorts) +313 +314 if debug: +315 print("ap sorts:", ap_sort_pointers) +316 +317 else: # action seen before +318 if debug: +319 print("seen action") +320 +321 for ap_sort, obj in zip( +322 ap_sort_pointers[action.name], action.obj_params +323 ): +324 if debug: +325 print("checking obj", obj.name) +326 print("ap sort:", ap_sort) 327 -328 # check if the object's sort matches the action paremeter's -329 # if so, do nothing and move on to next step -330 # otherwise, unite the two sorts -331 if obj_sort == ap_sort: -332 if debug: -333 print("obj sort matches action") -334 else: -335 if debug: -336 print( -337 f"obj sort {obj_sort} doesn't match action {ap_sort}" -338 ) -339 print(f"seq_sorts: {sorts}") +328 if obj.name not in sorted_objs: # unsorted object +329 if debug: +330 print("unsorted object", obj.name) +331 print("sorts:", sorts) +332 +333 # add the object to the sort of current action parameter +334 sorts[ap_sort].add(obj) +335 sorted_objs.append(obj.name) +336 +337 else: # object already has a sort +338 # retrieve the sort the object belongs to +339 obj_sort = get_obj_sort(obj) 340 -341 # unite the action parameter's sort and the object's sort -342 sorts[obj_sort] = sorts[obj_sort].union(sorts[ap_sort]) -343 -344 # drop the not unionized sort -345 sorts.pop(ap_sort) -346 -347 old_obj_sort = obj_sort -348 -349 obj_sort = get_obj_sort(obj) -350 -351 if debug: -352 print( -353 f"united seq_sorts[{ap_sort}] and seq_sorts[{obj_sort}]" -354 ) -355 print(f"seq_sorts: {sorts}") -356 print(f"ap_sort_pointers: {ap_sort_pointers}") -357 print("updating pointers...") +341 if debug: +342 print(f"retrieving sorted obj {obj.name}") +343 print(f"obj_sort_idx: {obj_sort}") +344 print(f"seq_sorts: {sorts}") +345 +346 # check if the object's sort matches the action paremeter's +347 # if so, do nothing and move on to next step +348 # otherwise, unite the two sorts +349 if obj_sort == ap_sort: +350 if debug: +351 print("obj sort matches action") +352 else: +353 if debug: +354 print( +355 f"obj sort {obj_sort} doesn't match action {ap_sort}" +356 ) +357 print(f"seq_sorts: {sorts}") 358 -359 min_idx = min(ap_sort, obj_sort) -360 -361 # update all outdated records of which sort the affected objects belong to -362 for action_name, ap_sorts in ap_sort_pointers.items(): -363 for p, sort in enumerate(ap_sorts): -364 if sort == ap_sort or sort == old_obj_sort: -365 ap_sort_pointers[action_name][p] = obj_sort -366 elif sort > min_idx: -367 ap_sort_pointers[action_name][p] -= 1 +359 # unite the action parameter's sort and the object's sort +360 sorts[obj_sort] = sorts[obj_sort].union(sorts[ap_sort]) +361 +362 # drop the not unionized sort +363 sorts.pop(ap_sort) +364 +365 old_obj_sort = obj_sort +366 +367 obj_sort = get_obj_sort(obj) 368 369 if debug: -370 print(f"ap_sort_pointers: {ap_sort_pointers}") -371 -372 obj_sorts = {} -373 for i, sort in enumerate(sorts): -374 for obj in sort: -375 # NOTE: object sorts are 1-indexed so the zero-object can be sort 0 -376 obj_sorts[obj.name] = i + 1 -377 -378 return obj_sorts -379 -380 @staticmethod -381 def _pointer_to_set(states: List[Set], pointer, pointer2=None) -> Tuple[int, int]: -382 state1, state2 = None, None -383 for i, state_set in enumerate(states): -384 if pointer in state_set: -385 state1 = i -386 if pointer2 is None or pointer2 in state_set: -387 state2 = i -388 if state1 is not None and state2 is not None: -389 break -390 -391 assert state1 is not None, f"Pointer ({pointer}) not in states: {states}" -392 assert state2 is not None, f"Pointer ({pointer2}) not in states: {states}" -393 return state1, state2 -394 -395 @staticmethod -396 def _step1( -397 obs_trace: List[Observation], sorts: Sorts, debug: bool = False -398 ) -> Tuple[TSType, APStatePointers, OSType]: -399 """Step 1: Create a state machine for each object sort -400 Implicitly includes Step 2 (zero analysis) by including the zero-object throughout -401 """ -402 -403 # create the zero-object for zero analysis (step 2) -404 zero_obj = LOCM.zero_obj -405 -406 # collect action sequences for each object -407 obj_traces: Dict[PlanningObject, List[AP]] = defaultdict(list) -408 for obs in obs_trace: -409 action = obs.action -410 if action is not None: -411 # add the step for the zero-object -412 obj_traces[zero_obj].append(AP(action, pos=0, sort=0)) -413 # for each combination of action name A and argument pos P -414 for j, obj in enumerate(action.obj_params): -415 # create transition A.P -416 ap = AP(action, pos=j + 1, sort=sorts[obj.name]) -417 obj_traces[obj].append(ap) -418 -419 # initialize the state set OS and transition set TS -420 OS: OSType = defaultdict(list) -421 TS: TSType = defaultdict(dict) -422 # track pointers mapping A.P to its start and end states -423 ap_state_pointers = defaultdict(dict) -424 # iterate over each object and its action sequence -425 for obj, seq in obj_traces.items(): -426 state_n = 1 # count current (new) state id -427 sort = sorts[obj.name] if obj != zero_obj else 0 -428 TS[sort][obj] = seq # add the sequence to the transition set -429 prev_states: StatePointers = None # type: ignore -430 # iterate over each transition A.P in the sequence -431 for ap in seq: -432 # if the transition has not been seen before for the current sort -433 if ap not in ap_state_pointers[sort]: -434 ap_state_pointers[sort][ap] = StatePointers(state_n, state_n + 1) -435 -436 # add the start and end states to the state set as unique states -437 OS[sort].append({state_n}) -438 OS[sort].append({state_n + 1}) -439 -440 state_n += 2 -441 -442 ap_states = ap_state_pointers[sort][ap] -443 -444 if prev_states is not None: -445 # get the state ids (indecies) of the state sets containing -446 # start(A.P) and the end state of the previous transition -447 start_state, prev_end_state = LOCM._pointer_to_set( -448 OS[sort], ap_states.start, prev_states.end -449 ) -450 -451 # if not the same state set, merge the two -452 if start_state != prev_end_state: -453 OS[sort][start_state] = OS[sort][start_state].union( -454 OS[sort][prev_end_state] -455 ) -456 OS[sort].pop(prev_end_state) -457 -458 prev_states = ap_states -459 -460 # remove the zero-object sort if it only has one state -461 if len(OS[0]) == 1: -462 ap_state_pointers[0] = {} -463 OS[0] = [] +370 print( +371 f"united seq_sorts[{ap_sort}] and seq_sorts[{obj_sort}]" +372 ) +373 print(f"seq_sorts: {sorts}") +374 print(f"ap_sort_pointers: {ap_sort_pointers}") +375 print("updating pointers...") +376 +377 min_idx = min(ap_sort, obj_sort) +378 +379 # update all outdated records of which sort the affected objects belong to +380 for action_name, ap_sorts in ap_sort_pointers.items(): +381 for p, sort in enumerate(ap_sorts): +382 if sort == ap_sort or sort == old_obj_sort: +383 ap_sort_pointers[action_name][p] = obj_sort +384 elif sort > min_idx: +385 ap_sort_pointers[action_name][p] -= 1 +386 +387 if debug: +388 print(f"ap_sort_pointers: {ap_sort_pointers}") +389 +390 obj_sorts = {} +391 for i, sort in enumerate(sorts): +392 for obj in sort: +393 # NOTE: object sorts are 1-indexed so the zero-object can be sort 0 +394 obj_sorts[obj.name] = i + 1 +395 +396 return obj_sorts +397 +398 @staticmethod +399 def _pointer_to_set(states: List[Set], pointer, pointer2=None) -> Tuple[int, int]: +400 state1, state2 = None, None +401 for i, state_set in enumerate(states): +402 if pointer in state_set: +403 state1 = i +404 if pointer2 is None or pointer2 in state_set: +405 state2 = i +406 if state1 is not None and state2 is not None: +407 break +408 +409 assert state1 is not None, f"Pointer ({pointer}) not in states: {states}" +410 assert state2 is not None, f"Pointer ({pointer2}) not in states: {states}" +411 return state1, state2 +412 +413 @staticmethod +414 def _step1( +415 obs_trace: List[Observation], sorts: Sorts, debug: bool = False +416 ) -> Tuple[TSType, APStatePointers, OSType]: +417 """Step 1: Create a state machine for each object sort +418 Implicitly includes Step 2 (zero analysis) by including the zero-object throughout +419 """ +420 +421 # create the zero-object for zero analysis (step 2) +422 zero_obj = LOCM.zero_obj +423 +424 # collect action sequences for each object +425 obj_traces: Dict[PlanningObject, List[AP]] = defaultdict(list) +426 for obs in obs_trace: +427 action = obs.action +428 if action is not None: +429 # add the step for the zero-object +430 obj_traces[zero_obj].append(AP(action, pos=0, sort=0)) +431 # for each combination of action name A and argument pos P +432 for j, obj in enumerate(action.obj_params): +433 # create transition A.P +434 ap = AP(action, pos=j + 1, sort=sorts[obj.name]) +435 obj_traces[obj].append(ap) +436 +437 # initialize the state set OS and transition set TS +438 OS: OSType = defaultdict(list) +439 TS: TSType = defaultdict(dict) +440 # track pointers mapping A.P to its start and end states +441 ap_state_pointers = defaultdict(dict) +442 # iterate over each object and its action sequence +443 for obj, seq in obj_traces.items(): +444 sort = sorts[obj.name] if obj != zero_obj else 0 +445 TS[sort][obj] = seq # add the sequence to the transition set +446 # max of the states already in OS[sort], plus 1 +447 state_n = ( +448 max([max(s) for s in OS[sort]] + [0]) + 1 +449 ) # count current (new) state id +450 prev_states: StatePointers = None # type: ignore +451 # iterate over each transition A.P in the sequence +452 for ap in seq: +453 # if the transition has not been seen before for the current sort +454 if ap not in ap_state_pointers[sort]: +455 ap_state_pointers[sort][ap] = StatePointers(state_n, state_n + 1) +456 +457 # add the start and end states to the state set as unique states +458 OS[sort].append({state_n}) +459 OS[sort].append({state_n + 1}) +460 +461 state_n += 2 +462 +463 ap_states = ap_state_pointers[sort][ap] 464 -465 return dict(TS), dict(ap_state_pointers), dict(OS) -466 -467 @staticmethod -468 def _step3( -469 TS: TSType, -470 ap_state_pointers: APStatePointers, -471 OS: OSType, -472 sorts: Sorts, -473 debug: bool = False, -474 ) -> Hypotheses: -475 """Step 3: Induction of parameterised FSMs""" -476 -477 zero_obj = LOCM.zero_obj -478 -479 # indexed by B.k and C.l for 3.2 matching hypotheses against transitions -480 HS: Dict[HSIndex, Set[HSItem]] = defaultdict(set) +465 if prev_states is not None: +466 # get the state ids (indecies) of the state sets containing +467 # start(A.P) and the end state of the previous transition +468 start_state, prev_end_state = LOCM._pointer_to_set( +469 OS[sort], ap_states.start, prev_states.end +470 ) +471 +472 # if not the same state set, merge the two +473 if start_state != prev_end_state: +474 OS[sort][start_state] = OS[sort][start_state].union( +475 OS[sort][prev_end_state] +476 ) +477 OS[sort].pop(prev_end_state) +478 assert len(set.union(*OS[sort])) == sum([len(s) for s in OS[sort]]) +479 +480 prev_states = ap_states 481 -482 # 3.1: Form hypotheses from state machines -483 for G, sort_ts in TS.items(): -484 # for each O ∈ O_u (not including the zero-object) -485 for obj, seq in sort_ts.items(): -486 if obj == zero_obj: -487 continue -488 # for each pair of transitions B.k and C.l consecutive for O -489 for B, C in zip(seq, seq[1:]): -490 # skip if B or C only have one parameter, since there is no k' or l' to match on -491 if len(B.action.obj_params) == 1 or len(C.action.obj_params) == 1: -492 continue -493 -494 k = B.pos -495 l = C.pos -496 -497 # check each pair B.k' and C.l' -498 for i, Bk_ in enumerate(B.action.obj_params): -499 k_ = i + 1 -500 if k_ == k: -501 continue -502 G_ = sorts[Bk_.name] -503 for j, Cl_ in enumerate(C.action.obj_params): -504 l_ = j + 1 -505 if l_ == l: -506 continue -507 -508 # check that B.k' and C.l' are of the same sort -509 if sorts[Cl_.name] == G_: -510 # check that end(B.P) = start(C.P) -511 # NOTE: just a sanity check, should never fail -512 S, S2 = LOCM._pointer_to_set( -513 OS[G], -514 ap_state_pointers[G][B].end, -515 ap_state_pointers[G][C].start, -516 ) -517 assert ( -518 S == S2 -519 ), f"end(B.P) != start(C.P)\nB.P: {B}\nC.P: {C}" -520 -521 # save the hypothesis in the hypothesis set -522 HS[HSIndex(B, k, C, l)].add( -523 HSItem(S, k_, l_, G, G_, supported=False) -524 ) -525 -526 # 3.2: Test hypotheses against sequence -527 for G, sort_ts in TS.items(): -528 # for each O ∈ O_u (not including the zero-object) -529 for obj, seq in sort_ts.items(): -530 if obj == zero_obj: -531 continue -532 # for each pair of transitions Ap.m and Aq.n consecutive for O -533 for Ap, Aq in zip(seq, seq[1:]): -534 m = Ap.pos -535 n = Aq.pos -536 # Check if we have a hypothesis matching Ap=B, m=k, Aq=C, n=l -537 BkCl = HSIndex(Ap, m, Aq, n) -538 if BkCl in HS: -539 # check each matching hypothesis -540 for H in HS[BkCl].copy(): -541 # if Op,k' = Oq,l' then mark the hypothesis as supported -542 if ( -543 Ap.action.obj_params[H.k_ - 1] -544 == Aq.action.obj_params[H.l_ - 1] -545 ): -546 H.supported = True -547 else: # otherwise remove the hypothesis -548 HS[BkCl].remove(H) -549 -550 # Remove any unsupported hypotheses (but yet undisputed) -551 for hind, hs in HS.copy().items(): -552 for h in hs: -553 if not h.supported: -554 hs.remove(h) -555 if len(hs) == 0: -556 del HS[hind] -557 -558 # Converts HS {HSIndex: HSItem} to a mapping of hypothesis for states of a sort {sort: {state: Hypothesis}} -559 return Hypothesis.from_dict(HS) -560 -561 @staticmethod -562 def _step4( -563 HS: Dict[int, Dict[int, Set[Hypothesis]]], debug: bool = False -564 ) -> Bindings: -565 """Step 4: Creation and merging of state parameters""" -566 -567 # bindings = {sort: {state: [(hypothesis, state param)]}} -568 bindings: Bindings = defaultdict(dict) -569 for sort, hs_sort in HS.items(): -570 for state, hs_sort_state in hs_sort.items(): -571 # state_bindings = {hypothesis (h): state param (v)} -572 state_bindings: Dict[Hypothesis, int] = {} -573 -574 # state_params = [set(v)]; params in the same set are the same -575 state_params: List[Set[int]] = [] -576 -577 # state_param_pointers = {v: P}; maps state param to the state_params set index -578 # i.e. map hypothesis state param v -> actual state param P -579 state_param_pointers: Dict[int, int] = {} -580 -581 # for each hypothesis h, -582 hs_sort_state = list(hs_sort_state) -583 for v, h in enumerate(hs_sort_state): -584 # add the <h, v> binding pair -585 state_bindings[h] = v -586 # add a param v as a unique state parameter -587 state_params.append({v}) -588 state_param_pointers[v] = v -589 -590 # for each (unordered) pair of hypotheses h1, h2 -591 for i, h1 in enumerate(hs_sort_state): -592 for h2 in hs_sort_state[i + 1 :]: -593 # check if hypothesis parameters (v1 & v2) need to be unified -594 if ( -595 (h1.B == h2.B and h1.k == h2.k and h1.k_ == h2.k_) -596 or -597 (h1.C == h2.C and h1.l == h2.l and h1.l_ == h2.l_) # fmt: skip -598 ): -599 v1 = state_bindings[h1] -600 v2 = state_bindings[h2] -601 -602 # get the parameter sets P1, P2 that v1, v2 belong to -603 P1, P2 = LOCM._pointer_to_set(state_params, v1, v2) -604 -605 if P1 != P2: -606 # merge P1 and P2 -607 state_params[P1] = state_params[P1].union( -608 state_params[P2] -609 ) -610 state_params.pop(P2) -611 state_param_pointers[v2] = P1 -612 -613 # add state bindings for the sort to the output bindings -614 # replacing hypothesis params with actual state params -615 bindings[sort][state] = [ -616 Binding(h, LOCM._pointer_to_set(state_params, v)[0]) -617 for h, v in state_bindings.items() -618 ] -619 -620 return dict(bindings) -621 -622 @staticmethod -623 def _step5( -624 HS: Dict[int, Dict[int, Set[Hypothesis]]], -625 bindings: Bindings, -626 debug: bool = False, -627 ) -> Bindings: -628 """Step 5: Removing parameter flaws""" -629 -630 # check each bindings[G][S] -> (h, P) -631 for sort, hs_sort in HS.items(): -632 for state in hs_sort: -633 # track all the h.Bs that occur in bindings[G][S] -634 all_hB = set() -635 # track the set of h.B that set parameter P -636 sets_P = defaultdict(set) -637 for h, P in bindings[sort][state]: -638 sets_P[P].add(h.B) -639 all_hB.add(h.B) -640 -641 # for each P, check if there is a transition h.B that never sets parameter P -642 # i.e. if sets_P[P] != all_hB -643 for P, setby in sets_P.items(): -644 if not setby == all_hB: # P is a flawed parameter -645 # remove all bindings referencing P -646 for h, P_ in bindings[sort][state].copy(): -647 if P_ == P: -648 bindings[sort][state].remove(Binding(h, P_)) -649 if len(bindings[sort][state]) == 0: -650 del bindings[sort][state] -651 -652 for k, v in bindings.copy().items(): -653 if not v: -654 del bindings[k] -655 -656 return bindings -657 -658 @staticmethod -659 def get_state_machines( -660 ap_state_pointers: APStatePointers, -661 OS: OSType, -662 bindings: Optional[Bindings] = None, -663 ): -664 from graphviz import Digraph -665 -666 state_machines = [] -667 for (sort, trans), states in zip(ap_state_pointers.items(), OS.values()): -668 graph = Digraph(f"LOCM-step1-sort{sort}") -669 for state in range(len(states)): -670 label = f"state{state}" -671 if ( -672 bindings is not None -673 and sort in bindings -674 and state in bindings[sort] -675 ): -676 label += f"\n[" -677 params = [] -678 for binding in bindings[sort][state]: -679 params.append(f"{binding.hypothesis.G_}") -680 label += f",".join(params) -681 label += f"]" -682 graph.node(str(state), label=label, shape="oval") -683 for ap, apstate in trans.items(): -684 start_idx, end_idx = LOCM._pointer_to_set( -685 states, apstate.start, apstate.end -686 ) -687 graph.edge( -688 str(start_idx), str(end_idx), label=f"{ap.action.name}.{ap.pos}" -689 ) +482 # remove the zero-object sort if it only has one state +483 if len(OS[0]) == 1: +484 ap_state_pointers[0] = {} +485 OS[0] = [] +486 +487 return dict(TS), dict(ap_state_pointers), dict(OS) +488 +489 @staticmethod +490 def _step3( +491 TS: TSType, +492 ap_state_pointers: APStatePointers, +493 OS: OSType, +494 sorts: Sorts, +495 debug: bool = False, +496 ) -> Hypotheses: +497 """Step 3: Induction of parameterised FSMs""" +498 +499 zero_obj = LOCM.zero_obj +500 +501 # indexed by B.k and C.l for 3.2 matching hypotheses against transitions +502 HS: Dict[HSIndex, Set[HSItem]] = defaultdict(set) +503 +504 # 3.1: Form hypotheses from state machines +505 for G, sort_ts in TS.items(): +506 # for each O ∈ O_u (not including the zero-object) +507 for obj, seq in sort_ts.items(): +508 if obj == zero_obj: +509 continue +510 # for each pair of transitions B.k and C.l consecutive for O +511 for B, C in zip(seq, seq[1:]): +512 # skip if B or C only have one parameter, since there is no k' or l' to match on +513 if len(B.action.obj_params) == 1 or len(C.action.obj_params) == 1: +514 continue +515 +516 k = B.pos +517 l = C.pos +518 +519 # check each pair B.k' and C.l' +520 for i, Bk_ in enumerate(B.action.obj_params): +521 k_ = i + 1 +522 if k_ == k: +523 continue +524 G_ = sorts[Bk_.name] +525 for j, Cl_ in enumerate(C.action.obj_params): +526 l_ = j + 1 +527 if l_ == l: +528 continue +529 +530 # check that B.k' and C.l' are of the same sort +531 if sorts[Cl_.name] == G_: +532 # check that end(B.P) = start(C.P) +533 # NOTE: just a sanity check, should never fail +534 S, S2 = LOCM._pointer_to_set( +535 OS[G], +536 ap_state_pointers[G][B].end, +537 ap_state_pointers[G][C].start, +538 ) +539 assert ( +540 S == S2 +541 ), f"end(B.P) != start(C.P)\nB.P: {B}\nC.P: {C}" +542 +543 # save the hypothesis in the hypothesis set +544 HS[HSIndex(B, k, C, l)].add( +545 HSItem(S, k_, l_, G, G_, supported=False) +546 ) +547 +548 # 3.2: Test hypotheses against sequence +549 for G, sort_ts in TS.items(): +550 # for each O ∈ O_u (not including the zero-object) +551 for obj, seq in sort_ts.items(): +552 if obj == zero_obj: +553 continue +554 # for each pair of transitions Ap.m and Aq.n consecutive for O +555 for Ap, Aq in zip(seq, seq[1:]): +556 m = Ap.pos +557 n = Aq.pos +558 # Check if we have a hypothesis matching Ap=B, m=k, Aq=C, n=l +559 BkCl = HSIndex(Ap, m, Aq, n) +560 if BkCl in HS: +561 # check each matching hypothesis +562 for H in HS[BkCl].copy(): +563 # if Op,k' = Oq,l' then mark the hypothesis as supported +564 if ( +565 Ap.action.obj_params[H.k_ - 1] +566 == Aq.action.obj_params[H.l_ - 1] +567 ): +568 H.supported = True +569 else: # otherwise remove the hypothesis +570 HS[BkCl].remove(H) +571 +572 # Remove any unsupported hypotheses (but yet undisputed) +573 for hind, hs in HS.copy().items(): +574 for h in hs.copy(): +575 if not h.supported: +576 hs.remove(h) +577 if len(hs) == 0: +578 del HS[hind] +579 +580 # Converts HS {HSIndex: HSItem} to a mapping of hypothesis for states of a sort {sort: {state: Hypothesis}} +581 return Hypothesis.from_dict(HS) +582 +583 @staticmethod +584 def _step4(HS: Hypotheses, debug: bool = False) -> Bindings: +585 """Step 4: Creation and merging of state parameters""" +586 # bindings = {sort: {state: [(hypothesis, state param)]}} +587 bindings: Bindings = defaultdict(dict) +588 for sort, hs_sort in HS.items(): +589 for state, hs_sort_state in hs_sort.items(): +590 # state_bindings = {hypothesis (h): state param (v)} +591 state_bindings: Dict[Hypothesis, int] = {} +592 +593 # state_params = [set(v)]; params in the same set are the same +594 state_params: List[Set[int]] = [] +595 +596 # state_param_pointers = {v: P}; maps state param to the state_params set index +597 # i.e. map hypothesis state param v -> actual state param P +598 state_param_pointers: Dict[int, int] = {} +599 +600 # for each hypothesis h, +601 hs_sort_state = list(hs_sort_state) +602 for v, h in enumerate(hs_sort_state): +603 # add the <h, v> binding pair +604 state_bindings[h] = v +605 # add a param v as a unique state parameter +606 state_params.append({v}) +607 state_param_pointers[v] = v +608 +609 # for each (unordered) pair of hypotheses h1, h2 +610 for i, h1 in enumerate(hs_sort_state): +611 for h2 in hs_sort_state[i + 1 :]: +612 # check if hypothesis parameters (v1 & v2) need to be unified +613 if ( +614 (h1.B == h2.B and h1.k == h2.k and h1.k_ == h2.k_) +615 and # See https://github.com/AI-Planning/macq/discussions/200 +616 (h1.C == h2.C and h1.l == h2.l and h1.l_ == h2.l_) # fmt: skip +617 ): +618 v1 = state_bindings[h1] +619 v2 = state_bindings[h2] +620 +621 # get the parameter sets P1, P2 that v1, v2 belong to +622 P1, P2 = LOCM._pointer_to_set(state_params, v1, v2) +623 +624 if P1 != P2: +625 # merge P1 and P2 +626 state_params[P1] = state_params[P1].union( +627 state_params[P2] +628 ) +629 state_params.pop(P2) +630 state_param_pointers[v2] = P1 +631 +632 # fix state_param_pointers after v2 +633 for ind in range(v2 + 1, len(state_param_pointers)): +634 state_param_pointers[ind] -= 1 +635 +636 # add state bindings for the sort to the output bindings +637 # replacing hypothesis params with actual state params +638 bindings[sort][state] = [ +639 Binding(h, LOCM._pointer_to_set(state_params, v)[0]) +640 for h, v in state_bindings.items() +641 ] +642 +643 return dict(bindings) +644 +645 @staticmethod +646 def _step5( +647 HS: Hypotheses, +648 bindings: Bindings, +649 debug: bool = False, +650 ) -> Bindings: +651 """Step 5: Removing parameter flaws""" +652 +653 # check each bindings[G][S] -> (h, P) +654 for sort, hs_sort in HS.items(): +655 for state_id in hs_sort: +656 # track all the h.Bs that occur in bindings[G][S] +657 all_hB = set() +658 # track the set of h.B that set parameter P +659 sets_P = defaultdict(set) +660 for h, P in bindings[sort][state_id]: +661 sets_P[P].add(h.B) +662 all_hB.add(h.B) +663 +664 # for each P, check if there is a transition h.B that never sets parameter P +665 # i.e. if sets_P[P] != all_hB +666 for P, setby in sets_P.items(): +667 if not setby == all_hB: # P is a flawed parameter +668 # remove all bindings referencing P +669 for h, P_ in bindings[sort][state_id].copy(): +670 if P_ == P: +671 bindings[sort][state_id].remove(Binding(h, P_)) +672 if len(bindings[sort][state_id]) == 0: +673 del bindings[sort][state_id] +674 +675 # do the same for checking h.C reading parameter P +676 # See https://github.com/AI-Planning/macq/discussions/200 +677 all_hC = set() +678 reads_P = defaultdict(set) +679 if state_id in bindings[sort]: +680 for h, P in bindings[sort][state_id]: +681 reads_P[P].add(h.C) +682 all_hC.add(h.C) +683 for P, readby in reads_P.items(): +684 if not readby == all_hC: +685 for h, P_ in bindings[sort][state_id].copy(): +686 if P_ == P: +687 bindings[sort][state_id].remove(Binding(h, P_)) +688 if len(bindings[sort][state_id]) == 0: +689 del bindings[sort][state_id] 690 -691 state_machines.append(graph) -692 -693 return state_machines +691 for k, v in bindings.copy().items(): +692 if not v: +693 del bindings[k] 694 -695 @staticmethod -696 def _step7( -697 OS: OSType, -698 ap_state_pointers: APStatePointers, -699 sorts: Sorts, -700 bindings: Bindings, -701 statics: Statics, -702 debug: bool = False, -703 ) -> Tuple[Set[LearnedLiftedFluent], Set[LearnedLiftedAction]]: -704 """Step 7: Formation of PDDL action schema -705 Implicitly includes Step 6 (statics) by including statics as an argument -706 and adding to the relevant actions while being constructed. -707 """ -708 -709 # delete zero-object if it's state machine was discarded -710 if not OS[0]: -711 del OS[0] -712 del ap_state_pointers[0] -713 -714 if debug: -715 print("ap state pointers") -716 pprint(ap_state_pointers) -717 print() -718 -719 print("OS:") -720 pprint(OS) -721 print() -722 -723 print("bindings:") -724 pprint(bindings) -725 print() -726 -727 bound_param_sorts = { -728 sort: { -729 state: [ -730 binding.hypothesis.G_ -731 for binding in bindings.get(sort, {}).get(state, []) -732 ] -733 for state in range(len(states)) -734 } -735 for sort, states in OS.items() -736 } +695 return bindings +696 +697 @staticmethod +698 def _debug_state_machines(OS, ap_state_pointers, state_params): +699 import os +700 +701 import networkx as nx +702 +703 for sort in OS: +704 G = nx.DiGraph() +705 for n in range(len(OS[sort])): +706 lbl = f"state{n}" +707 if ( +708 state_params is not None +709 and sort in state_params +710 and n in state_params[sort] +711 ): +712 lbl += str( +713 [ +714 state_params[sort][n][v] +715 for v in sorted(state_params[sort][n].keys()) +716 ] +717 ) +718 G.add_node(n, label=lbl, shape="oval") +719 for ap, apstate in ap_state_pointers[sort].items(): +720 start_idx, end_idx = LOCM._pointer_to_set( +721 OS[sort], apstate.start, apstate.end +722 ) +723 # check if edge is already in graph +724 if G.has_edge(start_idx, end_idx): +725 # append to the edge label +726 G.edges[start_idx, end_idx][ +727 "label" +728 ] += f"\n{ap.action.name}.{ap.pos}" +729 else: +730 G.add_edge(start_idx, end_idx, label=f"{ap.action.name}.{ap.pos}") +731 # write to dot file +732 nx.drawing.nx_pydot.write_dot(G, f"LOCM-step7-sort{sort}.dot") +733 os.system( +734 f"dot -Tpng LOCM-step7-sort{sort}.dot -o LOCM-step7-sort{sort}.png" +735 ) +736 os.system(f"rm LOCM-step7-sort{sort}.dot") 737 -738 actions = {} -739 fluents = defaultdict(dict) -740 -741 all_aps: Dict[str, List[AP]] = defaultdict(list) -742 for aps in ap_state_pointers.values(): -743 for ap in aps: -744 all_aps[ap.action.name].append(ap) -745 -746 for action, aps in all_aps.items(): -747 actions[action] = LearnedLiftedAction( -748 action, [f"sort{ap.sort}" for ap in aps] -749 ) -750 -751 @dataclass -752 class TemplateFluent: -753 name: str -754 param_sorts: List[str] -755 -756 def __hash__(self) -> int: -757 return hash(self.name + "".join(self.param_sorts)) -758 -759 for sort, state_bindings in bound_param_sorts.items(): -760 for state, bound_sorts in state_bindings.items(): -761 fluents[sort][state] = TemplateFluent( -762 f"sort{sort}_state{state}", -763 [f"sort{sort}"] + [f"sort{s}" for s in bound_sorts], -764 ) -765 -766 for (sort, aps), states in zip(ap_state_pointers.items(), OS.values()): -767 for ap, pointers in aps.items(): -768 start_state, end_state = LOCM._pointer_to_set( -769 states, pointers.start, pointers.end -770 ) -771 -772 # preconditions += fluent for origin state -773 start_fluent_temp = fluents[sort][start_state] -774 -775 bound_param_inds = [] -776 -777 # for each bindings on the start state (if there are any) -778 # then add each binding.hypothesis.l_ -779 if sort in bindings and start_state in bindings[sort]: -780 bound_param_inds = [ -781 b.hypothesis.l_ - 1 for b in bindings[sort][start_state] -782 ] +738 @staticmethod +739 def _step7( +740 OS: OSType, +741 ap_state_pointers: APStatePointers, +742 sorts: Sorts, +743 bindings: Bindings, +744 statics: Statics, +745 debug: bool = False, +746 viz: bool = False, +747 ) -> Tuple[Set[LearnedLiftedFluent], Set[LearnedLiftedAction]]: +748 """Step 7: Formation of PDDL action schema +749 Implicitly includes Step 6 (statics) by including statics as an argument +750 and adding to the relevant actions while being constructed. +751 """ +752 +753 # delete zero-object if it's state machine was discarded +754 if not OS[0]: +755 del OS[0] +756 del ap_state_pointers[0] +757 +758 # all_aps = {action_name: [AP]} +759 all_aps: Dict[str, List[AP]] = defaultdict(list) +760 for aps in ap_state_pointers.values(): +761 for ap in aps: +762 all_aps[ap.action.name].append(ap) +763 +764 state_params = defaultdict(dict) +765 state_params_to_hyps = defaultdict(dict) +766 for sort in bindings: +767 state_params[sort] = defaultdict(dict) +768 state_params_to_hyps[sort] = defaultdict(dict) +769 for state in bindings[sort]: +770 keys = {b.param for b in bindings[sort][state]} +771 typ = None +772 for key in keys: +773 hyps = [ +774 b.hypothesis for b in bindings[sort][state] if b.param == key +775 ] +776 # assert that all are the same G_ +777 assert len(set([h.G_ for h in hyps])) == 1 +778 state_params[sort][state][key] = hyps[0].G_ +779 state_params_to_hyps[sort][state][key] = hyps +780 +781 if viz: +782 LOCM._debug_state_machines(OS, ap_state_pointers, state_params) 783 -784 start_fluent = LearnedLiftedFluent( -785 start_fluent_temp.name, -786 start_fluent_temp.param_sorts, -787 [ap.pos - 1] + bound_param_inds, -788 ) -789 fluents[sort][start_state] = start_fluent -790 actions[ap.action.name].update_precond(start_fluent) -791 -792 if start_state != end_state: -793 # del += fluent for origin state -794 actions[ap.action.name].update_delete(start_fluent) -795 -796 # add += fluent for destination state -797 end_fluent_temp = fluents[sort][end_state] -798 bound_param_inds = [] -799 if sort in bindings and end_state in bindings[sort]: -800 bound_param_inds = [ -801 b.hypothesis.l_ - 1 for b in bindings[sort][end_state] -802 ] -803 end_fluent = LearnedLiftedFluent( -804 end_fluent_temp.name, -805 end_fluent_temp.param_sorts, -806 [ap.pos - 1] + bound_param_inds, -807 ) -808 fluents[sort][end_state] = end_fluent -809 actions[ap.action.name].update_add(end_fluent) +784 fluents = defaultdict(dict) +785 actions = {} +786 for sort in ap_state_pointers: +787 sort_str = f"sort{sort}" +788 for ap in ap_state_pointers[sort]: +789 if ap.action.name not in actions: +790 actions[ap.action.name] = LearnedLiftedAction( +791 ap.action.name, +792 [None for _ in range(len(all_aps[ap.action.name]))], # type: ignore +793 ) +794 a = actions[ap.action.name] +795 a.param_sorts[ap.pos - 1] = sort_str +796 +797 start_pointer, end_pointer = ap_state_pointers[sort][ap] +798 start_state, end_state = LOCM._pointer_to_set( +799 OS[sort], start_pointer, end_pointer +800 ) +801 +802 start_fluent_name = f"sort{sort}_state{start_state}" +803 if start_fluent_name not in fluents[ap.action.name]: +804 start_fluent = LearnedLiftedFluent( +805 start_fluent_name, +806 param_sorts=[sort_str], +807 param_act_inds=[ap.pos - 1], +808 ) +809 fluents[ap.action.name][start_fluent_name] = start_fluent 810 -811 fluents = set(fluent for sort in fluents.values() for fluent in sort.values()) -812 actions = set(actions.values()) -813 -814 # Step 6: Extraction of static preconditions -815 for action in actions: -816 if action.name in statics: -817 for static in statics[action.name]: -818 action.update_precond(static) -819 -820 if debug: -821 pprint(fluents) -822 pprint(actions) -823 -824 return fluents, actions +811 start_fluent = fluents[ap.action.name][start_fluent_name] +812 +813 if ( +814 sort in state_params_to_hyps +815 and start_state in state_params_to_hyps[sort] +816 ): +817 for param in state_params_to_hyps[sort][start_state]: +818 psort = None +819 pind = None +820 for hyp in state_params_to_hyps[sort][start_state][param]: +821 if hyp.C == ap: +822 assert psort is None or psort == hyp.G_ +823 assert pind is None or pind == hyp.l_ +824 psort = hyp.G_ +825 pind = hyp.l_ +826 assert psort is not None +827 assert pind is not None +828 start_fluent.param_sorts.append(f"sort{psort}") +829 start_fluent.param_act_inds.append(pind - 1) +830 +831 a.update_precond(start_fluent) +832 +833 if end_state != start_state: +834 end_fluent_name = f"sort{sort}_state{end_state}" +835 if end_fluent_name not in fluents[ap.action.name]: +836 end_fluent = LearnedLiftedFluent( +837 end_fluent_name, +838 param_sorts=[sort_str], +839 param_act_inds=[ap.pos - 1], +840 ) +841 fluents[ap.action.name][end_fluent_name] = end_fluent +842 +843 end_fluent = fluents[ap.action.name][end_fluent_name] +844 +845 if ( +846 sort in state_params_to_hyps +847 and end_state in state_params_to_hyps[sort] +848 ): +849 for param in state_params_to_hyps[sort][end_state]: +850 psort = None +851 pind = None +852 for hyp in state_params_to_hyps[sort][end_state][param]: +853 if hyp.B == ap: +854 assert psort is None or psort == hyp.G_ +855 assert pind is None or pind == hyp.k_ +856 psort = hyp.G_ +857 pind = hyp.k_ +858 assert psort is not None +859 assert pind is not None +860 end_fluent.param_sorts.append(f"sort{psort}") +861 end_fluent.param_act_inds.append(pind - 1) +862 +863 a.update_delete(start_fluent) +864 a.update_add(end_fluent) +865 +866 # Step 6: Extraction of static preconditions +867 for action in actions.values(): +868 if action.name in statics: +869 for static in statics[action.name]: +870 action.update_precond(static) +871 +872 return set( +873 fluent +874 for action_fluents in fluents.values() +875 for fluent in action_fluents.values() +876 ), set(actions.values())Inherited Members
171 def __new__( -172 cls, -173 obs_tracelist: ObservedTraceList, -174 statics: Optional[Statics] = None, -175 viz: bool = False, -176 view: bool = False, -177 debug: Union[bool, Dict[str, bool], List[str]] = False, -178 ): -179 """Creates a new Model object. -180 Args: -181 observations (ObservationList): -182 The state observations to extract the model from. -183 statics (Dict[str, List[str]]): -184 A dictionary mapping an action name and its arguments to the -185 list of static preconditions of the action. A precondition should -186 be a tuple, where the first element is the predicate name and the -187 rest correspond to the arguments of the action (1-indexed). -188 E.g. static( next(C1, C2), put_on_card_in_homecell(C2, C1, _) ) -189 should is provided as: {"put_on_card_in_homecell": [("next", 2, 1)]} -190 viz (bool): -191 Whether to visualize the FSM. -192 view (bool): -193 Whether to view the FSM visualization. -194 -195 Raises: -196 IncompatibleObservationToken: -197 Raised if the observations are not identity observation. -198 """ -199 if obs_tracelist.type is not ActionObservation: -200 raise IncompatibleObservationToken(obs_tracelist.type, LOCM) -201 -202 if len(obs_tracelist) != 1: -203 warn("LOCM only supports a single trace, using first trace only") -204 -205 if isinstance(debug, bool) and debug: -206 debug = defaultdict(lambda: True) -207 elif isinstance(debug, dict): -208 debug = defaultdict(lambda: False, debug) -209 elif isinstance(debug, list): -210 debug = defaultdict(lambda: False, {k: True for k in debug}) -211 else: -212 debug = defaultdict(lambda: False) -213 -214 obs_trace = obs_tracelist[0] -215 fluents, actions = None, None -216 -217 sorts = LOCM._get_sorts(obs_trace, debug=debug["get_sorts"]) +@@ -5673,57 +5949,15 @@188 def __new__( +189 cls, +190 obs_tracelist: ObservedTraceList, +191 statics: Optional[Statics] = None, +192 viz: bool = False, +193 view: bool = False, +194 debug: Union[bool, Dict[str, bool], List[str]] = False, +195 ): +196 """Creates a new Model object. +197 Args: +198 observations (ObservationList): +199 The state observations to extract the model from. +200 statics (Dict[str, List[str]]): +201 A dictionary mapping an action name and its arguments to the +202 list of static preconditions of the action. A precondition should +203 be a tuple, where the first element is the predicate name and the +204 rest correspond to the arguments of the action (1-indexed). +205 E.g. static( next(C1, C2), put_on_card_in_homecell(C2, C1, _) ) +206 should is provided as: {"put_on_card_in_homecell": [("next", 2, 1)]} +207 viz (bool): +208 Whether to visualize the FSM. +209 view (bool): +210 Whether to view the FSM visualization. +211 +212 Raises: +213 IncompatibleObservationToken: +214 Raised if the observations are not identity observation. +215 """ +216 if obs_tracelist.type is not ActionObservation: +217 raise IncompatibleObservationToken(obs_tracelist.type, LOCM) 218 -219 if debug["sorts"]: -220 print(f"Sorts:\n{sorts}", end="\n\n") +219 if len(obs_tracelist) != 1: +220 warn("LOCM only supports a single trace, using first trace only") 221 -222 TS, ap_state_pointers, OS = LOCM._step1(obs_trace, sorts, debug["step1"]) -223 HS = LOCM._step3(TS, ap_state_pointers, OS, sorts, debug["step3"]) -224 bindings = LOCM._step4(HS, debug["step4"]) -225 bindings = LOCM._step5(HS, bindings, debug["step5"]) -226 fluents, actions = LOCM._step7( -227 OS, -228 ap_state_pointers, -229 sorts, -230 bindings, -231 statics if statics is not None else {}, -232 debug["step7"], -233 ) -234 -235 if viz: -236 state_machines = LOCM.get_state_machines(ap_state_pointers, OS, bindings) -237 for sm in state_machines: -238 sm.render(view=view) -239 -240 return Model(fluents, actions) +222 if isinstance(debug, bool) and debug: +223 debug = defaultdict(lambda: True) +224 elif isinstance(debug, dict): +225 debug = defaultdict(lambda: False, debug) +226 elif isinstance(debug, list): +227 debug = defaultdict(lambda: False, {k: True for k in debug}) +228 else: +229 debug = defaultdict(lambda: False) +230 +231 obs_trace = obs_tracelist[0] +232 fluents, actions = None, None +233 +234 sorts = LOCM._get_sorts(obs_trace, debug=debug["get_sorts"]) +235 +236 if debug["sorts"]: +237 sortid2objs = {v: [] for v in set(sorts.values())} +238 for k, v in sorts.items(): +239 sortid2objs[v].append(k) +240 print("\nSorts:\n") +241 pprint(sortid2objs) +242 print("\n") +243 +244 TS, ap_state_pointers, OS = LOCM._step1(obs_trace, sorts, debug["step1"]) +245 HS = LOCM._step3(TS, ap_state_pointers, OS, sorts, debug["step3"]) +246 bindings = LOCM._step4(HS, debug["step4"]) +247 bindings = LOCM._step5(HS, bindings, debug["step5"]) +248 fluents, actions = LOCM._step7( +249 OS, +250 ap_state_pointers, +251 sorts, +252 bindings, +253 statics if statics is not None else {}, +254 debug["step7"], +255 viz, +256 ) +257 +258 return Model(fluents, actions)Inherited Members
658 @staticmethod -659 def get_state_machines( -660 ap_state_pointers: APStatePointers, -661 OS: OSType, -662 bindings: Optional[Bindings] = None, -663 ): -664 from graphviz import Digraph -665 -666 state_machines = [] -667 for (sort, trans), states in zip(ap_state_pointers.items(), OS.values()): -668 graph = Digraph(f"LOCM-step1-sort{sort}") -669 for state in range(len(states)): -670 label = f"state{state}" -671 if ( -672 bindings is not None -673 and sort in bindings -674 and state in bindings[sort] -675 ): -676 label += f"\n[" -677 params = [] -678 for binding in bindings[sort][state]: -679 params.append(f"{binding.hypothesis.G_}") -680 label += f",".join(params) -681 label += f"]" -682 graph.node(str(state), label=label, shape="oval") -683 for ap, apstate in trans.items(): -684 start_idx, end_idx = LOCM._pointer_to_set( -685 states, apstate.start, apstate.end -686 ) -687 graph.edge( -688 str(start_idx), str(end_idx), label=f"{ap.action.name}.{ap.pos}" -689 ) -690 -691 state_machines.append(graph) -692 -693 return state_machines -
Samples goals by randomly generating candidate goal states k (steps_deep
) steps deep, then running planners on those
+
Samples goals by randomly generating candidate goal states k (steps_deep
) steps deep, then running planners on those
goal states to ensure the goals are complex enough (i.e. cannot be reached in too few steps). Candidate
goal states are generated for a set amount of time indicated by MAX_GOAL_SEARCH_TIME, and the goals with the
longest plans (the most complex goals) are selected.
Grounded action.
An Action represents a grounded action in a Trace or a Model. The action's
-precond
, add
, and delete
attributes characterize a Model, and are
+precond
, add
, and delete
attributes characterize a Model, and are
found during model extraction.
Attributes: @@ -527,7 +632,7 @@