diff --git a/.gitignore b/.gitignore
index eae90acb..ac6e0101 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,4 +12,6 @@ __init__.pyc
generated_testing_files/
new_domain.pddl
new_prob.pddl
-test_model.json
\ No newline at end of file
+test_model.json
+results.txt
+html/
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index eeda14a5..b13c3d30 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -5,11 +5,13 @@
### Installing
Install macq for development by cloning the repository and running
-`pip install .[dev]`
+`pip install -e .[dev]`
We recommend installing in a virtual environment to avoid package version
conflicts.
+**Note: `tarski` requires [`clingo`](https://potassco.org/clingo/) be installed to work.**
+
### Formatting
We use [black](https://black.readthedocs.io/en/stable/) for easy and consistent
@@ -34,3 +36,12 @@ report, run `pytest --cov=macq --cov-report=html`, and open `htmlcov/index.html`
in a browser. This will provide detailed line by line test coverage information,
so you can identify what specifically still needs testing.
+### Generating Docs
+To generate the HTML documentation, run `pdoc --html macq --config latex_math=True`.
+
+During development, you can run a local HTTP server to reference/see live
+changes to the documentation: `pdoc --http : macq --config latex_math=True`.
+
+*Note: `--config latex_math=True` is required to properly render the latex found
+in many extraction techniques' documentation.*
+
diff --git a/README.md b/README.md
index 0bee1b8a..b37f55e9 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ This library is a collection of tools for planning-like action model acquisition
## Usage
```python
from macq import generate, extract
-from macq.observation import IdentityObservation
+from macq.observation import IdentityObservation, AtomicPartialObservation
# get a domain-specific generator: uses api.planning.domains problem_id/
# generate 100 traces of length 20 using vanilla sampling
@@ -39,50 +39,6 @@ trace.actions
trace.get_pre_states(action) # get the state before each occurance of action
trace.get_post_states(action) # state after each occurance of action
trace.get_total_cost()
-
-######################################################################
-# Model Extraction - OBSERVER Technique
-######################################################################
-observations = traces.tokenize(IdentityObservation)
-model = extract.Extract(observations, extract.modes.OBSERVER)
-model.details()
-
-Model:
- Fluents: at stone stone-03 location pos-04-06, at stone stone-01 location pos-04-06, at stone stone-02 location pos-05-06, at stone stone-06 location pos-07-04, at stone stone-11 ...
- Actions:
- push-to-goal stone stone-04 location pos-04-05 location pos-04-06 direction dir-up location pos-04-04 player player-01:
- precond:
- at player player-01 location pos-04-06
- at stone stone-04 location pos-04-05
- clear location pos-05-06
- ...
- add:
- at stone stone-04 location pos-04-04
- clear location pos-04-06
- at-goal stone stone-04
- at player player-01 location pos-04-05
- delete:
- at stone stone-04 location pos-04-05
- clear location pos-04-04
- at player player-01 location pos-04-06
- ...
-######################################################################
-# Model Extraction - SLAF Technique
-######################################################################
-traces = generate.pddl.VanillaSampling(problem_id = 123, plan_len = 2, num_traces = 1).traces
-observations = traces.tokenize(PartialObservabilityToken, method=PartialObservabilityToken.random_subset, percent_missing=0.10)
-model = Extract(observations, modes.SLAF)
-model.details()
-
-Model:
- Fluents: clear location pos-06-09, clear location pos-02-05, clear location pos-08-08, clear location pos-10-05, clear location pos-02-06, clear location pos-10-02, clear location pos-01-01, at stone stone-05 location pos-08-05, at stone stone-07 location pos-08-06, at stone stone-03 location pos-07-04, clear location pos-03-06, clear location pos-10-06, clear location pos-10-10, clear location pos-05-09, clear location pos-05-07, clear location pos-02-07, clear location pos-09-01, at stone stone-06 location pos-04-06, clear location pos-02-03, clear location pos-07-05, clear location pos-09-10, clear location pos-06-05, at stone stone-01 location pos-05-04, clear location pos-02-10, clear location pos-06-10, clear location pos-11-03, at stone stone-11 location pos-06-08, at stone stone-08 location pos-04-07, clear location pos-01-10, clear location pos-07-03, clear location pos-02-11, clear location pos-03-01, clear location pos-06-02, clear location pos-03-02, clear location pos-11-01, clear location pos-06-03, clear location pos-08-04, clear location pos-09-11, at stone stone-09 location pos-08-07, clear location pos-09-07, clear location pos-06-07, clear location pos-10-01, clear location pos-11-09, clear location pos-03-05, clear location pos-07-06, clear location pos-05-05, at stone stone-12 location pos-07-08, clear location pos-10-03, clear location pos-11-11, clear location pos-10-09, clear location pos-02-01, clear location pos-02-02, clear location pos-01-02, at stone stone-02 location pos-06-04, clear location pos-03-10, clear location pos-05-10, clear location pos-07-10, clear location pos-09-05, clear location pos-07-09, clear location pos-05-03, clear location pos-10-11, clear location pos-01-03, at stone stone-04 location pos-04-05, clear location pos-07-02, clear location pos-09-06, clear location pos-10-07, clear location pos-01-09, clear location pos-03-07, clear location pos-04-04, clear location pos-01-11
- Actions:
- move player player-01 direction dir-left location pos-05-02 location pos-06-02:
- precond:
- add:
- delete:
- (clear location pos-05-02)
- (at player player-01 location pos-06-02)
```
## Coverage
@@ -92,9 +48,9 @@ Model:
- [x] [Learning Planning Operators by Observation and Practice](https://aaai.org/Papers/AIPS/1994/AIPS94-057.pdf) (AIPS'94)
- [ ] [Learning by Experimentation: Incremental Refinement of Incomplete Planning Domains](https://www.sciencedirect.com/science/article/pii/B9781558603356500192) (ICML'94)
- [ ] [Learning Probabilistic Relational Planning Rules](https://people.csail.mit.edu/lpk/papers/2005/zpk-aaai05.pdf) (ICAPS'04)
-- [ ] [Learning Action Models from Plan Examples with Incomplete Knowledge](https://www.aaai.org/Papers/ICAPS/2005/ICAPS05-025.pdf) (ICAPS'05)
+- [x] [Learning Action Models from Plan Examples with Incomplete Knowledge](https://www.aaai.org/Papers/ICAPS/2005/ICAPS05-025.pdf) (ICAPS'05)
- [ ] [Learning Planning Rules in Noisy Stochastic Worlds](https://people.csail.mit.edu/lpk/papers/2005/zpk-aaai05.pdf) (AAAI'05)
-- [ ] [Learning action models from plan examples using weighted MAX-SAT](https://www.sciencedirect.com/science/article/pii/S0004370206001408) (AIJ'07)
+- [x] [Learning action models from plan examples using weighted MAX-SAT](https://www.sciencedirect.com/science/article/pii/S0004370206001408) (AIJ'07)
- [ ] [Learning Symbolic Models of Stochastic Domains](https://www.aaai.org/Papers/JAIR/Vol29/JAIR-2910.pdf) (JAIR'07)
- [x] [Learning Partially Observable Deterministic Action Models](https://www.aaai.org/Papers/JAIR/Vol33/JAIR-3310.pdf) (JAIR'08)
- [ ] [Acquisition of Object-Centred Domain Models from Planning Examples](https://ojs.aaai.org/index.php/ICAPS/article/view/13391) (ICAPS'09)
@@ -113,7 +69,7 @@ Model:
- [ ] [Learning STRIPS Action Models with Classical Planning](https://arxiv.org/abs/1903.01153) (ICAPS'18)
- [ ] [Learning Planning Operators from Episodic Traces](https://aaai.org/ocs/index.php/SSS/SSS18/paper/view/17594/15530) (AAAI-SS'18)
- [ ] [Learning action models with minimal observability](https://www.sciencedirect.com/science/article/abs/pii/S0004370218304259) (AIJ'19)
-- [ ] [Learning Action Models from Disordered and Noisy Plan Traces](https://arxiv.org/abs/1908.09800) (arXiv'19)
+- [x] [Learning Action Models from Disordered and Noisy Plan Traces](https://arxiv.org/abs/1908.09800) (arXiv'19)
- [ ] [Bridging the Gap: Providing Post-Hoc Symbolic Explanations for Sequential Decision-Making Problems with Black Box Simulators](https://arxiv.org/abs/2002.01080) (ICML-WS'20)
- [ ] [STRIPS Action Discovery](https://arxiv.org/abs/2001.11457) (arXiv'20)
- [ ] [Learning First-Order Symbolic Representations for Planning from the Structure of the State Space](https://arxiv.org/abs/1909.05546) (ECAI'20)
diff --git a/docs/extract/amdn.md b/docs/extract/amdn.md
new file mode 100644
index 00000000..a31a4f8f
--- /dev/null
+++ b/docs/extract/amdn.md
@@ -0,0 +1,13 @@
+# Usage
+
+```python
+from macq import generate, extract
+
+print(model.details())
+```
+
+**Output:**
+```text
+```
+
+# API Documentation
diff --git a/docs/extract/arms.md b/docs/extract/arms.md
new file mode 100644
index 00000000..aa4f7a61
--- /dev/null
+++ b/docs/extract/arms.md
@@ -0,0 +1,77 @@
+# Usage
+
+```python
+from macq import generate, extract
+from macq.trace import PlanningObject, Fluent, TraceList
+from macq.observation import PartialObservation
+
+def get_fluent(name: str, objs: list[str]):
+ objects = [PlanningObject(o.split()[0], o.split()[1]) for o in objs]
+ return Fluent(name, objects)
+
+traces = TraceList()
+generator = generate.pddl.TraceFromGoal(problem_id=1801)
+
+generator.change_goal(
+ {
+ get_fluent("communicated_soil_data", ["waypoint waypoint2"]),
+ get_fluent("communicated_rock_data", ["waypoint waypoint3"]),
+ get_fluent(
+ "communicated_image_data", ["objective objective1", "mode high_res"]
+ ),
+ }
+)
+traces.append(generator.generate_trace())
+
+generator.change_goal(
+ {
+ get_fluent("communicated_soil_data", ["waypoint waypoint2"]),
+ get_fluent("communicated_rock_data", ["waypoint waypoint3"]),
+ get_fluent(
+ "communicated_image_data", ["objective objective1", "mode high_res"]
+ ),
+ }
+)
+traces.append(generator.generate_trace())
+
+observations = traces.tokenize(PartialObservation, percent_missing=0.60)
+model = extract.Extract(
+ observations,
+ extract.modes.ARMS,
+ upper_bound=2,
+ min_support=2,
+ action_weight=110,
+ info_weight=100,
+ threshold=0.6,
+ info3_default=30,
+ plan_default=30,
+)
+
+print(model.details())
+```
+
+**Output:**
+```text
+Model:
+ Fluents: (at rover rover0 waypoint waypoint2), (have_soil_analysis rover rover0 waypoint waypoint2), (have_soil_analysis rover rover0 waypoint waypoint3), ...
+ Actions:
+ (communicate_image_data rover waypoint mode objective lander waypoint):
+ precond:
+ calibrated camera rover
+ have_rock_analysis rover waypoint
+ communicated_rock_data waypoint
+ channel_free lander
+ at_soil_sample waypoint
+ at_rock_sample waypoint
+ add:
+ calibrated camera rover
+ at rover waypoint
+ have_image rover objective mode
+ channel_free lander
+ communicated_image_data objective mode
+ delete:
+ calibrated camera rover
+ ...
+```
+
+# API Documentation
diff --git a/docs/extract/extract.md b/docs/extract/extract.md
new file mode 100644
index 00000000..88b7e77e
--- /dev/null
+++ b/docs/extract/extract.md
@@ -0,0 +1,9 @@
+# Usage
+
+## Debugging
+Include the argument `debug=True` to `Extract` to enable debugging for any
+extraction technique.
+
+*Note: debugging output and interfaces are unique to each method.*
+
+# API Documentation
diff --git a/docs/extract/observer.md b/docs/extract/observer.md
new file mode 100644
index 00000000..2bcef454
--- /dev/null
+++ b/docs/extract/observer.md
@@ -0,0 +1,37 @@
+# Usage
+
+```python
+from macq import generate, extract
+from macq.observation import IdentityObservation
+
+traces = generate.pddl.VanillaSampling(problem_id=123, plan_len=20, num_traces=100).traces
+observations = traces.tokenize(IdentityObservation)
+model = extract.Extract(observations, extract.modes.OBSERVER)
+
+print(model.details())
+```
+
+**Output:**
+```text
+Model:
+ Fluents: at stone stone-03 location pos-04-06, at stone stone-01 location pos-04-06, at stone stone-02 location pos-05-06, at stone stone-06 location pos-07-04, at stone stone-11 ...
+ Actions:
+ push-to-goal stone stone-04 location pos-04-05 location pos-04-06 direction dir-up location pos-04-04 player player-01:
+ precond:
+ at player player-01 location pos-04-06
+ at stone stone-04 location pos-04-05
+ clear location pos-05-06
+ ...
+ add:
+ at stone stone-04 location pos-04-04
+ clear location pos-04-06
+ at-goal stone stone-04
+ at player player-01 location pos-04-05
+ delete:
+ at stone stone-04 location pos-04-05
+ clear location pos-04-04
+ at player player-01 location pos-04-06
+ ...
+```
+
+# API Documentation
diff --git a/docs/extract/slaf.md b/docs/extract/slaf.md
new file mode 100644
index 00000000..49acbf64
--- /dev/null
+++ b/docs/extract/slaf.md
@@ -0,0 +1,28 @@
+# Usage
+
+```python
+from macq import generate, extract
+from macq.observation import AtomicPartialObservation
+
+traces = generate.pddl.VanillaSampling(problem_id=123, plan_len=2, num_traces=1).traces
+observations = traces.tokenize(AtomicPartialObservation, percent_missing=0.10)
+model = Extract(observations, extract.modes.SLAF)
+print(model.details())
+```
+
+**Output:**
+```text
+Model:
+ Fluents: clear location pos-06-09, clear location pos-02-05, clear location pos-08-08, clear location pos-10-05, ...
+ Actions:
+ move player player-01 direction dir-left location pos-05-02 location pos-06-02:
+ precond:
+ add:
+ delete:
+ (clear location pos-05-02)
+ (at player player-01 location pos-06-02)
+ ...
+ ...
+```
+
+# API Documentation
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 00000000..ae954d30
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,17 @@
+# Usage Documentation
+
+## Trace Generation
+- [VanillaSampling](extract/observer/#usage)
+
+## Tokenization
+- [IdentityObservation](extract/observer#usage)
+- [AtomicPartialObservation](extract/slaf#usage)
+
+## Extraction Techniques
+- [Observer](extract/observer#usage)
+- [SLAF](extract/slaf#usage)
+- [ARMS](extract/arms#usage)
+- [AMDN](extract/amdn#usage)
+
+
+# API Documentation
diff --git a/macq/__init__.py b/macq/__init__.py
index e69de29b..0fe09475 100644
--- a/macq/__init__.py
+++ b/macq/__init__.py
@@ -0,0 +1,3 @@
+"""
+.. include:: ../docs/index.md
+"""
diff --git a/macq/extract/__init__.py b/macq/extract/__init__.py
index 20142941..8a236f86 100644
--- a/macq/extract/__init__.py
+++ b/macq/extract/__init__.py
@@ -1,14 +1,15 @@
-from .model import Model
-from .extract import Extract, modes, IncompatibleObservationToken, SLAF
from .learned_fluent import LearnedFluent
from .learned_action import LearnedAction
+from .model import Model, LearnedAction
+from .extract import Extract, modes
+from .exceptions import IncompatibleObservationToken
+from .model import Model
__all__ = [
+ "LearnedAction",
+ "LearnedFluent",
"Model",
"Extract",
"modes",
"IncompatibleObservationToken",
- "LearnedAction",
- "SLAF",
- "LearnedFluent",
-]
+]
\ No newline at end of file
diff --git a/macq/extract/amdn.py b/macq/extract/amdn.py
index 3981f357..034c859e 100644
--- a/macq/extract/amdn.py
+++ b/macq/extract/amdn.py
@@ -1,95 +1,683 @@
+""".. include:: ../../docs/extract/amdn.md"""
+
+from macq.trace import Fluent, Action # for typing
+from macq.extract.learned_action import LearnedAction
+from nnf.operators import implies
+
import macq.extract as extract
+from typing import Dict, List, Optional, Union, Hashable
+from nnf import Aux, Var, And, Or
+from bauhaus import Encoding # only used for pretty printing in debug mode
+from .exceptions import (
+ IncompatibleObservationToken,
+)
from .model import Model
-from ..trace import ObservationLists
-from ..observation import NoisyPartialDisorderedParallelObservation
+from ..trace import ActionPair
+from ..observation import NoisyPartialDisorderedParallelObservation, ObservationLists
+from ..utils.pysat import to_wcnf, extract_raw_model
+
+e = Encoding
+
+def pre(r: Fluent, act: Action):
+ """Create a Var that enforces that the given fluent is a precondition of the given action.
+
+ Args:
+ r (Fluent):
+ The precondition to be added.
+ act (Action):
+ The action that the precondition will be added to.
+ Returns:
+ The Var that enforces that the given fluent is a precondition of the given action.
+ """
+ return Var("(" + str(r)[1:-1] + " is a precondition of " + act.details() + ")")
+
+def add(r: Fluent, act: Action):
+ """Create a Var that enforces that the given fluent is an add effect of the given action.
+
+ Args:
+ r (Fluent):
+ The add effect to be added.
+ act (Action):
+ The action that the add effect will be added to.
+ Returns:
+ The Var that enforces that the given fluent is an add effect of the given action.
+ """
+ return Var("(" + str(r)[1:-1] + " is added by " + act.details() + ")")
+
+def delete(r: Fluent, act: Action):
+ """Create a Var that enforces that the given fluent is a delete effect of the given action.
+
+ Args:
+ r (Fluent):
+ The delete effect to be added.
+ act (Action):
+ The action that the delete effect will be added to.
+ Returns:
+ The Var that enforces that the given fluent is a delete effect of the given action.
+ """
+ return Var("(" + str(r)[1:-1] + " is deleted by " + act.details() + ")")
+
+WMAX = 1
class AMDN:
- def __new__(cls, obs_lists: ObservationLists):
+ def __new__(cls, obs_lists: ObservationLists, debug: bool = False, occ_threshold: int = 1):
"""Creates a new Model object.
Args:
obs_lists (ObservationList):
The state observations to extract the model from.
+ debug (bool):
+ Optional debugging mode.
+ occ_threshold (int):
+ Threshold to be used for noise constraints.
+
Raises:
IncompatibleObservationToken:
Raised if the observations are not identity observation.
"""
if obs_lists.type is not NoisyPartialDisorderedParallelObservation:
- raise extract.IncompatibleObservationToken(obs_lists.type, AMDN)
-
- # TODO:
- # iterate through all tokens and create two base sets for all actions and propositions; store as attributes
- # iterate through all pairs of parallel action sets and create a dictionary of the probability of ax and ay being disordered -
- # (this will make it easy and efficient to refer to later, and prevents unnecessary recalculations). store as attribute
- # also create a list of all tuples, store as attribute
-
- #return Model(fluents, actions)
-
- def _build_disorder_constraints(self):
- # TODO:
- # iterate through all pairs of parallel action sets
- # for each pair, iterate through all possible action combinations
- # calculate the probability of the actions being disordered (p)
- # for each action combination, iterate through all possible propositions
- # for each action x action x proposition combination, enforce the following constraint if the actions are ordered:
- # enforce all [constraint 1] with weight (1 - p) x wmax
-
- # likewise, enforce the following constraint if the actions are disordered:
- # enforce all [constraint 2] with weight p x wmax
- pass
-
- def _build_hard_parallel_constraints(self):
- # TODO:
- # iterate through the list of tuples
- # for each action x proposition pair, enforce the two hard constraints with weight wmax
- pass
-
- def _build_soft_parallel_constraints(self):
- # TODO:
- # iterate through all parallel action sets
- # within each parallel action set, iterate through the same action set again to compare
- # each action to every other action in the set; assuming none are disordered
- # enforce all [constraint 4] with weight (1 - p) x wmax
-
- # then, iterate through all pairs of action sets
- # assuming the actions are disordered, check ay against EACH action in ax, for each pair
- # enforce all [constraint 5] with weight p x wmax
- pass
-
- def _build_parallel_constraints(self):
- # TODO:
- # call the above two functions
- pass
-
- def _build_noise_constraints(self):
- # TODO:
- # iterate through all tuples
- # for each tuple: iterate through each step over ALL the plan traces
- # count the number of occurrences; if higher than the user-provided parameter delta,
- # store this tuple as a dictionary entry in a list of dictionaries (something like
- # [{"action and proposition": , "occurrences of r:" 5}]).
- # after all iterations are through, iterate through all the tuples in this dictionary,
- # and set [constraint 6] with the calculated weight.
- # TODO: Ask - what "occurrences of all propositions" refers to - is it the total number of steps...?
+ raise IncompatibleObservationToken(obs_lists.type, AMDN)
+
+ return AMDN._amdn(obs_lists, debug, occ_threshold)
+
+ @staticmethod
+ def _amdn(obs_lists: ObservationLists, debug: bool, occ_threshold: int):
+ """Main driver for the entire AMDN algorithm.
+ The first line contains steps 1-4.
+ The second line contains step 5.
+ Finally, the final line corresponds to step 6 (return the model).
- # store the initial state s0
- # iterate through every step in the plan trace
- # at each step, check all the propositions r in the current state
- # if r is not in s0, enforce [constraint 7] with the calculated weight
- # TODO: Ask - what happens when you find the first r? I assume you keep iterating through the rest of the trace,
- # continuing the process with different propositions? Do we still count the occurrences of each proposition through
- # the entire trace to use when we calculate the weight?
+ Args:
+ obs_lists (ObservationLists):
+ The tokens to be fed into the algorithm.
+ debug (bool):
+ Optional debugging mode.
+ occ_threshold (int):
+ Threshold to be used for noise constraints.
+
+ Returns:
+ The extracted `Model`.
+ """
+ wcnf, decode = AMDN._solve_constraints(obs_lists, occ_threshold, debug)
+ raw_model = extract_raw_model(wcnf, decode)
+ return AMDN._extract_model(obs_lists, raw_model)
+
+ @staticmethod
+ def _or_refactor(maybe_lit: Union[Or, Var]):
+ """Converts a "Var" fluent to an "Or" fluent.
+
+ Args:
+ maybe_lit (Union[Or, Var]):
+ Fluent that is either type "Or" or "Var."
+
+ Returns:
+ A corresponding fluent of type "Or."
+ """
+ return Or([maybe_lit]) if isinstance(maybe_lit, Var) else maybe_lit
+
+ @staticmethod
+ def _extract_aux_set_weights(cnf_formula: And[Or[Var]], constraints: Dict, prob_disordered: float):
+ """Sets each clause in a CNF formula as a hard constraint, then sets any auxiliary variables to
+ the appropriate weight detailed in the "Constraint DC" section of the AMDN paper.
+ Used to help create disorder constraints.
+
+ Args:
+ cnf_formula (And[Or[Var]]):
+ The CNF formula to extract the clauses and auxiliary variables from.
+ constraints (Dict):
+ The existing dictionary of disorder constraints.
+ prob_disordered (float):
+ The probability that the two actions relevant fot this constraint are disordered.
+ """
+ # find all the auxiliary variables
+ for clause in cnf_formula.children:
+ for var in clause.children:
+ if isinstance(var.name, Aux) and var.true:
+ # aux variables are the soft clauses that get the original weight
+ constraints[AMDN._or_refactor(var)] = prob_disordered * WMAX
+ # set each original constraint to be a hard clause
+ constraints[clause] = "HARD"
+
+ @staticmethod
+ def _get_observe(obs_lists: ObservationLists):
+ """Gets from the user which fluents they want to observe (for debug mode).
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that contain the fluents.
+
+ Returns:
+ A list of of which fluents the user wants to observe.
+ """
+ print("Select a proposition to observe:")
+ sorted_f = [str(f) for f in obs_lists.propositions]
+ sorted_f.sort()
+ for f in sorted_f:
+ print(f)
+ to_obs = []
+ user_input = ""
+ while user_input != "x":
+ user_input = input(
+ "Which fluents do you want to observe? Enter 'x' when you are finished.\n"
+ )
+ if user_input in sorted_f:
+ to_obs.append(user_input[1:-1])
+ print(user_input + " added to the debugging list.")
+ else:
+ if user_input != "x":
+ print("The fluent you entered is invalid.")
+ return to_obs
+
+ @staticmethod
+ def _debug_is_observed(constraint: Or, to_obs: List[str]):
+ """Determines if the given constraint contains a fluent that is being observed in debug mode.
+
+ Args:
+ constraint (Or):
+ The constraint to be analyzed.
+ to_obs (List[str]):
+ The list of fluents being observed.
+
+ Returns:
+ A bool that determines if the constraint should be observed or not.
+ """
+ for c in constraint.children:
+ for v in to_obs:
+ if v in str(c):
+ return True
+ return False
+
+ @staticmethod
+ def _debug_simple_pprint(constraints: Dict, to_obs: List[str]):
+ """Pretty print used for simple formulas in debug mode.
+
+ Args:
+ constraints (Dict):
+ The constraints/weights to be pretty printed.
+ to_obs (List[str]):
+ The fluents being observed.
+ """
+ for c in constraints:
+ observe = AMDN._debug_is_observed(c, to_obs)
+ if observe:
+ e.pprint(e, c)
+ print(constraints[c])
+
+ @staticmethod
+ def _debug_aux_pprint(constraints: Dict, to_obs: List[str]):
+ """Pretty print used for formulas with auxiliary variables in debug mode.
+
+ Args:
+ constraints (Dict):
+ The constraints/weights to be pretty printed.
+ to_obs (List[str]):
+ The fluents being observed.
+ """
+ aux_map = {}
+ index = 0
+ for c in constraints:
+ if AMDN._debug_is_observed(c, to_obs):
+ for var in c.children:
+ if isinstance(var.name, Aux) and var.name not in aux_map:
+ aux_map[var.name] = f"aux {index}"
+ index += 1
+
+ all_pretty_c = {}
+ for c in constraints:
+ if AMDN._debug_is_observed(c, to_obs):
+ pretty_c = []
+ for var in c.children:
+ if isinstance(var.name, Aux):
+ if var.true:
+ pretty_c.append(Var(aux_map[var.name]))
+ all_pretty_c[AMDN._or_refactor(var)] = Or([Var(aux_map[var.name])])
+ else:
+ pretty_c.append(~Var(aux_map[var.name]))
+ else:
+ pretty_c.append(var)
+ # map disorder constraints to pretty disorder constraints
+ all_pretty_c[c] = Or(pretty_c)
+
+ for aux in aux_map.values():
+ for c, v in all_pretty_c.items():
+ for child in v.children:
+ if aux == child.name:
+ e.pprint(e, v)
+ print(constraints[c])
+ break
+ print()
+
+
+ @staticmethod
+ def _build_disorder_constraints(obs_lists: ObservationLists):
+ """Builds disorder constraints. Corresponds to step 1 of the AMDN algorithm.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens to be analyzed.
+
+ Returns:
+ The disorder constraints to be used in the algorithm.
+ """
+ disorder_constraints = {}
+
+ # iterate through all traces
+ for i in range(len(obs_lists.all_par_act_sets)):
+ # get the parallel action sets for this trace
+ par_act_sets = obs_lists.all_par_act_sets[i]
+ # iterate through all pairs of parallel action sets for this trace
+ # use -1 since we will be referencing the current parallel action set and the following one
+ for j in range(len(par_act_sets) - 1):
+ # for each action in psi_i+1
+ for act_y in par_act_sets[j + 1]:
+ # for each action in psi_i
+ # NOTE: we do not use an existential here, as the paper describes (for each act_y in psi_i + 1,
+ # there exists an act_x in psi_i such that the condition holds.)
+ # this is due to the fact that the weights must be set for each action pair.
+ for act_x in par_act_sets[j]:
+ if act_x != act_y:
+ # calculate the probability of the actions being disordered (p)
+ p = obs_lists.probabilities[ActionPair({act_x, act_y})]
+ # each constraint only needs to hold for one proposition to be true
+ constraint_1 = []
+ constraint_2 = []
+ for r in obs_lists.propositions:
+ constraint_1.append(Or([
+ And([pre(r, act_x), ~delete(r, act_x), delete(r, act_y)]),
+ And([add(r, act_x), pre(r, act_y)]),
+ And([add(r, act_x), delete(r, act_y)]),
+ And([delete(r, act_x), add(r, act_y)])
+ ]))
+ constraint_2.append(Or([
+ And([pre(r, act_y), ~delete(r, act_y), delete(r, act_x)]),
+ And([add(r, act_y), pre(r, act_x)]),
+ And([add(r, act_y), delete(r, act_x)]),
+ And([delete(r, act_y), add(r, act_x)])
+ ]))
+ disjunct_all_constr_1 = Or(constraint_1).to_CNF()
+ disjunct_all_constr_2 = Or(constraint_2).to_CNF()
+ AMDN._extract_aux_set_weights(disjunct_all_constr_1, disorder_constraints, (1 - p))
+ AMDN._extract_aux_set_weights(disjunct_all_constr_2, disorder_constraints, p)
+ return disorder_constraints
+
+ @staticmethod
+ def _build_hard_parallel_constraints(obs_lists: ObservationLists):
+ """Builds hard parallel constraints.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens to be analyzed.
+
+ Returns:
+ The hard parallel constraints to be used in the algorithm.
+ """
+ hard_constraints = {}
+ # create a list of all tuples
+ for act in obs_lists.actions:
+ for r in obs_lists.propositions:
+ # for each action x proposition pair, enforce the two hard constraints with weight wmax
+ hard_constraints[implies(add(r, act), ~pre(r, act))] = WMAX
+ hard_constraints[implies(delete(r, act), pre(r, act))] = WMAX
+ return hard_constraints
+
+ @staticmethod
+ def _build_soft_parallel_constraints(obs_lists: ObservationLists):
+ """Builds soft parallel constraints.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens to be analyzed.
+
+ Returns:
+ The soft parallel constraints to be used in the algorithm.
+ """
+ soft_constraints = {}
- # [constraint 8] is almost identical to [constraint 6]. Watch the order of the tuples.
+ # NOTE: the paper does not take into account possible conflicts between the preconditions of actions
+ # and the add/delete effects of other actions (similar to the hard constraints, but with other actions
+ # in the parallel action set).
- pass
+ # iterate through all traces
+ for i in range(len(obs_lists.all_par_act_sets)):
+ par_act_sets = obs_lists.all_par_act_sets[i]
+ # iterate through all parallel action sets for this trace
+ for j in range(len(par_act_sets)):
+ # within each parallel action set, iterate through the same action set again to compare
+ # each action to every other action in the set; setting constraints assuming actions are not disordered
+ for act_x in par_act_sets[j]:
+ for act_x_prime in par_act_sets[j] - {act_x}:
+ p = obs_lists.probabilities[ActionPair({act_x, act_x_prime})]
+ # iterate through all propositions
+ for r in obs_lists.propositions:
+ soft_constraints[implies(add(r, act_x), ~delete(r, act_x_prime))] = (1 - p) * WMAX
- def _solve_constraints(self):
- # TODO:
- # call the MAXSAT solver
- pass
+ # iterate through all traces
+ for i in range(len(obs_lists.all_par_act_sets)):
+ par_act_sets = obs_lists.all_par_act_sets[i]
+ # then, iterate through all pairs of parallel action sets for each trace
+ for j in range(len(par_act_sets) - 1):
+ # for each pair, compare every action in act_y to every action in act_x_prime; setting constraints assuming actions are disordered
+ for act_y in par_act_sets[j + 1]:
+ for act_x_prime in par_act_sets[j] - {act_y}:
+ p = obs_lists.probabilities[ActionPair({act_y, act_x_prime})]
+ # iterate through all propositions and similarly set the constraint
+ for r in obs_lists.propositions:
+ soft_constraints[implies(add(r, act_y), ~delete(r, act_x_prime))] = p * WMAX
+
+ return soft_constraints
- def _convert_to_model(self):
- # TODO:
+ @staticmethod
+ def _build_parallel_constraints(obs_lists: ObservationLists, debug: bool, to_obs: Optional[List[str]]):
+ """Main driver for building parallel constraints. Corresponds to step 2 of the AMDN algorithm.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that were analyzed.
+ debug (bool):
+ Optional debugging mode.
+ to_obs (Optional[List[str]]):
+ If in the optional debugging mode, the list of fluents to observe.
+
+ Returns:
+ The parallel constraints.
+ """
+ hard_constraints = AMDN._build_hard_parallel_constraints(obs_lists)
+ soft_constraints = AMDN._build_soft_parallel_constraints(obs_lists)
+ if debug:
+ print("\nHard parallel constraints:")
+ AMDN._debug_simple_pprint(hard_constraints, to_obs)
+ print("\nSoft parallel constraints:")
+ AMDN._debug_simple_pprint(soft_constraints, to_obs)
+ return {**hard_constraints, **soft_constraints}
+
+ @staticmethod
+ def _calculate_all_r_occ(obs_lists: ObservationLists):
+ """Calculates the total number of (true) propositions in the provided traces/tokens.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens to be analyzed.
+
+ Returns:
+ The total number of (true) propositions in the provided traces/tokens.
+ """
+ # tracks occurrences of all propositions
+ all_occ = 0
+ for trace in obs_lists:
+ for step in trace:
+ all_occ += len([f for f in step.state if step.state[f]])
+ return all_occ
+
+ @staticmethod
+ def _set_up_occurrences_dict(obs_lists: ObservationLists):
+ """Helper function used when constructing noise constraints.
+ Sets up an "occurrence" dictionary used to track the occurrences of propositions
+ before or after actions.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens to be analyzed.
+
+ Returns:
+ The blank "occurrences" dictionary.
+ """
+ # set up dict
+ occurrences = {}
+ for a in obs_lists.actions:
+ occurrences[a] = {}
+ for r in obs_lists.propositions:
+ occurrences[a][r] = 0
+ return occurrences
+
+ @staticmethod
+ def _noise_constraints_6(obs_lists: ObservationLists, all_occ: int, occ_threshold: int):
+ """Noise constraints (6) in the AMDN paper.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that were analyzed.
+ all_occ (int):
+ The number of occurrences of all (true) propositions in the given observation list.
+ occ_threshold (int):
+ Threshold to be used for noise constraints.
+
+ Returns:
+ The noise constraints.
+ """
+ noise_constraints_6 = {}
+ occurrences = AMDN._set_up_occurrences_dict(obs_lists)
+
+ # iterate over ALL the plan traces, adding occurrences accordingly
+ for i in range(len(obs_lists)):
+ # iterate through each step in each trace, omitting the last step because the last action is None/we access the state in the next step
+ for j in range(len(obs_lists[i]) - 1):
+ true_prop = [f for f in obs_lists[i][j + 1].state if obs_lists[i][j + 1].state[f]]
+ for r in true_prop:
+ # count the number of occurrences of each action and its following proposition
+ occurrences[obs_lists[i][j].action][r] += 1
+
+ # iterate through actions
+ for a in occurrences:
+ # iterate through all propositions for this action
+ for r in occurrences[a]:
+ occ_r = occurrences[a][r]
+ # if the # of occurrences is higher than the user-provided threshold:
+ if occ_r > occ_threshold:
+ # set constraint 6 with the calculated weight
+ noise_constraints_6[AMDN._or_refactor(~delete(r, a))] = (occ_r / all_occ) * WMAX
+ return noise_constraints_6
+
+ @staticmethod
+ def _noise_constraints_7(obs_lists: ObservationLists, all_occ: int):
+ """Noise constraints (7) in the AMDN paper.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that were analyzed.
+ all_occ (int):
+ The number of occurrences of all (true) propositions in the given observation list.
+
+ Returns:
+ The noise constraints.
+ """
+ noise_constraints_7 = {}
+ # set up dict
+ occurrences = {}
+ for r in obs_lists.propositions:
+ occurrences[r] = 0
+
+ for trace in obs_lists.all_states:
+ for state in trace:
+ true_prop = [r for r in state if state[r]]
+ for r in true_prop:
+ occurrences[r] += 1
+
+ # iterate through all traces
+ for i in range(len(obs_lists.all_par_act_sets)):
+ # get the next trace/states
+ par_act_sets = obs_lists.all_par_act_sets[i]
+ states = obs_lists.all_states[i]
+ # iterate through all parallel action sets within the trace
+ for j in range(len(par_act_sets)):
+ # examine the states before and after each parallel action set; set constraints accordinglly
+ true_prop = [r for r in states[j + 1] if states[j + 1][r]]
+ for r in true_prop:
+ if not states[j][r]:
+ noise_constraints_7[Or([add(r, act) for act in par_act_sets[j]])] = (occurrences[r]/all_occ) * WMAX
+ return noise_constraints_7
+
+ @staticmethod
+ def _noise_constraints_8(obs_lists, all_occ: int, occ_threshold: int):
+ """Noise constraints (8) in the AMDN paper.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that were analyzed.
+ all_occ (int):
+ The number of occurrences of all (true) propositions in the given observation list.
+ occ_threshold (int):
+ Threshold to be used for noise constraints.
+
+ Returns:
+ The noise constraints.
+ """
+ noise_constraints_8 = {}
+ occurrences = AMDN._set_up_occurrences_dict(obs_lists)
+
+ # iterate over ALL the plan traces, adding occurrences accordingly
+ for i in range(len(obs_lists)):
+ # iterate through each step in each trace
+ for j in range(len(obs_lists[i])):
+ # if the action is not None
+ if obs_lists[i][j].action:
+ true_prop = [f for f in obs_lists[i][j].state if obs_lists[i][j].state[f]]
+ for r in true_prop:
+ # count the number of occurrences of each action and its previous proposition
+ occurrences[obs_lists[i][j].action][r] += 1
+
+ # iterate through actions
+ for a in occurrences:
+ # iterate through all propositions for this action
+ for r in occurrences[a]:
+ occ_r = occurrences[a][r]
+ # if the # of occurrences is higher than the user-provided threshold:
+ if occ_r > occ_threshold:
+ # set constraint 8 with the calculated weight
+ noise_constraints_8[AMDN._or_refactor(pre(r, a))] = (occ_r / all_occ) * WMAX
+ return noise_constraints_8
+
+ @staticmethod
+ def _build_noise_constraints(obs_lists: ObservationLists, occ_threshold: int, debug: bool, to_obs: Optional[List[str]]):
+ """Driver for building all noise constraints. Corresponds to step 3 of the AMDN algorithm.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that were analyzed.
+ occ_threshold (int):
+ Threshold to be used for noise constraints.
+ debug (bool):
+ Optional debugging mode.
+ to_obs (Optional[List[str]]):
+ If in the optional debugging mode, the list of fluents to observe.
+ """
+ # calculate all occurrences for use in weights
+ all_occ = AMDN._calculate_all_r_occ(obs_lists)
+ nc_6 = AMDN._noise_constraints_6(obs_lists, all_occ, occ_threshold)
+ nc_7 = AMDN._noise_constraints_7(obs_lists, all_occ)
+ nc_8 = AMDN._noise_constraints_8(obs_lists, all_occ, occ_threshold)
+ if debug:
+ print("\nNoise constraints 6:")
+ AMDN._debug_simple_pprint(nc_6, to_obs)
+ print("\nNoise constraints 7:")
+ AMDN._debug_simple_pprint(nc_7, to_obs)
+ print("\nNoise constraints 8:")
+ AMDN._debug_simple_pprint(nc_8, to_obs)
+ return{**nc_6, **nc_7, **nc_8}
+
+ @staticmethod
+ def _set_all_constraints(obs_lists: ObservationLists, occ_threshold: int, debug: bool):
+ """Main driver for generating all constraints in the AMDN algorithm.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that were analyzed.
+ occ_threshold (int):
+ Threshold to be used for noise constraints.
+ debug (bool):
+ Optional debugging mode.
+
+ Returns:
+ A dictionary that constains all of the constraints set and all of their weights.
+ """
+ to_obs = None
+ if debug:
+ to_obs = AMDN._get_observe(obs_lists)
+ disorder_constraints = AMDN._build_disorder_constraints(obs_lists)
+ if debug:
+ print("\nDisorder constraints:")
+ AMDN._debug_aux_pprint(disorder_constraints, to_obs)
+ parallel_constraints = AMDN._build_parallel_constraints(obs_lists, debug, to_obs)
+ noise_constraints = AMDN._build_noise_constraints(obs_lists, occ_threshold, debug, to_obs)
+ return {**disorder_constraints, **parallel_constraints, **noise_constraints}
+
+ @staticmethod
+ def _solve_constraints(obs_lists: ObservationLists, occ_threshold: int, debug: bool):
+ """Returns the WCNF and the decoder according to the constraints generated.
+ Corresponds to step 4 of the AMDN algorithm.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that were analyzed.
+ occ_threshold (int):
+ Threshold to be used for noise constraints.
+ debug (bool):
+ Optional debugging mode.
+
+ Returns:
+ The WCNF and corresponding decode dictionary.
+ """
+ constraints = AMDN._set_all_constraints(obs_lists, occ_threshold, debug)
+ # extract hard constraints
+ hard_constraints = []
+ for c, weight in constraints.items():
+ if weight == "HARD":
+ hard_constraints.append(c)
+ for c in hard_constraints:
+ del constraints[c]
+
+ wcnf, decode = to_wcnf(soft_clauses=And(constraints.keys()), hard_clauses=And(hard_constraints), weights=list(constraints.values()))
+ return wcnf, decode
+
+ @staticmethod
+ def _split_raw_fluent(raw_f: Hashable, learned_actions: Dict[str, LearnedAction]):
+ """Helper function for `_extract_model` that updates takes raw fluents to update
+ a dictionary of `LearnedActions`.
+
+ Args:
+ raw_f (Hashable):
+ The raw fluent to parse.
+ learned_actions (Dict[str, LearnedAction]):
+ The dictionary of learned actions that will be used to create the model.
+ """
+ raw_f = str(raw_f)[1:-1]
+ pre_str = " is a precondition of "
+ add_str = " is added by "
+ del_str = " is deleted by "
+ if pre_str in raw_f:
+ f, act = raw_f.split(pre_str)
+ learned_actions[act].update_precond({f})
+ elif add_str in raw_f:
+ f, act = raw_f.split(add_str)
+ learned_actions[act].update_add({f})
+ else:
+ f, act = raw_f.split(del_str)
+ learned_actions[act].update_delete({f})
+
+ @staticmethod
+ def _extract_model(obs_lists: ObservationLists, model: Dict[Hashable, bool]):
+ """Converts a raw model generated from the pysat module into a macq `Model`.
+ Corresponds to step 5 of the AMDN algorithm.
+
+ Args:
+ obs_lists (ObservationLists):
+ The tokens that were analyzed.
+ model (Dict[Hashable, bool]):
+ The raw model to parse and analyze.
+
+ Returns:
+ The macq action `Model`.
+ """
# convert the result to a Model
- pass
\ No newline at end of file
+ fluents = obs_lists.propositions
+ # set up LearnedActions
+ learned_actions = {}
+ for a in obs_lists.actions:
+ # set up a base LearnedAction with the known information
+ learned_actions[a.details()] = extract.LearnedAction(a.name, a.obj_params, cost=a.cost)
+ # iterate through all fluents
+ for raw_f in model:
+ # update learned_actions (ignore auxiliary variables)
+ if not isinstance(raw_f, Aux) and model[raw_f]:
+ AMDN._split_raw_fluent(raw_f, learned_actions)
+
+ return Model(fluents, learned_actions.values())
diff --git a/macq/extract/arms.py b/macq/extract/arms.py
new file mode 100644
index 00000000..615ec377
--- /dev/null
+++ b/macq/extract/arms.py
@@ -0,0 +1,897 @@
+""".. include:: ../../docs/extract/arms.md"""
+
+from collections import defaultdict, Counter
+from dataclasses import dataclass
+from warnings import warn
+from typing import Set, List, Dict, Tuple, Hashable
+from nnf import Var, And, Or, false as nnffalse
+from . import LearnedAction, Model, LearnedFluent
+from .exceptions import (
+ IncompatibleObservationToken,
+ InvalidMaxSATModel,
+)
+from ..observation import PartialObservation as Observation, ObservationLists
+from ..trace import Fluent, Action # Action only used for typing
+from ..utils.pysat import to_wcnf, RC2, WCNF
+
+
+@dataclass
+class Relation:
+ """Fluents with the parameters replaced by their types."""
+
+ name: str
+ types: list
+
+ def var(self):
+ """Generates the variable representation for NNF."""
+ return f"{self.name} {' '.join(list(self.types))}"
+
+ def matches(self, action: LearnedAction):
+ """Determines if a relation is related to a given action."""
+ action_types = set(action.obj_params)
+ self_counts = Counter(self.types)
+ action_counts = Counter(action.obj_params)
+ return all([t in action_types for t in self.types]) and all(
+ [self_counts[t] <= action_counts[t] for t in self.types]
+ )
+
+ def __hash__(self):
+ return hash(self.var())
+
+
+@dataclass
+class ARMSConstraints:
+ """A dataclass to hold all the constraints and weight information."""
+
+ action: List[Or[Var]]
+ info: List[Or[Var]]
+ info3: Dict[Or[Var], int]
+ plan: Dict[Or[Var], int]
+
+
+class ARMS:
+ """ARMS model extraction method.
+
+ Extracts a Model from state observations using the ARMS technique. Fluents
+ are retrieved from the initial state. Actions are learned using the
+ algorithm.
+ """
+
+ class InvalidThreshold(Exception):
+ def __init__(self, threshold):
+ super().__init__(
+ f"Invalid threshold value: {threshold}. Threshold must be a float between 0-1 (inclusive)."
+ )
+
+ def __new__(
+ cls,
+ obs_lists: ObservationLists,
+ debug: bool,
+ upper_bound: int,
+ min_support: int = 2,
+ action_weight: int = 110,
+ info_weight: int = 100,
+ threshold: float = 0.6,
+ info3_default: int = 30,
+ plan_default: int = 30,
+ ):
+ """
+ Arguments:
+ obs_lists (ObservationLists):
+ The observations to extract the model from.
+ upper_bound (int):
+ The upper bound for the maximum size of an action's preconditions and
+ add/delete lists. Determines when an action schemata is fully learned.
+ min_support (int):
+ The minimum support count for an action pair to be considered frequent.
+ action_weight (int):
+ The constant weight W_A(a) to assign to each action constraint.
+ Should be set higher than the weight of information constraints.
+ info_weight (int):
+ The constant weight W_I(r) to assign to each information constraint.
+ Determined empirically, generally the highest in all constraints' weights.
+ threshold (float):
+ (0-1). The probability threshold θ to determine if an I3/plan constraint
+ is weighted by its probability or set to a default value.
+ info3_default (int):
+ The default weight for I3 constraints with probability below the threshold.
+ plan_default (int):
+ The default weight for plan constraints with probability below the threshold.
+ """
+ if obs_lists.type is not Observation:
+ raise IncompatibleObservationToken(obs_lists.type, ARMS)
+
+ if not (threshold >= 0 and threshold <= 1):
+ raise ARMS.InvalidThreshold(threshold)
+
+ fluents = obs_lists.get_fluents()
+ # get fluents from initial state
+ # call algorithm to get actions
+ actions = ARMS._arms(
+ obs_lists,
+ upper_bound,
+ fluents,
+ min_support,
+ action_weight,
+ info_weight,
+ threshold,
+ info3_default,
+ plan_default,
+ debug,
+ )
+
+ learned_fluents = set(map(lambda f: LearnedFluent(f.name, f.objects), fluents))
+ return Model(learned_fluents, actions)
+
+ @staticmethod
+ def _arms(
+ obs_lists: ObservationLists,
+ upper_bound: int,
+ fluents: Set[Fluent],
+ min_support: int,
+ action_weight: int,
+ info_weight: int,
+ threshold: float,
+ info3_default: int,
+ plan_default: int,
+ debug: bool,
+ ) -> Set[LearnedAction]:
+ """The main driver for the ARMS algorithm."""
+ learned_actions = set() # The set of learned action models Θ
+ # pointers to the earliest unlearned action for each observation list
+ early_actions = [0] * len(obs_lists)
+
+ debug1 = ARMS.debug_menu("Debug step 1?") if debug else False
+ connected_actions, action_map = ARMS.step1(obs_lists, debug1)
+ if debug1:
+ input("Press enter to continue...")
+
+ action_map_rev: Dict[LearnedAction, List[Action]] = defaultdict(list)
+ for obs_action, learned_action in action_map.items():
+ action_map_rev[learned_action].append(obs_action)
+
+ count = 1
+ while action_map_rev:
+ if debug:
+ print("Iteration", count)
+ count += 1
+
+ debug2 = ARMS.debug_menu("Debug step 2?") if debug else False
+ constraints, relation_map = ARMS.step2(
+ obs_lists, connected_actions, action_map, fluents, min_support, debug2
+ )
+ if debug2:
+ input("Press enter to continue...")
+
+ relation_map_rev: Dict[Relation, List[Fluent]] = defaultdict(list)
+ for fluent, relation in relation_map.items():
+ relation_map_rev[relation].append(fluent)
+
+ debug3 = ARMS.debug_menu("Debug step 3?") if debug else False
+ max_sat, decode = ARMS.step3(
+ constraints,
+ action_weight,
+ info_weight,
+ threshold,
+ info3_default,
+ plan_default,
+ debug3,
+ )
+ if debug3:
+ input("Press enter to continue...")
+
+ model = ARMS.step4(max_sat, decode)
+
+ debug5 = ARMS.debug_menu("Debug step 5?") if debug else False
+ # Mutates the LearnedAction (keys) of action_map_rev
+ ARMS.step5(
+ model,
+ list(action_map_rev.keys()),
+ debug5,
+ )
+
+ # Step 5 updates
+
+ # Progress observed states if early actions have been learned
+ setA = set()
+ for action in action_map_rev.keys():
+ for i, obs_list in enumerate(obs_lists):
+ obs_action: Action = obs_list[early_actions[i]].action
+ # if current action is the early action for obs_list i,
+ # update the next state with the effects and update the
+ # early action pointer
+ if obs_action in action_map and action == action_map[obs_action]:
+ print()
+ # Set add effects true
+ for add in action.add:
+ # get candidate fluents from add relation
+ # get fluent by cross referencing obs_list.action params
+ candidates = relation_map_rev[add]
+ for fluent in candidates:
+ if set(fluent.objects).issubset(obs_action.obj_params):
+ obs_list[early_actions[i] + 1].state[fluent] = True
+ early_actions[i] += 1
+ # Set del effects false
+ for delete in action.delete:
+ candidates = relation_map_rev[delete]
+ for fluent in candidates:
+ if set(fluent.objects).issubset(obs_action.obj_params):
+ obs_list[early_actions[i] + 1].state[fluent] = False
+ early_actions[i] += 1
+
+ if debug:
+ print()
+ print(action.details())
+ print("precond:", action.precond)
+ print("add:", action.add)
+ print("delete:", action.delete)
+
+ if (
+ max([len(action.precond), len(action.add), len(action.delete)])
+ >= upper_bound
+ ):
+ if debug:
+ print(
+ f"Action schemata for {action.details()} has been fully learned."
+ )
+ setA.add(action)
+
+ # Update Λ by Λ − A
+ for action in setA:
+ action_keys = action_map_rev[action]
+ for obs_action in action_keys:
+ del action_map[obs_action]
+ del action_map_rev[action]
+ del connected_actions[action]
+ action_keys = [
+ a1 for a1 in connected_actions if action in connected_actions[a1]
+ ]
+ for a1 in action_keys:
+ del connected_actions[a1][action]
+
+ # Update Θ by adding A
+ learned_actions.add(action)
+
+ if debug5:
+ input("Press enter to continue...")
+
+ return learned_actions
+
+ @staticmethod
+ def step1(
+ obs_lists: ObservationLists, debug: bool
+ ) -> Tuple[
+ Dict[LearnedAction, Dict[LearnedAction, Set[str]]],
+ Dict[Action, LearnedAction],
+ ]:
+ """(Step 1) Substitute instantiated objects in each action instance with the object type."""
+
+ learned_actions: Set[LearnedAction] = set()
+ action_map: Dict[Action, LearnedAction] = {}
+ for obs_action in obs_lists.get_actions():
+ # We don't support objects with multiple types right now, so no
+ # multiple type clauses need to be generated.
+
+ # Create LearnedActions for each action, replacing instantiated
+ # objects with the object type.
+ types = [obj.obj_type for obj in obs_action.obj_params]
+ learned_action = LearnedAction(obs_action.name, types)
+ learned_actions.add(learned_action)
+ action_map[obs_action] = learned_action
+
+ connected_actions: Dict[LearnedAction, Dict[LearnedAction, Set[str]]] = {}
+ for a1 in learned_actions:
+ connected_actions[a1] = {}
+ for a2 in learned_actions.difference({a1}): # includes connecting with self
+ intersection = set(a1.obj_params).intersection(set(a2.obj_params))
+ if intersection:
+ connected_actions[a1][a2] = intersection
+ if debug:
+ print(
+ f"{a1.details()} is connected to {a2.details()} by {intersection}"
+ )
+
+ return connected_actions, action_map
+
+ @staticmethod
+ def step2(
+ obs_lists: ObservationLists,
+ connected_actions: Dict[LearnedAction, Dict[LearnedAction, Set[str]]],
+ action_map: Dict[Action, LearnedAction],
+ fluents: Set[Fluent],
+ min_support: int,
+ debug: bool,
+ ) -> Tuple[ARMSConstraints, Dict[Fluent, Relation]]:
+ """(Step 2) Generate action constraints, information constraints, and plan constraints.
+
+ For the unexplained actions, build a set of information and action
+ constraints based on individual actions. Apply a frequent-set
+ mining algorithm to find the frequent sets of connected actions and
+ relations. Here connected means the actions and relations must share
+ some common parameters.
+ """
+
+ # Map fluents to relations
+ # relations are fluents but with instantiated objects replaced by the object type
+ relations: Dict[Fluent, Relation] = dict(
+ map(
+ lambda f: (
+ f,
+ Relation(
+ f.name,
+ [obj.obj_type for obj in f.objects],
+ ),
+ ),
+ fluents,
+ )
+ )
+
+ debuga = ARMS.debug_menu("Debug action constraints?") if debug else False
+
+ action_constraints = ARMS.step2A(
+ connected_actions, set(relations.values()), debuga
+ )
+
+ debugi = ARMS.debug_menu("Debug info constraints?") if debug else False
+ info_constraints, info_support_counts = ARMS.step2I(
+ obs_lists, relations, action_map, debugi
+ )
+
+ debugp = ARMS.debug_menu("Debug plan constraints?") if debug else False
+ plan_constraints = ARMS.step2P(
+ obs_lists,
+ connected_actions,
+ action_map,
+ set(relations.values()),
+ min_support,
+ debugp,
+ )
+
+ return (
+ ARMSConstraints(
+ action_constraints,
+ info_constraints,
+ info_support_counts,
+ plan_constraints,
+ ),
+ relations,
+ )
+
+ @staticmethod
+ def step2A(
+ connected_actions: Dict[LearnedAction, Dict[LearnedAction, Set]],
+ relations: Set[Relation],
+ debug: bool,
+ ) -> List[Or[Var]]:
+ """Action constraints.
+
+ A1. The intersection of the precondition and add lists of all actions must be empty.
+
+ A2. In addition, if an action’s delete list includes a relation, this relation is
+ in the action’s precondition list. Thus, for every action, we require that
+ the delete list is a subset of the precondition list.
+ """
+
+ if debug:
+ print("\nBuilding action constraints...\n")
+
+ def implication(a: Var, b: Var):
+ return Or([a.negate(), b])
+
+ constraints: List[Or[Var]] = []
+ actions = set(connected_actions.keys())
+ for action in actions:
+ for relation in relations:
+ # A relation is relevant to an action if they share parameter types
+ if relation.matches(action):
+ if debug:
+ print(
+ f'relation ({relation.var()}) is relevant to action "{action.details()}"\n'
+ "A1:\n"
+ f" {relation.var()}∈ add ⇒ {relation.var()}∉ pre\n"
+ f" {relation.var()}∈ pre ⇒ {relation.var()}∉ add\n"
+ "A2:\n"
+ f" {relation.var()}∈ del ⇒ {relation.var()}∈ pre\n"
+ )
+
+ # A1
+ # relation in action.add => relation not in action.precond
+ # relation in action.precond => relation not in action.add
+
+ # (BREAK) marks unambiguous breakpoints for parsing later
+ constraints.append(
+ implication(
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) add (BREAK) {action.details()}"
+ ),
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) pre (BREAK) {action.details()}"
+ ).negate(),
+ )
+ )
+ constraints.append(
+ implication(
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) pre (BREAK) {action.details()}"
+ ),
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) add (BREAK) {action.details()}"
+ ).negate(),
+ )
+ )
+
+ # A2
+ # relation in action.del => relation in action.precond
+ constraints.append(
+ implication(
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) del (BREAK) {action.details()}"
+ ),
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) pre (BREAK) {action.details()}"
+ ),
+ )
+ )
+
+ return constraints
+
+ @staticmethod
+ def step2I(
+ obs_lists: ObservationLists,
+ relations: Dict[Fluent, Relation],
+ actions: Dict[Action, LearnedAction],
+ debug: bool,
+ ) -> Tuple[List[Or[Var]], Dict[Or[Var], int]]:
+ """Information constraints.
+
+ Suppose we observe a relation p to be true between two actions
+ \(a_n\) and \(a_{n+1}\) , and \(p, a_{i_1} , ... ,\) and \(a_{i_k}\) share
+ the same parameter types. We can represent this fact by the following clauses,
+ given that \(a_{i_1} , ... ,\) and \(a_{i_k}\) appear in that order.
+
+ I1. The relation \(p\) must be generated by an action \(a_{i_k} (0 \le i_k \le n)\),
+ that is, \(p\) is selected to be in the add-list of \(a_{i_k}\).
+ \(p∈ (add_{i_1} ∪ add_{i_2} ∪ \dots ∪ add_{i_k}) \), where ∪ means logical “or”.
+
+ I2. The last action \(a_{i_k}\) must not delete the relation p; that is,
+ \(p\) must not be selected to be in the delete list of \(a_{i_k}\): \(p \\not\in del_{i_k}\).
+
+ I3. We define the weight value of a relation-action pair \((p, a)\) as the
+ occurrence probability of this pair in all plan examples. If the probability
+ of a relation-action pair is higher than the probability threshold θ ,
+ then we set a corresponding relation constraint \(p ∈ \\text{PRECOND}_a\), which
+ receives a weight value equal to its prior probability.
+ """
+ if debug:
+ print("\nBuilding information constraints...")
+ constraints: List[Or[Var]] = []
+ support_counts: Dict[Or[Var], int] = defaultdict(int)
+ obs_list: List[Observation]
+ for obs_list_i, obs_list in enumerate(obs_lists):
+ for i, obs in enumerate(obs_list):
+ if obs.state is not None and i > 0:
+ n = i - 1
+ if debug:
+ print(
+ f"\nStep {i} of observation list {obs_list_i} contains state information."
+ )
+ for fluent, val in obs.state.items():
+ relation = relations[fluent]
+ # Information constraints only apply to true relations
+ if val:
+ if debug:
+ print(
+ f" Fluent {fluent} is true.\n"
+ f" ({relation.var()})∈ ("
+ f"{' ∪ '.join([f'add_{{ {actions[obs_list[ik].action].details()} }}' for ik in range(0,n+1) if obs_list[ik].action in actions] )}" # type: ignore
+ ")"
+ )
+ # I1
+ # relation in the add list of an action <= n (i-1)
+ i1: List[Var] = []
+ for obs_i in obs_list[: i - 1]:
+ if obs_i.action in actions and obs_i.action is not None:
+ ai = actions[obs_i.action]
+ if relation.matches(ai):
+ i1.append(
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) add (BREAK) {ai.details()}"
+ )
+ )
+
+ # I2
+ # relation not in del list of action n (i-1)
+ i2 = None
+ a_n = obs_list[i - 1].action
+ if a_n in actions and a_n is not None:
+ i2 = Var(
+ f"{relation.var()} (BREAK) in (BREAK) del (BREAK) {actions[a_n].details()}"
+ ).negate()
+
+ if i1:
+ constraints.append(Or(i1))
+ if i2:
+ constraints.append(Or([i2]))
+
+ # I3
+ # count occurences
+ if (
+ i < len(obs_list) - 1
+ and obs.action in actions
+ and obs.action is not None # for the linter
+ and relation.matches(actions[obs.action])
+ ):
+ # corresponding constraint is related to the current action's precondition list
+ support_counts[
+ Or(
+ [
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) pre (BREAK) {actions[obs.action].details()}"
+ )
+ ]
+ )
+ ] += 1
+ elif (
+ a_n in actions
+ and a_n is not None
+ and relation.matches(actions[a_n])
+ ):
+ # corresponding constraint is related to the previous action's add list
+ support_counts[
+ Or(
+ [
+ Var(
+ f"{relation.var()} (BREAK) in (BREAK) add (BREAK) {actions[a_n].details()}"
+ )
+ ]
+ )
+ ] += 1
+
+ return constraints, support_counts
+
+ @staticmethod
+ def _apriori(
+ action_lists: List[List[LearnedAction]], minsup: int
+ ) -> Dict[Tuple[LearnedAction, LearnedAction], int]:
+ """An implementation of the Apriori algorithm to find frequent ordered pairs of actions."""
+ counts = Counter(
+ [action for action_list in action_lists for action in action_list]
+ )
+ # L1 = {actions that appear >minsup}
+ L1 = set(
+ frozenset([action])
+ for action in filter(lambda k: counts[k] >= minsup, counts.keys())
+ ) # large 1-itemsets
+
+ # Only going up to L2, so no loop or generalized algorithm needed
+ # apriori-gen step
+ C2 = set([i.union(j) for i in L1 for j in L1 if len(i.union(j)) == 2])
+ # Since L1 contains 1-itemsets where each item is frequent, C2 can
+ # only contain valid sets and pruning is not required
+
+ # Get all possible ordered action pairs
+ C2_ordered = set()
+ for pair in C2:
+ pair = list(pair)
+ C2_ordered.add((pair[0], pair[1]))
+ C2_ordered.add((pair[1], pair[0]))
+
+ # Count pair occurences and generate L2
+ frequent_pairs = {}
+ for ai, aj in C2_ordered:
+ count = 0
+ for action_list in action_lists:
+ a1_indecies = [i for i, e in enumerate(action_list) if e == ai]
+ if a1_indecies:
+ for i in a1_indecies:
+ if aj in action_list[i + 1 :]:
+ count += 1
+ if count >= minsup:
+ frequent_pairs[(ai, aj)] = count
+
+ return frequent_pairs
+
+ @staticmethod
+ def step2P(
+ obs_lists: ObservationLists,
+ connected_actions: Dict[LearnedAction, Dict[LearnedAction, Set]],
+ action_map: Dict[Action, LearnedAction],
+ relations: Set[Relation],
+ min_support: int,
+ debug: bool,
+ ) -> Dict[Or[Var], int]:
+ """Plan constraints.
+
+ P1. Every precondition \(p\) of every action \(b\) must be in the add
+ list of a preceding action \(a\) and is not deleted by any actions between
+ \(a\) and \(b\).
+
+ P2. In addition, at least one relation \(r\) in the add list of an action
+ must be useful in achieving a precondition of a later action. That is, for
+ every action \(a\), an add list relation \(r\) must be in the precondition of a
+ later action \(b\), and there is no other action between \(a\) and \(b\)
+ that either adds or deletes \(r\).
+
+ "While constraints P1 and P2 provide the general guiding principle for
+ ensuring a plan’s correctness, in practice there are too many
+ instantiations of these constraints." (more information on the
+ rationelle in the paper) Thus, they are replaced with the following
+ constraints:
+
+ Let there be an action pair \(\langle a_i, a_j \\rangle, 0 \le i < j \le n-1\).
+
+ P3. One of the relevant relations \(p\) must be chosen to be in the
+ preconditions of both \(a_i\) and \(a_j\), but not in the delete list
+ of \(a_i\).
+
+ P4. The first action \(a_i\) adds a relevant relation that is in the
+ precondition list of the second action \(a_j\) in the pair.
+
+ P5. A relevant relation \(p\) that is deleted by the first action
+ \(a_i\) is added by \(a_j\). The second clause is designed for the event
+ when an action re-establishes a fact that is deleted by a previous action.
+
+ The above constraints can be combined into one constraint:
+ $$\exists p ((p \in (pre_i \cap pre_j) \land p \\not \in (del_i)) \lor (p \in (add_i \cap pre_j)) \lor (p \in (del_i \cap add_j)))$$
+ """
+ frequent_pairs = ARMS._apriori(
+ [
+ [
+ action_map[obs.action]
+ for obs in obs_list
+ if obs.action is not None and obs.action in action_map
+ ]
+ for obs_list in obs_lists
+ ],
+ min_support,
+ )
+ if debug:
+ print("Frequent pairs:")
+ print(frequent_pairs)
+
+ # constraints: Dict[And[Or[Var]], int] = {}
+ constraints: Dict[Or[Var], int] = {}
+ for ai, aj in frequent_pairs.keys():
+ connectors = set()
+ # get list of relevant relations from connected_actions
+ if ai in connected_actions and aj in connected_actions[ai]:
+ connectors.update(connected_actions[ai][aj])
+ if aj in connected_actions and ai in connected_actions[aj]:
+ connectors.update(connected_actions[aj][ai])
+
+ # if the actions are not related they are not a valid pair for a plan constraint.
+ if not connectors:
+ continue
+
+ # for each relation, save constraint
+ relevant_relations = {p for p in relations if connectors.issubset(p.types)}
+ # relation_constraints: List[Or[And[Var]]] = []
+ relation_constraints: List[Var] = []
+ for relation in relevant_relations:
+
+ relation_constraints.append(
+ Var(
+ f"{relation.var()} (BREAK) relevant (BREAK) {ai.details()} (BREAK) {aj.details()}"
+ )
+ )
+ if debug:
+ print(
+ f"{relation.var()} might explain action pair ({ai.details()}, {aj.details()})"
+ )
+ constraints[Or(relation_constraints)] = frequent_pairs[(ai, aj)]
+
+ return constraints
+
+ @staticmethod
+ def step3(
+ constraints: ARMSConstraints,
+ action_weight: int,
+ info_weight: int,
+ threshold: float,
+ info3_default: int,
+ plan_default: int,
+ debug: bool,
+ ) -> Tuple[WCNF, Dict[int, Hashable]]:
+ """(Step 3) Construct the weighted MAX-SAT problem based on the constraints and weight information found in Step 2."""
+
+ action_weights = [action_weight] * len(constraints.action)
+ info_weights = [info_weight] * len(constraints.info)
+ info3_weights = ARMS._calculate_support_rates(
+ list(constraints.info3.values()), threshold, info3_default
+ )
+ plan_weights = ARMS._calculate_support_rates(
+ list(constraints.plan.values()), threshold, plan_default
+ )
+ all_weights = action_weights + info_weights + info3_weights + plan_weights
+
+ info3_constraints = list(constraints.info3.keys())
+ plan_constraints = list(constraints.plan.keys())
+ all_constraints = (
+ constraints.action + constraints.info + info3_constraints + plan_constraints
+ )
+
+ constraints_w_weights: Dict[Or[Var], int] = {}
+ for constraint, weight in zip(all_constraints, all_weights):
+ if constraint == nnffalse:
+ continue
+ if constraint not in constraints_w_weights:
+ constraints_w_weights[constraint] = weight
+ elif weight != constraints_w_weights[constraint]:
+ if debug:
+ warn(
+ f"The constraint {constraint} has conflicting weights ({weight} and {constraints_w_weights[constraint]}). Choosing the smaller weight."
+ )
+ constraints_w_weights[constraint] = min(
+ weight, constraints_w_weights[constraint]
+ )
+
+ problem: And[Or[Var]] = And(list(constraints_w_weights.keys()))
+ weights = list(constraints_w_weights.values())
+
+ wcnf, decode = to_wcnf(problem, weights)
+ return wcnf, decode
+
+ @staticmethod
+ def _calculate_support_rates(
+ support_counts: List[int], threshold: float, default: int
+ ) -> List[int]:
+ # NOTE:
+ # In the paper, Z_Σ_P (denominator of the support rate formula) is
+ # defined as the "total pairs" in the set of plans. However, in the
+ # examples it appears that they use the max support count as the
+ # denominator. My best interpretation is then to use the max support
+ # count as the denominator to calculate the support rate.
+
+ z_sigma_p = max(support_counts)
+
+ def get_support_rate(count):
+ probability = count / z_sigma_p
+ return probability * 100 if probability > threshold else default
+
+ return list(map(get_support_rate, support_counts))
+
+ @staticmethod
+ def step4(max_sat: WCNF, decode: Dict[int, Hashable]) -> Dict[Hashable, bool]:
+ """(Step 4) Solve the MAX-SAT problem built in Step 3."""
+ solver = RC2(max_sat)
+
+ encoded_model = solver.compute()
+
+ if not isinstance(encoded_model, list):
+ # should never be reached
+ raise InvalidMaxSATModel(encoded_model)
+
+ # decode the model (back to nnf vars)
+ model: Dict[Hashable, bool] = {
+ decode[abs(clause)]: clause > 0 for clause in encoded_model
+ }
+
+ return model
+
+ @staticmethod
+ def step5(
+ model: Dict[Hashable, bool],
+ actions: List[LearnedAction],
+ debug: bool,
+ ):
+ """(Step 5) Extract the learned action effects from the solved model."""
+ action_map = {a.details(): a for a in actions}
+ negative_constraints = defaultdict(set)
+ plan_constraints: List[Tuple[str, LearnedAction, LearnedAction]] = []
+
+ # NOTE: only taking the top n (optimal number varies, determine
+ # empirically) constraints usually results in more accurate action
+ # models, however this is not a part of the paper and therefore not
+ # implemented.
+ for constraint, val in model.items():
+ constraint = str(constraint).split(" (BREAK) ")
+ relation = constraint[0]
+ ctype = constraint[1] # constraint type
+ if ctype == "in":
+ effect = constraint[2]
+ action = action_map[constraint[3]]
+ if debug:
+ print(
+ f"Learned constraint: {relation} in {effect}_{action.details()}"
+ )
+ if val:
+ action_update = {
+ "pre": action.update_precond,
+ "add": action.update_add,
+ "del": action.update_delete,
+ }[effect]
+ action_update({relation})
+ else:
+ action_effect = {
+ "pre": action.precond,
+ "add": action.add,
+ "del": action.delete,
+ }[effect]
+ if relation in action_effect:
+ if debug:
+ warn(
+ f"Removing {relation} from {effect} of {action.details()}"
+ )
+ action_effect.remove(relation)
+ negative_constraints[(relation, action)].add(effect)
+
+ else: # store plan constraint
+ ai = action_map[constraint[2]]
+ aj = action_map[constraint[3]]
+ plan_constraints.append((relation, ai, aj))
+ if debug:
+ print(f"{relation} possibly explains action pair ({ai}, {aj})")
+
+ for p, ai, aj in plan_constraints:
+ if (
+ not (
+ p in ai.precond.intersection(aj.precond) and p not in ai.delete
+ ) # P3
+ or not (p in ai.add.intersection(aj.precond)) # P4
+ or not (p in ai.delete.intersection(aj.add)) # P5
+ ):
+ # check if either P3 or P4 are partially fulfilled and can be satisfied
+ if p in ai.precond.union(aj.precond):
+ if p in aj.precond:
+ # if P3 isn't contradicted, add p to ai.precond
+ if p not in ai.delete and not (
+ (p, ai) in negative_constraints
+ and "pre" in negative_constraints[(p, ai)]
+ ):
+ ai.update_precond({p})
+
+ # if P4 isn't contradicted, add p to ai.add
+ if not (
+ (p, ai) in negative_constraints
+ and "add" in negative_constraints[(p, ai)]
+ ):
+ ai.update_add({p})
+
+ # p in ai.precond and P3 not contradicted, add p to aj.precond
+ elif p not in ai.delete and not (
+ (p, aj) in negative_constraints
+ and "pre" in negative_constraints[(p, aj)]
+ ):
+ aj.update_precond({p})
+
+ # check if either P3 or P4 can be satisfied
+ elif not (
+ (p, aj) in negative_constraints
+ and "pre" in negative_constraints[(p, aj)]
+ ):
+ # if P3 isn't contradicted, add p to both ai and aj preconds
+ if p not in ai.delete and not (
+ (p, ai) in negative_constraints
+ and "pre" in negative_constraints[(p, ai)]
+ ):
+ ai.update_precond({p})
+ aj.update_precond({p})
+
+ # if P4 isn't contradicted, add p to ai.add and aj.precond
+ if not (
+ (p, ai) in negative_constraints
+ and "add" in negative_constraints[(p, ai)]
+ ):
+ ai.update_add({p})
+ aj.update_precond({p})
+
+ # check if P5 can be satisfied
+ # if P5 isn't contradicted, add p wherever it is missing
+ if not (
+ (p, ai) in negative_constraints
+ and "del" in negative_constraints[(p, ai)]
+ ) and not (
+ (p, aj) in negative_constraints
+ and "add" in negative_constraints[(p, aj)]
+ ):
+ if p not in ai.delete:
+ ai.update_delete({p})
+ if p not in aj.add:
+ aj.update_add({p})
+
+ @staticmethod
+ def debug_menu(prompt: str):
+ choice = input(prompt + " (y/n): ").lower()
+ while choice not in ["y", "n"]:
+ choice = input(prompt + " (y/n): ").lower()
+ return choice == "y"
diff --git a/macq/extract/exceptions.py b/macq/extract/exceptions.py
new file mode 100644
index 00000000..7407cb23
--- /dev/null
+++ b/macq/extract/exceptions.py
@@ -0,0 +1,26 @@
+class IncompatibleObservationToken(Exception):
+ def __init__(self, token, technique, message=None):
+ if message is None:
+ message = f"Observations of type {token.__name__} are not compatible with the {technique.__name__} extraction technique."
+ super().__init__(message)
+
+
+class InconsistentConstraintWeights(Exception):
+ def __init__(self, constraint, weight1, weight2, message=None):
+ if message is None:
+ message = f"Tried to assign the constraint {constraint} conflicting weights ({weight1} and {weight2})."
+ super().__init__(message)
+
+
+class InvalidMaxSATModel(Exception):
+ def __init__(self, model, message=None):
+ if message is None:
+ message = f"The MAX-SAT solver generated an invalid model. Model should be a list of integers. model = {model}."
+ super().__init__(message)
+
+
+class ConstraintContradiction(Exception):
+ def __init__(self, relation, effect, action, message=None):
+ if message is None:
+ message = f"Action model has contradictory constraints for {relation}'s presence in the {effect} list of {action.details()}."
+ super().__init__(message)
diff --git a/macq/extract/extract.py b/macq/extract/extract.py
index ea2e84a4..b2f8310f 100644
--- a/macq/extract/extract.py
+++ b/macq/extract/extract.py
@@ -1,9 +1,14 @@
+""".. include:: ../../docs/extract/extract.md"""
+
from dataclasses import dataclass
from enum import Enum, auto
-from ..trace import ObservationLists, Action, State
+from ..trace import Action, State
+from ..observation import ObservationLists
from .model import Model
from .observer import Observer
from .slaf import SLAF
+from .amdn import AMDN
+from .arms import ARMS
@dataclass
@@ -13,13 +18,6 @@ class SAS:
post_state: State
-class IncompatibleObservationToken(Exception):
- def __init__(self, token, technique, message=None):
- if message is None:
- message = f"Observations of type {token.__name__} are not compatible with the {technique.__name__} extraction technique."
- super().__init__(message)
-
-
class modes(Enum):
"""Model extraction techniques.
@@ -28,6 +26,8 @@ class modes(Enum):
OBSERVER = auto()
SLAF = auto()
+ AMDN = auto()
+ ARMS = auto()
class Extract:
@@ -37,7 +37,9 @@ class Extract:
from state observations.
"""
- def __new__(cls, obs_lists: ObservationLists, mode: modes, **kwargs) -> Model:
+ def __new__(
+ cls, obs_lists: ObservationLists, mode: modes, debug: bool = False, **kwargs
+ ) -> Model:
"""Extracts a Model object.
Extracts a model from the observations using the specified extraction
@@ -58,11 +60,11 @@ def __new__(cls, obs_lists: ObservationLists, mode: modes, **kwargs) -> Model:
techniques = {
modes.OBSERVER: Observer,
modes.SLAF: SLAF,
+ modes.AMDN: AMDN,
+ modes.ARMS: ARMS,
}
if mode == modes.SLAF:
- # only allow one trace
- assert (
- len(obs_lists) == 1
- ), "The SLAF extraction technique only takes one trace."
+ if len(obs_lists) != 1:
+ raise Exception("The SLAF extraction technique only takes one trace.")
- return techniques[mode](obs_lists, **kwargs)
+ return techniques[mode](obs_lists, debug, **kwargs)
diff --git a/macq/extract/learned_action.py b/macq/extract/learned_action.py
index d11d5f7f..6d040b89 100644
--- a/macq/extract/learned_action.py
+++ b/macq/extract/learned_action.py
@@ -1,10 +1,9 @@
from __future__ import annotations
-from typing import Set, List
-from ..trace import Fluent
+from typing import List, Set
class LearnedAction:
- def __init__(self, name: str, obj_params: List, **kwargs):
+ def __init__(self, name: str, obj_params: List[str], **kwargs):
self.name = name
self.obj_params = obj_params
if "cost" in kwargs:
@@ -15,9 +14,11 @@ def __init__(self, name: str, obj_params: List, **kwargs):
self.delete = set() if "delete" not in kwargs else kwargs["delete"]
def __eq__(self, other):
- if not isinstance(other, LearnedAction):
- return False
- return self.name == other.name and self.obj_params == other.obj_params
+ return (
+ isinstance(other, LearnedAction)
+ and self.name == other.name
+ and self.obj_params == other.obj_params
+ )
def __hash__(self):
# Order of obj_params is important!
@@ -26,7 +27,7 @@ def __hash__(self):
def details(self):
# obj_params can be either a list of strings or a list of PlanningObject depending on the token type and extraction method used to learn the action
try:
- string = f"({self.name} {' '.join([o for o in self.obj_params])})"
+ string = f"({self.name} {' '.join(self.obj_params)})"
except TypeError:
string = f"({self.name} {' '.join([o.details() for o in self.obj_params])})"
@@ -59,6 +60,11 @@ def update_delete(self, fluents: Set[str]):
"""
self.delete.update(fluents)
+ def clear(self):
+ self.precond = set()
+ self.add = set()
+ self.delete = set()
+
def compare(self, orig_action: LearnedAction):
"""Compares the learned action to an original, ground truth action."""
precond_diff = orig_action.precond.difference(self.precond)
diff --git a/macq/extract/learned_fluent.py b/macq/extract/learned_fluent.py
index c80d3545..7093bd73 100644
--- a/macq/extract/learned_fluent.py
+++ b/macq/extract/learned_fluent.py
@@ -1,14 +1,13 @@
from typing import List
+
class LearnedFluent:
def __init__(self, name: str, objects: List):
self.name = name
self.objects = objects
def __eq__(self, other):
- if not isinstance(other, LearnedFluent):
- return False
- return hash(self) == hash(other)
+ return isinstance(other, LearnedFluent) and hash(self) == hash(other)
def __hash__(self):
# Order of objects is important!
diff --git a/macq/extract/model.py b/macq/extract/model.py
index 6a2fc578..04af2c05 100644
--- a/macq/extract/model.py
+++ b/macq/extract/model.py
@@ -27,9 +27,7 @@ class Model:
action attributes characterize the model.
"""
- def __init__(
- self, fluents: Set[LearnedFluent], actions: Set[LearnedAction]
- ):
+ def __init__(self, fluents: Set[LearnedFluent], actions: Set[LearnedAction]):
"""Initializes a Model with a set of fluents and a set of actions.
Args:
@@ -129,7 +127,13 @@ def __to_tarski_formula(self, attribute: Set[str], lang: FirstOrderLanguage):
Connective.And, [lang.get(a.replace(" ", "_"))() for a in attribute]
)
- def to_pddl(self, domain_name: str, problem_name: str, domain_filename: str, problem_filename: str):
+ def to_pddl(
+ self,
+ domain_name: str,
+ problem_name: str,
+ domain_filename: str,
+ problem_filename: str,
+ ):
"""Dumps a Model to two PDDL files. The conversion only uses 0-arity predicates, and no types, objects,
or parameters of any kind are used. Actions are represented as ground actions with no parameters.
@@ -192,4 +196,4 @@ def deserialize(string: str):
@classmethod
def _from_json(cls, data: dict):
actions = set(map(LearnedAction._deserialize, data["actions"]))
- return cls(set(data["fluents"]), actions)
\ No newline at end of file
+ return cls(set(data["fluents"]), actions)
diff --git a/macq/extract/observer.py b/macq/extract/observer.py
index e93d0bdd..f881fbb5 100644
--- a/macq/extract/observer.py
+++ b/macq/extract/observer.py
@@ -1,12 +1,14 @@
+""".. include:: ../../docs/extract/observer.md"""
+
from typing import List, Set
from collections import defaultdict
from attr import dataclass
-import macq.extract as extract
+from . import LearnedAction, Model
+from .exceptions import IncompatibleObservationToken
from .model import Model
from .learned_fluent import LearnedFluent
-from ..trace import ObservationLists
-from ..observation import IdentityObservation
+from ..observation import IdentityObservation, ObservationLists
@dataclass
@@ -30,7 +32,7 @@ class Observer:
fluents that went from True to False.
"""
- def __new__(cls, obs_lists: ObservationLists):
+ def __new__(cls, obs_lists: ObservationLists, debug: bool):
"""Creates a new Model object.
Args:
@@ -41,7 +43,7 @@ def __new__(cls, obs_lists: ObservationLists):
Raised if the observations are not identity observation.
"""
if obs_lists.type is not IdentityObservation:
- raise extract.IncompatibleObservationToken(obs_lists.type, Observer)
+ raise IncompatibleObservationToken(obs_lists.type, Observer)
fluents = Observer._get_fluents(obs_lists)
actions = Observer._get_actions(obs_lists)
return Model(fluents, actions)
@@ -54,7 +56,10 @@ def _get_fluents(obs_lists: ObservationLists):
for obs_list in obs_lists:
for obs in obs_list:
# Update fluents with the fluents in this observation
- fluents.update(LearnedFluent(f.name, [o.details() for o in f.objects]) for f in obs.state.keys())
+ fluents.update(
+ LearnedFluent(f.name, [o.details() for o in f.objects])
+ for f in obs.state.keys()
+ )
return fluents
@staticmethod
@@ -74,7 +79,7 @@ def _get_actions(obs_lists: ObservationLists):
action_transitions = obs_lists.get_all_transitions()
for action, transitions in action_transitions.items():
# Create a LearnedAction for the current action
- model_action = extract.LearnedAction(
+ model_action = LearnedAction(
action.name, action.obj_params, cost=action.cost
)
diff --git a/macq/extract/slaf.py b/macq/extract/slaf.py
index ada95794..0e78413f 100644
--- a/macq/extract/slaf.py
+++ b/macq/extract/slaf.py
@@ -1,11 +1,13 @@
+""".. include:: ../../docs/extract/slaf.md"""
+
import macq.extract as extract
from typing import Set, Union
from nnf import Var, Or, And, true, false, config
from bauhaus import Encoding
+from .exceptions import IncompatibleObservationToken
from .model import Model
from .learned_fluent import LearnedFluent
-from ..observation import AtomicPartialObservation
-from ..trace import ObservationLists
+from ..observation import AtomicPartialObservation, ObservationLists
# only used for pretty printing in debug mode
e = Encoding()
@@ -52,7 +54,7 @@ def __new__(cls, o_list: ObservationLists, debug_mode: bool = False):
Raised if the observations are not identity observation.
"""
if o_list.type is not AtomicPartialObservation:
- raise extract.IncompatibleObservationToken(o_list.type, SLAF)
+ raise IncompatibleObservationToken(o_list.type, SLAF)
SLAF.debug_mode = debug_mode
entailed = SLAF.__as_strips_slaf(o_list)
# return the Model
@@ -162,12 +164,13 @@ def __sort_results(observations: ObservationLists, entailed: Set):
The extracted `Model`.
"""
learned_actions = {}
- base_fluents = {}
model_fluents = set()
# iterate through each step
for o in observations:
for token in o:
- model_fluents.update([LearnedFluent(name=f, objects=[]) for f in token.state])
+ model_fluents.update(
+ [LearnedFluent(name=f, objects=[]) for f in token.state]
+ )
# if an action was taken on this step
if token.action:
# set up a base LearnedAction with the known information
@@ -356,7 +359,7 @@ def __as_strips_slaf(o_list: ObservationLists):
phi["pos expl"] = set()
phi["neg expl"] = set()
- """Steps 1 (a-c) - Update every fluent in the fluent-factored transition belief formula
+ """Steps 1 (a-c) - Update every fluent in the fluent-factored transition belief formula
with information from the last step."""
"""Step 1 (a) - update the neutral effects."""
diff --git a/macq/generate/pddl/generator.py b/macq/generate/pddl/generator.py
index 95250395..2523298a 100644
--- a/macq/generate/pddl/generator.py
+++ b/macq/generate/pddl/generator.py
@@ -1,3 +1,4 @@
+from time import sleep
from typing import Set, List, Union
from tarski.io import PDDLReader
from tarski.search import GroundForwardSearchModel
@@ -21,15 +22,21 @@
from ...trace import Action, State, PlanningObject, Fluent, Trace, Step
+class PlanningDomainsAPIError(Exception):
+ """Raised when a valid response cannot be obtained from the planning.domains solver."""
+
+ def __init__(self, message):
+ super().__init__(message)
+
+
class InvalidGoalFluent(Exception):
"""
Raised when the user attempts to supply a new goal with invalid fluent(s).
"""
- def __init__(
- self,
- message="The fluents provided contain one or more fluents not available in this problem.",
- ):
+ def __init__(self, fluent, message=None):
+ if message is None:
+ message = f"{fluent} is not available in this problem."
super().__init__(message)
@@ -57,10 +64,16 @@ class Generator:
op_dict (dict):
The problem's ground operators, formatted to a dictionary for easy access during plan generation.
observe_pres_effs (bool):
- Option to observe action preconditions and effects upon generation.
+ Option to observe action preconditions and effects upon generation.
"""
- def __init__(self, dom: str = None, prob: str = None, problem_id: int = None, observe_pres_effs: bool = False):
+ def __init__(
+ self,
+ dom: str = None,
+ prob: str = None,
+ problem_id: int = None,
+ observe_pres_effs: bool = False,
+ ):
"""Creates a basic PDDL state trace generator. Takes either the raw filenames
of the domain and problem, or a problem ID.
@@ -72,7 +85,7 @@ def __init__(self, dom: str = None, prob: str = None, problem_id: int = None, ob
problem_id (int):
The ID of the problem to access.
observe_pres_effs (bool):
- Option to observe action preconditions and effects upon generation.
+ Option to observe action preconditions and effects upon generation.
"""
# get attributes
self.pddl_dom = dom
@@ -146,8 +159,12 @@ def __get_op_dict(self):
"""
op_dict = {}
for o in self.instance.operators:
- # reformat so that operators can be referenced by the same string format the planner uses for actions
- op_dict["".join(["(", o.name.replace("(", " ").replace(",", "")])] = o
+ # special case for actions that don't take parameters
+ if "()" in o.name:
+ op_dict["".join(["(", o.name[:-2], ")"])] = o
+ else:
+ # reformat so that operators can be referenced by the same string format the planner uses for actions
+ op_dict["".join(["(", o.name.replace("(", " ").replace(",", "")])] = o
return op_dict
def __get_all_grounded_fluents(self):
@@ -248,9 +265,7 @@ def tarski_act_to_macq(self, tarski_act: PlainOperator):
raw_precond = tarski_act.precondition.subformulas
for raw_p in raw_precond:
if isinstance(raw_p, CompoundFormula):
- precond.add(
- self.__tarski_atom_to_macq_fluent(raw_p.subformulas[0])
- )
+ precond.add(self.__tarski_atom_to_macq_fluent(raw_p.subformulas[0]))
else:
precond.add(self.__tarski_atom_to_macq_fluent(raw_p))
else:
@@ -263,7 +278,17 @@ def tarski_act_to_macq(self, tarski_act: PlainOperator):
for fluent in precond:
objs.update(set(fluent.objects))
- return Action(name=name, obj_params=list(objs), precond=precond, add=add, delete=delete) if self.observe_pres_effs else Action(name=name, obj_params=list(objs))
+ return (
+ Action(
+ name=name,
+ obj_params=list(objs),
+ precond=precond,
+ add=add,
+ delete=delete,
+ )
+ if self.observe_pres_effs
+ else Action(name=name, obj_params=list(objs))
+ )
def change_init(
self,
@@ -285,7 +310,10 @@ def change_init(
init = create(self.lang)
for f in init_fluents:
# convert fluents to tarski Atoms
- atom = Atom(self.lang.get_predicate(f.name), [self.lang.get(o.name) for o in f.objects])
+ atom = Atom(
+ self.lang.get_predicate(f.name),
+ [self.lang.get(o.name) for o in f.objects],
+ )
init.add(atom.predicate, *atom.subterms)
self.problem.init = init
@@ -320,7 +348,7 @@ def change_goal(
available_f = self.grounded_fluents
for f in goal_fluents:
if f not in available_f:
- raise InvalidGoalFluent()
+ raise InvalidGoalFluent(f)
# convert the given set of fluents into a formula
if not goal_fluents:
@@ -344,7 +372,7 @@ def change_goal(
self.pddl_dom = new_domain
self.pddl_prob = new_prob
- def generate_plan(self, from_ipc_file:bool=False, filename:str=None):
+ def generate_plan(self, from_ipc_file: bool = False, filename: str = None):
"""Generates a plan. If reading from an IPC file, the `Plan` is read directly. Otherwise, if the initial state or
goal was changed, these changes are taken into account through the updated PDDL files. If no changes were made, the
default nitial state/goal in the initial problem file is used.
@@ -369,14 +397,30 @@ def generate_plan(self, from_ipc_file:bool=False, filename:str=None):
"domain": open(self.pddl_dom, "r").read(),
"problem": open(self.pddl_prob, "r").read(),
}
- resp = requests.post(
- "http://solver.planning.domains/solve", verify=False, json=data
- ).json()
- plan = [act["name"] for act in resp["result"]["plan"]]
+
+ def get_api_response(delays: List[int]):
+ if delays:
+ sleep(delays[0])
+ try:
+ resp = requests.post(
+ "http://solver.planning.domains/solve",
+ verify=False,
+ json=data,
+ ).json()
+ return [act["name"] for act in resp["result"]["plan"]]
+ except TypeError:
+ return get_api_response(delays[1:])
+
+ plan = get_api_response([0, 1, 3, 5, 10])
+ if plan is None:
+ raise PlanningDomainsAPIError(
+ "Could not get a valid response from the planning.domains solver after 5 attempts.",
+ )
+
else:
f = open(filename, "r")
- plan = list(filter(lambda x: ';' not in x, f.read().splitlines()))
-
+ plan = list(filter(lambda x: ";" not in x, f.read().splitlines()))
+
# convert to a list of tarski PlainOperators (actions)
return Plan([self.op_dict[p] for p in plan if p in self.op_dict])
@@ -406,4 +450,4 @@ def generate_single_trace_from_plan(self, plan: Plan):
state = progress(state, act)
else:
trace.append(Step(macq_state, None, i + 1))
- return trace
\ No newline at end of file
+ return trace
diff --git a/macq/generate/pddl/random_goal_sampling.py b/macq/generate/pddl/random_goal_sampling.py
index 17ea53e2..0bc025c0 100644
--- a/macq/generate/pddl/random_goal_sampling.py
+++ b/macq/generate/pddl/random_goal_sampling.py
@@ -4,18 +4,13 @@
from collections import OrderedDict
from . import VanillaSampling
from ...trace import TraceList, State
-from ...utils import PercentError
-from ...utils.timer import basic_timer
-
-
-
-MAX_GOAL_SEARCH_TIME = 30.0
+from ...utils import PercentError, basic_timer, progress
class RandomGoalSampling(VanillaSampling):
"""Random Goal State Trace Sampler - inherits the VanillaSampling class and its attributes.
- A state trace generator that generates traces by randomly generating some candidate states/goals k steps deep,
+ A state trace generator that generates traces by randomly generating some candidate states/goals k steps deep,
then running a planner on a random subset of the fluents to get plans. The longest plans (those closest to k, thus representing
goal states that are somewhat complex and take longer to reach) are taken and used to generate traces.
@@ -30,8 +25,9 @@ class RandomGoalSampling(VanillaSampling):
The percentage of fluents to extract to use as a goal state from the generated states.
goals_inits_plans (List[Dict]):
A list of dictionaries, where each dictionary stores the generated goal state as the key and the initial state and plan used to
- reach the goal as values.
+ reach the goal as values.
"""
+
def __init__(
self,
steps_deep: int,
@@ -41,6 +37,7 @@ def __init__(
dom: str = None,
prob: str = None,
problem_id: int = None,
+ max_time: float = 30,
observe_pres_effs: bool = False,
):
"""
@@ -64,98 +61,116 @@ def __init__(
The problem filename.
problem_id (int):
The ID of the problem to access.
+ max_time (float):
+ The maximum time allowed for a trace to be generated.
observe_pres_effs (bool):
- Option to observe action preconditions and effects upon generation.
+ Option to observe action preconditions and effects upon generation.
"""
if subset_size_perc < 0 or subset_size_perc > 1:
raise PercentError()
- self.steps_deep = steps_deep
+ self.steps_deep = steps_deep
self.enforced_hill_climbing_sampling = enforced_hill_climbing_sampling
self.subset_size_perc = subset_size_perc
self.goals_inits_plans = []
- super().__init__(dom=dom, prob=prob, problem_id=problem_id, observe_pres_effs=observe_pres_effs, num_traces=num_traces)
+ super().__init__(
+ dom=dom,
+ prob=prob,
+ problem_id=problem_id,
+ num_traces=num_traces,
+ observe_pres_effs=observe_pres_effs,
+ max_time=max_time
+ )
def goal_sampling(self):
"""Samples goals by randomly generating candidate goal states k (`steps_deep`) steps deep, then running planners on those
- goal states to ensure the goals are complex enough (i.e. cannot be reached in too few steps). Candidate
- goal states are generated for a set amount of time indicated by MAX_GOAL_SEARCH_TIME, and the goals with the
+ goal states to ensure the goals are complex enough (i.e. cannot be reached in too few steps). Candidate
+ goal states are generated for a set amount of time indicated by MAX_GOAL_SEARCH_TIME, and the goals with the
longest plans (the most complex goals) are selected.
Returns: An OrderedDict holding the longest goal states along with the initial state and plans used to reach them.
"""
goal_states = {}
- self.generate_goals(goal_states=goal_states)
+ self.generate_goals_setup(num_seconds=self.max_time, goal_states=goal_states)()
# sort the results by plan length and get the k largest ones
- filtered_goals = OrderedDict(sorted(goal_states.items(), key=lambda x : len(x[1]["plan"].actions)))
- to_del = list(filtered_goals.keys())[:len(filtered_goals) - self.num_traces]
+ filtered_goals = OrderedDict(
+ sorted(goal_states.items(), key=lambda x: len(x[1]["plan"].actions))
+ )
+ to_del = list(filtered_goals.keys())[: len(filtered_goals) - self.num_traces]
for d in to_del:
del filtered_goals[d]
return filtered_goals
- @basic_timer(num_seconds=MAX_GOAL_SEARCH_TIME)
- def generate_goals(self, goal_states: Dict):
- """Helper function for `goal_sampling`. Generates as many goals as possible within MAX_GOAL_SEARCH_TIME seconds.
- Given the specified number of traces `num_traces`, if `num_traces` plans of length k (`steps_deep`) are found before
- the time is up, exit early.
-
- Args:
- goal_states (Dict):
- The dictionary to fill with the values of each goal state, initial state, and plan.
- """
- # create a sampler to test the complexity of the new goal by running a planner on it
- k_length_plans = 0
- while True:
- # generate a trace of the specified length and retrieve the state of the last step
- state = self.generate_single_trace(self.steps_deep)[-1].state
-
- # get all positive fluents (only positive fluents can be used for a goal)
- goal_f = [f for f in state if state[f]]
- # get next initial state (only used for enforced hill climbing sampling)
- next_init_f = goal_f.copy()
- # get the subset size
- subset_size = int(len(state.fluents) * self.subset_size_perc)
- # if necessary, take a subset of the fluents
- if len(goal_f) > subset_size:
- random.shuffle(goal_f)
- goal_f = goal_f[:subset_size]
-
- self.change_goal(goal_fluents=goal_f)
-
- # ensure that the goal doesn't hold in the initial state; restart if it does
- init_state = {
- str(a) for a in self.problem.init.as_atoms()
- }
- goal = {
- str(a) for a in self.problem.goal.subformulas
- }
-
- if goal.issubset(init_state):
- continue
-
- try:
- # attempt to generate a plan, and find a new goal if a plan can't be found
- # should only crash if there are server issues
- test_plan = self.generate_plan()
- except KeyError as e:
- continue
-
- # create a State and add it to the dictionary
- state_dict = {}
- for f in goal_f:
- state_dict[f] = True
- # map each goal to the initial state and plan used to achieve it
- goal_states[State(state_dict)] = {"plan": test_plan, "initial state": self.problem.init}
-
- # optionally change the initial state of the sampler for the next iteration to the goal state just generated (ensures more diversity in goals/plans)
- # use the full state the goal was extracted from as the initial state to prevent planning errors from incomplete initial states
- if self.enforced_hill_climbing_sampling:
- self.change_init(next_init_f)
-
- # keep track of the number of plans of length k; if we get enough of them, exit early
- if len(test_plan.actions) >= self.steps_deep:
- k_length_plans += 1
- if k_length_plans >= self.num_traces:
- break
+ def generate_goals_setup(self, num_seconds: float, goal_states: Dict):
+ @basic_timer(num_seconds=num_seconds)
+ def generate_goals(self=self, goal_states=goal_states):
+ """Helper function for `goal_sampling`. Generates as many goals as possible within the specified max_time seconds (timing is
+ enforced by the basic_timer wrapper).
+
+ The outside function is a wrapper that provides parameters for both the timer
+ wrapper and the function.
+
+ Given the specified number of traces `num_traces`, if `num_traces` plans of length k (`steps_deep`) are found before
+ the time is up, exit early.
+
+ Args:
+ goal_states (Dict):
+ The dictionary to fill with the values of each goal state, initial state, and plan.
+ """
+ # create a sampler to test the complexity of the new goal by running a planner on it
+ k_length_plans = 0
+ while True:
+ # generate a trace of the specified length and retrieve the state of the last step
+ state = self.generate_single_trace_setup(num_seconds, self.steps_deep)()[-1].state
+
+ # get all positive fluents (only positive fluents can be used for a goal)
+ goal_f = [f for f in state if state[f]]
+ # get next initial state (only used for enforced hill climbing sampling)
+ next_init_f = goal_f.copy()
+ # get the subset size
+ subset_size = int(len(state.fluents) * self.subset_size_perc)
+ # if necessary, take a subset of the fluents
+ if len(goal_f) > subset_size:
+ random.shuffle(goal_f)
+ goal_f = goal_f[:subset_size]
+
+ self.change_goal(goal_fluents=goal_f)
+
+ # ensure that the goal doesn't hold in the initial state; restart if it does
+ init_state = {
+ str(a) for a in self.problem.init.as_atoms()
+ }
+ goal = {
+ str(a) for a in self.problem.goal.subformulas
+ }
+
+ if goal.issubset(init_state):
+ continue
+
+ try:
+ # attempt to generate a plan, and find a new goal if a plan can't be found
+ # should only crash if there are server issues
+ test_plan = self.generate_plan()
+ except KeyError:
+ continue
+
+ # create a State and add it to the dictionary
+ state_dict = {}
+ for f in goal_f:
+ state_dict[f] = True
+ # map each goal to the initial state and plan used to achieve it
+ goal_states[State(state_dict)] = {"plan": test_plan, "initial state": self.problem.init}
+
+ # optionally change the initial state of the sampler for the next iteration to the goal state just generated (ensures more diversity in goals/plans)
+ # use the full state the goal was extracted from as the initial state to prevent planning errors from incomplete initial states
+ if self.enforced_hill_climbing_sampling:
+ self.change_init(next_init_f)
+
+ # keep track of the number of plans of length k; if we get enough of them, exit early
+ if len(test_plan.actions) >= self.steps_deep:
+ k_length_plans += 1
+ if k_length_plans >= self.num_traces:
+ break
+ return generate_goals
def generate_traces(self):
"""Generates traces based on the sampled goals. Traces are generated using the initial state and plan used to achieve the goal.
@@ -167,12 +182,10 @@ def generate_traces(self):
# retrieve goals and their respective plans
self.goals_inits_plans = self.goal_sampling()
# iterate through all plans corresponding to the goals, generating traces
- for goal in self.goals_inits_plans.values():
+ for goal in progress(self.goals_inits_plans.values()):
# update the initial state if necessary
if self.enforced_hill_climbing_sampling:
self.problem.init = goal["initial state"]
# generate a plan based on the new goal/initial state, then generate a trace based on that plan
- traces.append(
- self.generate_single_trace_from_plan(goal["plan"])
- )
- return traces
\ No newline at end of file
+ traces.append(self.generate_single_trace_from_plan(goal["plan"]))
+ return traces
diff --git a/macq/generate/pddl/vanilla_sampling.py b/macq/generate/pddl/vanilla_sampling.py
index 84d0cbab..b4e23010 100644
--- a/macq/generate/pddl/vanilla_sampling.py
+++ b/macq/generate/pddl/vanilla_sampling.py
@@ -1,19 +1,21 @@
from tarski.search.operations import progress
import random
from . import Generator
-from ...utils import set_timer_throw_exc, TraceSearchTimeOut, basic_timer, set_num_traces, set_plan_length
-from ...observation.partial_observation import PercentError
+from ...utils import (
+ set_timer_throw_exc,
+ TraceSearchTimeOut,
+ InvalidTime,
+ set_num_traces,
+ set_plan_length,
+ progress as print_progress,
+)
from ...trace import (
Step,
- State,
Trace,
TraceList,
)
-MAX_TRACE_TIME = 30.0
-
-
class VanillaSampling(Generator):
"""Vanilla State Trace Sampler - inherits the base Generator class and its attributes.
@@ -21,6 +23,8 @@ class VanillaSampling(Generator):
of the given length.
Attributes:
+ max_time (float):
+ The maximum time allowed for a trace to be generated.
plan_len (int):
The length of the traces to be generated.
num_traces (int):
@@ -30,7 +34,7 @@ class VanillaSampling(Generator):
"""
def __init__(
- self,
+ self,
dom: str = None,
prob: str = None,
problem_id: int = None,
@@ -38,6 +42,7 @@ def __init__(
plan_len: int = 1,
num_traces: int = 1,
seed: int = None,
+ max_time: float = 30,
):
"""
Initializes a vanilla state trace sampler using the plan length, number of traces,
@@ -50,15 +55,24 @@ def __init__(
The problem filename.
problem_id (int):
The ID of the problem to access.
+ max_time (float):
+ The maximum time allowed for a trace to be generated.
observe_pres_effs (bool):
- Option to observe action preconditions and effects upon generation.
+ Option to observe action preconditions and effects upon generation.
plan_len (int):
The length of each generated trace. Defaults to 1.
num_traces (int):
The number of traces to generate. Defaults to 1.
-
"""
- super().__init__(dom=dom, prob=prob, problem_id=problem_id, observe_pres_effs=observe_pres_effs)
+ super().__init__(
+ dom=dom,
+ prob=prob,
+ problem_id=problem_id,
+ observe_pres_effs=observe_pres_effs,
+ )
+ if max_time <= 0:
+ raise InvalidTime()
+ self.max_time = max_time
self.plan_len = set_plan_length(plan_len)
self.num_traces = set_num_traces(num_traces)
self.traces = self.generate_traces()
@@ -73,52 +87,60 @@ def generate_traces(self):
A TraceList object with the list of traces generated.
"""
traces = TraceList()
- traces.generator = self.generate_single_trace
- for _ in range(self.num_traces):
- traces.append(self.generate_single_trace())
+ traces.generator = self.generate_single_trace_setup(
+ num_seconds=self.max_time, plan_len=self.plan_len
+ )
+ for _ in print_progress(range(self.num_traces)):
+ traces.append(traces.generator())
return traces
- @set_timer_throw_exc(num_seconds=MAX_TRACE_TIME, exception=TraceSearchTimeOut)
- def generate_single_trace(self, plan_len: int = None):
- """Generates a single trace using the uniform random sampling technique.
- Loops until a valid trace is found. Wrapper does not allow the function
- to run past the time specified.
-
- Returns:
- A Trace object (the valid trace generated).
- """
-
- if not plan_len:
- plan_len = self.plan_len
-
- trace = Trace()
-
- state = self.problem.init
- valid_trace = False
- while not valid_trace:
- trace.clear()
- # add more steps while the trace has not yet reached the desired length
- for j in range(plan_len):
- # if we have not yet reached the last step
- if len(trace) < plan_len - 1:
- # find the next applicable actions
- app_act = list(self.instance.applicable(state))
- # if the trace reaches a dead lock, disregard this trace and try again
- if not app_act:
- break
- # pick a random applicable action and apply it
- act = random.choice(app_act)
- # create the trace and progress the state
- macq_action = self.tarski_act_to_macq(act)
- macq_state = self.tarski_state_to_macq(state)
- step = Step(macq_state, macq_action, j + 1)
- trace.append(step)
- state = progress(state, act)
- else:
- macq_state = self.tarski_state_to_macq(state)
- step = Step(state=macq_state, action=None, index=j + 1)
- trace.append(step)
- valid_trace = True
- return trace
-
-
+ def generate_single_trace_setup(self, num_seconds: float, plan_len: int = None):
+ @set_timer_throw_exc(
+ num_seconds=num_seconds, exception=TraceSearchTimeOut, max_time=num_seconds
+ )
+ def generate_single_trace(self=self, plan_len=plan_len):
+ """Generates a single trace using the uniform random sampling technique.
+ Loops until a valid trace is found. The timer wrapper does not allow the function
+ to run past the time specified.
+
+ The outside function is a wrapper that provides parameters for both the timer
+ wrapper and the function.
+
+ Returns:
+ A Trace object (the valid trace generated).
+ """
+
+ if not plan_len:
+ plan_len = self.plan_len
+
+ trace = Trace()
+
+ state = self.problem.init
+ valid_trace = False
+ while not valid_trace:
+ trace.clear()
+ # add more steps while the trace has not yet reached the desired length
+ for j in range(plan_len):
+ # if we have not yet reached the last step
+ if len(trace) < plan_len - 1:
+ # find the next applicable actions
+ app_act = list(self.instance.applicable(state))
+ # if the trace reaches a dead lock, disregard this trace and try again
+ if not app_act:
+ break
+ # pick a random applicable action and apply it
+ act = random.choice(app_act)
+ # create the trace and progress the state
+ macq_action = self.tarski_act_to_macq(act)
+ macq_state = self.tarski_state_to_macq(state)
+ step = Step(macq_state, macq_action, j + 1)
+ trace.append(step)
+ state = progress(state, act)
+ else:
+ macq_state = self.tarski_state_to_macq(state)
+ step = Step(state=macq_state, action=None, index=j + 1)
+ trace.append(step)
+ valid_trace = True
+ return trace
+
+ return generate_single_trace
diff --git a/macq/generate/plan.py b/macq/generate/plan.py
index 8837c60a..6b8838a4 100644
--- a/macq/generate/plan.py
+++ b/macq/generate/plan.py
@@ -1,6 +1,7 @@
from typing import List
from tarski.fstrips.action import PlainOperator
+
class Plan:
"""A Plan.
@@ -11,6 +12,7 @@ class Plan:
actions (List[PlainOperator]):
The list of actions that make up the plan.
"""
+
def __init__(self, actions: List[PlainOperator]):
"""Creates a Plan by instantiating it with the list of actions (of tarski type `PlainOperator`).
@@ -42,5 +44,4 @@ def __str__(self):
return "\n".join(string)
def __eq__(self, other):
- if isinstance(other, Plan):
- return self.actions == other.actions
\ No newline at end of file
+ return isinstance(other, Plan) and self.actions == other.actions
diff --git a/macq/observation/__init__.py b/macq/observation/__init__.py
index 0ec84398..a1dc1c99 100644
--- a/macq/observation/__init__.py
+++ b/macq/observation/__init__.py
@@ -1,19 +1,23 @@
from .observation import Observation, InvalidQueryParameter
+from .observation_lists import ObservationLists
from .identity_observation import IdentityObservation
from .partial_observation import PartialObservation
from .atomic_partial_observation import AtomicPartialObservation
from .noisy_observation import NoisyObservation
from .noisy_partial_observation import NoisyPartialObservation
-from .noisy_partial_disordered_parallel_observation import NoisyPartialDisorderedParallelObservation
+from .noisy_partial_disordered_parallel_observation import (
+ NoisyPartialDisorderedParallelObservation,
+)
__all__ = [
"Observation",
+ "ObservationLists",
"InvalidQueryParameter",
"IdentityObservation",
"PartialObservation",
"AtomicPartialObservation",
"NoisyObservation",
"NoisyPartialObservation",
- "NoisyPartialDisorderedParallelObservation"
+ "NoisyPartialDisorderedParallelObservation",
]
diff --git a/macq/observation/atomic_partial_observation.py b/macq/observation/atomic_partial_observation.py
index eb72b925..b711f9e5 100644
--- a/macq/observation/atomic_partial_observation.py
+++ b/macq/observation/atomic_partial_observation.py
@@ -1,9 +1,7 @@
+from logging import warning
from ..trace import Step, Fluent
-from ..trace import PartialState
-from . import Observation, InvalidQueryParameter
-from typing import Callable, Union, Set, List, Optional
-from dataclasses import dataclass
-import random
+from . import PartialObservation, Observation
+from typing import Set
class PercentError(Exception):
@@ -16,134 +14,50 @@ def __init__(
super().__init__(message)
-class AtomicPartialObservation(Observation):
+class AtomicPartialObservation(PartialObservation):
"""The Atomic Partial Observability Token.
-
The atomic partial observability token stores the step where some of the values of
the fluents in the step's state are unknown. Inherits the base Observation
class. Unlike the partial observability token, the atomic partial observability token
stores everything in strings.
"""
- # used these to store action and state info with just strings
- class IdentityState(dict):
- def __hash__(self):
- return hash(tuple(sorted(self.items())))
-
- @dataclass
- class IdentityAction:
- name: str
- obj_params: List[str]
- cost: Optional[int]
-
- def __str__(self):
- objs_str = ""
- for o in self.obj_params:
- objs_str += o + " "
- return " ".join([self.name, objs_str]) + "[" + str(self.cost) + "]"
-
- def __hash__(self):
- return hash(str(self))
-
def __init__(
- self,
- step: Step,
- method: Union[Callable[[int], Step], Callable[[Set[Fluent]], Step]],
- **method_kwargs,
+ self, step: Step, percent_missing: float = 0, hide: Set[Fluent] = None
):
"""
- Creates an PartialObservation object, storing the step.
+ Creates an AtomicPartialObservation object, storing the step.
Args:
step (Step):
The step associated with this observation.
- method (function reference):
- The method to be used to tokenize the step.
- **method_kwargs (keyword arguments):
- The arguments to be passed to the corresponding method function.
- """
- super().__init__(index=step.index)
- step = method(self, step, **method_kwargs)
- self.state = self.IdentityState(
- {str(fluent): value for fluent, value in step.state.items()}
- )
- self.action = (
- None
- if step.action is None
- else self.IdentityAction(
- step.action.name,
- list(map(lambda o: o.details(), step.action.obj_params)),
- step.action.cost,
- )
- )
-
- def __eq__(self, other):
- if not isinstance(other, AtomicPartialObservation):
- return False
- return self.state == other.state and self.action == other.action
-
- # and here is the old matches function
-
- def _matches(self, key: str, value: str):
- if key == "action":
- if self.action is None:
- return value is None
- return str(self.action) == value
- elif key == "fluent_holds":
- return self.state[value]
- else:
- raise InvalidQueryParameter(AtomicPartialObservation, key)
-
- def details(self):
- return f"Obs {str(self.index)}.\n State: {str(self.state)}\n Action: {str(self.action)}"
-
- def random_subset(self, step: Step, percent_missing: float):
- """Method of tokenization that picks a random subset of fluents to hide.
-
- Args:
- step (Step):
- The step to tokenize.
percent_missing (float):
- The percentage of fluents to hide.
-
- Returns:
- The new step created using a PartialState that takes the hidden fluents into account.
+ The percentage of fluents to randomly hide in the observation.
+ hide (Set[Fluent]):
+ The set of fluents to explicitly hide in the observation.
"""
if percent_missing > 1 or percent_missing < 0:
raise PercentError()
- fluents = step.state.fluents
- num_new_fluents = int(len(fluents) * (percent_missing))
+ if percent_missing == 0 and not hide:
+ warning("Creating a PartialObseration with no missing information.")
- new_fluents = {}
- # shuffle keys and take an appropriate subset of them
- hide_fluents_ls = list(fluents)
- random.shuffle(hide_fluents_ls)
- hide_fluents_ls = hide_fluents_ls[:num_new_fluents]
- # get new dict
- for f in fluents:
- if f in hide_fluents_ls:
- new_fluents[f] = None
- else:
- new_fluents[f] = step.state[f]
- return Step(PartialState(new_fluents), step.action, step.index)
+ Observation.__init__(self, index=step.index)
- def same_subset(self, step: Step, hide_fluents: Set[Fluent]):
- """Method of tokenization that hides the same subset of fluents every time.
+ if percent_missing < 1:
+ step = self.hide_random_subset(step, percent_missing)
+ if hide:
+ step = self.hide_subset(step, hide)
- Args:
- step (Step):
- The step to tokenize.
- hide_fluents (Set[Fluent]):
- The set of fluents that will be hidden each time.
+ self.state = None if percent_missing == 1 else step.state.clone(atomic=True)
+ self.action = None if step.action is None else step.action.clone(atomic=True)
- Returns:
- The new step created using a PartialState that takes the hidden fluents into account.
- """
- new_fluents = {}
- for f in step.state.fluents:
- if f in hide_fluents:
- new_fluents[f] = None
- else:
- new_fluents[f] = step.state[f]
- return Step(PartialState(new_fluents), step.action, step.index)
+ def __eq__(self, other):
+ return (
+ isinstance(other, AtomicPartialObservation)
+ and self.state == other.state
+ and self.action == other.action
+ )
+
+ def details(self):
+ return f"Obs {str(self.index)}.\n State: {str(self.state)}\n Action: {str(self.action)}"
diff --git a/macq/observation/identity_observation.py b/macq/observation/identity_observation.py
index 16b3dd5c..65bc5ee5 100644
--- a/macq/observation/identity_observation.py
+++ b/macq/observation/identity_observation.py
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import Optional, List
-from ..trace import Step
+from ..trace import Step, State
from . import Observation, InvalidQueryParameter
@@ -11,6 +11,8 @@ class IdentityObservation(Observation):
class.
"""
+ state: State
+
class IdentityState(dict):
def __hash__(self):
return hash(tuple(sorted(self.items())))
@@ -42,15 +44,14 @@ def __init__(self, step: Step, **kwargs):
self.action = None if step.action is None else step.action.clone()
def __hash__(self):
- return hash(self.details())
+ return super().__hash__()
def __eq__(self, other):
- if not isinstance(other, IdentityObservation):
- return False
- return self.state == other.state and self.action == other.action
-
- def details(self):
- return f"Obs {str(self.index)}.\n State: {str(self.state)}\n Action: {str(self.action)}"
+ return (
+ isinstance(other, IdentityObservation)
+ and self.state == other.state
+ and self.action == other.action
+ )
def _matches(self, key: str, value: str):
if key == "action":
diff --git a/macq/observation/noisy_observation.py b/macq/observation/noisy_observation.py
index ab09ebdd..340eed7d 100644
--- a/macq/observation/noisy_observation.py
+++ b/macq/observation/noisy_observation.py
@@ -1,6 +1,7 @@
+import random
from . import Observation
from ..trace import Step
-from ..utils import PercentError#, extract_fluent_subset
+from ..utils import PercentError
class NoisyObservation(Observation):
@@ -13,7 +14,7 @@ class NoisyObservation(Observation):
"""
def __init__(
- self, step: Step, percent_noisy: float = 0):
+ self, step: Step, percent_noisy: float = 0, replace: bool = False):
"""
Creates an NoisyObservation object, storing the state and action.
@@ -22,19 +23,22 @@ def __init__(
The step associated with this observation.
percent_noisy (float):
The percentage of fluents to randomly make noisy in the observation.
+ replace (bool):
+ Option to replace noisy fluents with the values of other existing fluents instead
+ of just flipping their values.
"""
super().__init__(index=step.index)
if percent_noisy > 1 or percent_noisy < 0:
- raise PercentError()
+ raise PercentError()
- step = self.random_noisy_subset(step, percent_noisy)
+ step = self.random_noisy_subset(step, percent_noisy, replace)
self.state = step.state.clone()
self.action = None if step.action is None else step.action.clone()
- def random_noisy_subset(self, step: Step, percent_noisy: float):
+ def random_noisy_subset(self, step: Step, percent_noisy: float, replace: bool = False):
"""Generates a random subset of fluents corresponding to the percent provided
and flips their value to create noise.
@@ -43,6 +47,9 @@ def random_noisy_subset(self, step: Step, percent_noisy: float):
The step associated with this observation.
percent_noisy (float):
The percentage of fluents to randomly make noisy in the observation.
+ replace (bool):
+ Option to replace noisy fluents with the values of other existing fluents instead
+ of just flipping their values.
Returns:
A new `Step` with the noisy fluents in place.
@@ -54,8 +61,10 @@ def random_noisy_subset(self, step: Step, percent_noisy: float):
for f in invisible_f:
del visible_f[f]
noisy_f = self.extract_fluent_subset(visible_f, percent_noisy)
- for f in state:
- state[f] = not state[f] if f in noisy_f else state[f]
+ if not replace:
+ for f in state:
+ state[f] = not state[f] if f in noisy_f else state[f]
+ else:
+ for f in state:
+ state[f] = state[random.choice(list(visible_f.keys()))] if f in noisy_f else state[f]
return Step(state, step.action, step.index)
-
-
diff --git a/macq/observation/noisy_partial_disordered_parallel_observation.py b/macq/observation/noisy_partial_disordered_parallel_observation.py
index d5d22943..48821f36 100644
--- a/macq/observation/noisy_partial_disordered_parallel_observation.py
+++ b/macq/observation/noisy_partial_disordered_parallel_observation.py
@@ -10,7 +10,7 @@ class NoisyPartialDisorderedParallelObservation(NoisyPartialObservation):
disordered with an action in another parallel action set. Finally, a "parallel action set ID" is stored which indicates
which parallel action set the token is a part of. Inherits the NoisyPartialObservation token class.
"""
- def __init__(self, step: Step, par_act_set_ID: int, percent_missing: float = 0, hide: Set[Fluent] = None, percent_noisy: float = 0):
+ def __init__(self, step: Step, par_act_set_ID: int, percent_missing: float = 0, hide: Set[Fluent] = None, percent_noisy: float = 0, replace: bool = False):
"""
Creates an NoisyPartialDisorderedParallelObservation object.
@@ -25,6 +25,9 @@ def __init__(self, step: Step, par_act_set_ID: int, percent_missing: float = 0,
The set of fluents to explicitly hide in the observation.
percent_noisy (float):
The percentage of fluents to randomly make noisy in the observation.
+ replace (bool):
+ Option to replace noisy fluents with the values of other existing fluents instead
+ of just flipping their values.
"""
- super().__init__(step=step, percent_missing=percent_missing, hide=hide, percent_noisy=percent_noisy)
+ super().__init__(step=step, percent_missing=percent_missing, hide=hide, percent_noisy=percent_noisy, replace=replace)
self.par_act_set_ID = par_act_set_ID
diff --git a/macq/observation/noisy_partial_observation.py b/macq/observation/noisy_partial_observation.py
index 0b82e025..f3056cd3 100644
--- a/macq/observation/noisy_partial_observation.py
+++ b/macq/observation/noisy_partial_observation.py
@@ -13,7 +13,7 @@ class NoisyPartialObservation(PartialObservation, NoisyObservation):
"""
def __init__(
- self, step: Step, percent_missing: float = 0, hide: Set[Fluent] = None, percent_noisy: float = 0):
+ self, step: Step, percent_missing: float = 0, hide: Set[Fluent] = None, percent_noisy: float = 0, replace: bool = False):
"""
Creates an NoisyPartialObservation object.
@@ -26,8 +26,11 @@ def __init__(
The set of fluents to explicitly hide in the observation.
percent_noisy (float):
The percentage of fluents to randomly make noisy in the observation.
+ replace (bool):
+ Option to replace noisy fluents with the values of other existing fluents instead
+ of just flipping their values.
"""
# get state and action with missing fluents (updates self.state and self.action)
PartialObservation.__init__(self, step=step, percent_missing=percent_missing, hide=hide)
# get state and action with noisy fluents, using the updated state and action (then updates self.state and self.action)
- NoisyObservation.__init__(self, step=Step(self.state, self.action, step.index), percent_noisy=percent_noisy)
\ No newline at end of file
+ NoisyObservation.__init__(self, step=Step(self.state, self.action, step.index), percent_noisy=percent_noisy, replace=replace)
\ No newline at end of file
diff --git a/macq/observation/observation.py b/macq/observation/observation.py
index fb4cfeab..bfcc84fa 100644
--- a/macq/observation/observation.py
+++ b/macq/observation/observation.py
@@ -1,5 +1,8 @@
-from logging import warning
+from warnings import warn
from json import dumps
+from typing import Union
+
+from ..trace import State, Action
import random
from ..trace import State
@@ -23,6 +26,10 @@ class Observation:
The index of the associated step in the trace it is a part of.
"""
+ index: int
+ state: Union[State, None]
+ action: Union[Action, None]
+
def __init__(self, **kwargs):
"""
Creates an Observation object, storing the step as a token, as well as its index/"place"
@@ -35,12 +42,35 @@ def __init__(self, **kwargs):
if "index" in kwargs.keys():
self.index = kwargs["index"]
else:
- warning("Creating an Observation token without an index.")
+ warn("Creating an Observation token without an index.")
+
+ def __hash__(self):
+ string = str(self)
+ if string == "Observation\n":
+ warn("Observation has no unique information. Generating a generic hash.")
+ return hash(string)
+
+ def __str__(self):
+ out = "Observation\n"
+ if self.index is not None:
+ out += f" Index: {str(self.index)}\n"
+ if self.state:
+ out += f" State: {str(self.state)}\n"
+ if self.action:
+ out += f" Action: {str(self.action)}\n"
+
+ return out
+
+ def get_details(self):
+ ind = str(self.index) if self.index else "-"
+ state = self.state.details() if self.state else "-"
+ action = self.action.details() if self.action else ""
+ return (ind, state, action)
def _matches(self, *_):
raise NotImplementedError()
- def extract_fluent_subset(self, fluents: State, percent: float):
+ def extract_fluent_subset(self, state: State, percent: float):
"""Randomly extracts a subset of fluents from a state, according to the percentage given.
Args:
@@ -52,10 +82,10 @@ def extract_fluent_subset(self, fluents: State, percent: float):
Returns:
The random subset of fluents.
"""
- num_new_f = int(len(fluents) * (percent))
+ num_new_f = int(len(state) * (percent))
# shuffle keys and take an appropriate subset of them
- extracted_f = list(fluents)
+ extracted_f = list(state)
random.shuffle(extracted_f)
return extracted_f[:num_new_f]
diff --git a/macq/observation/observation_lists.py b/macq/observation/observation_lists.py
new file mode 100644
index 00000000..228e1ae9
--- /dev/null
+++ b/macq/observation/observation_lists.py
@@ -0,0 +1,312 @@
+from __future__ import annotations
+from collections import defaultdict
+from collections.abc import MutableSequence
+from warnings import warn
+from typing import Callable, Dict, List, Type, Set, TYPE_CHECKING
+from inspect import cleandoc
+from rich.console import Console
+from rich.table import Table
+from rich.text import Text
+
+from . import Observation
+from ..trace import Action, Fluent
+
+# Prevents circular importing
+if TYPE_CHECKING:
+ from macq.trace import TraceList
+
+
+class MissingToken(Exception):
+ def __init__(self, message=None):
+ if message is None:
+ message = (
+ f"Cannot create ObservationLists from a TraceList without a Token."
+ )
+ super().__init__(message)
+
+
+class TokenTypeMismatch(Exception):
+ def __init__(self, token, obs_type, message=None):
+ if message is None:
+ message = (
+ "Token type does not match observation tokens."
+ f"Token type: {token}"
+ f"Observation type: {obs_type}"
+ )
+ super().__init__(message)
+
+
+class ObservationLists(MutableSequence):
+ """A sequence of observations.
+
+ A `list`-like object, where each element is a list of `Observation`s.
+
+ Attributes:
+ observations (List[List[Observation]]):
+ The internal list of lists of `Observation` objects.
+ type (Type[Observation]):
+ The type (class) of the observations.
+ """
+
+ observations: List[List[Observation]]
+ type: Type[Observation]
+
+ def __init__(
+ self,
+ trace_list: TraceList = None,
+ Token: Type[Observation] = None,
+ observations: List[List[Observation]] = None,
+ **kwargs,
+ ):
+ if trace_list is not None:
+ if not Token and not observations:
+ raise MissingToken()
+
+ if Token:
+ self.type = Token
+
+ self.observations = []
+ self.tokenize(trace_list, **kwargs)
+
+ if observations:
+ self.extend(observations)
+ # Check that the observations are of the specified token type
+ if self.type and type(observations[0][0]) != self.type:
+ raise TokenTypeMismatch(self.type, type(observations[0][0]))
+ # If token type was not provided, infer it from the observations
+ elif not self.type:
+ self.type = type(observations[0][0])
+
+ elif observations:
+ self.observations = observations
+ self.type = type(observations[0][0])
+
+ else:
+ self.observations = []
+ self.type = Observation
+
+ def __getitem__(self, key: int):
+ return self.observations[key]
+
+ def __setitem__(self, key: int, value: List[Observation]):
+ self.observations[key] = value
+ if self.type == Observation:
+ self.type = type(value[0])
+ elif type(value[0]) != self.type:
+ raise TokenTypeMismatch(self.type, type(value[0]))
+
+ def __delitem__(self, key: int):
+ del self.observations[key]
+
+ def __iter__(self):
+ return iter(self.observations)
+
+ def __len__(self):
+ return len(self.observations)
+
+ def insert(self, key: int, value: List[Observation]):
+ self.observations.insert(key, value)
+ if self.type == Observation:
+ self.type = type(value[0])
+ elif type(value[0]) != self.type:
+ raise TokenTypeMismatch(self.type, type(value[0]))
+
+ def get_actions(self) -> Set[Action]:
+ actions: Set[Action] = set()
+ for obs_list in self:
+ for obs in obs_list:
+ action = obs.action
+ if action is not None:
+ actions.add(action)
+ return actions
+
+ def get_fluents(self) -> Set[Fluent]:
+ fluents: Set[Fluent] = set()
+ for obs_list in self:
+ for obs in obs_list:
+ if obs.state:
+ fluents.update(list(obs.state.keys()))
+ return fluents
+
+ def tokenize(self, trace_list: TraceList, **kwargs):
+ for trace in trace_list:
+ tokens = trace.tokenize(self.type, **kwargs)
+ self.append(tokens)
+
+ def fetch_observations(self, query: dict) -> List[Set[Observation]]:
+ matches: List[Set[Observation]] = []
+ for i, obs_list in enumerate(self.observations):
+ matches.append(set())
+ for obs in obs_list:
+ if obs.matches(query):
+ matches[i].add(obs)
+ return matches
+
+ def fetch_observation_windows(
+ self, query: dict, left: int, right: int
+ ) -> List[List[Observation]]:
+ windows = []
+ matches = self.fetch_observations(query)
+ for i, obs_set in enumerate(matches):
+ for obs in obs_set:
+ # NOTE: obs.index starts at 1
+ start = obs.index - left - 1
+ end = obs.index + right
+ windows.append(self[i][start:end])
+ return windows
+
+ def get_transitions(self, action: str) -> List[List[Observation]]:
+ query = {"action": action}
+ return self.fetch_observation_windows(query, 0, 1)
+
+ def get_all_transitions(self) -> Dict[Action, List[List[Observation]]]:
+ actions = self.get_actions()
+ try:
+ return {
+ action: self.get_transitions(action.details()) for action in actions
+ }
+ except AttributeError:
+ return {action: self.get_transitions(str(action)) for action in actions}
+
+ def print(self, view="details", filter_func=lambda _: True, wrap=None):
+ """Pretty prints the trace list in the specified view.
+
+ Arguments:
+ view ("details" | "color"):
+ Specifies the view format to print in. "details" provides a
+ detailed summary of each step in a trace. "color" provides a
+ color grid, mapping fluents in a step to either red or green
+ corresponding to the truth value.
+ filter_func (function):
+ Optional; Used to filter which fluents are printed in the
+ colorgrid display.
+ wrap (bool):
+ Determines if the output is wrapped or cut off. Details defaults
+ to cut off (wrap=False), color defaults to wrap (wrap=True).
+ """
+ console = Console()
+
+ views = ["details", "color"]
+ if view not in views:
+ warn(f'Invalid view {view}. Defaulting to "details".')
+ view = "details"
+
+ obs_lists = []
+ if view == "details":
+ if wrap is None:
+ wrap = False
+ obs_lists = [self._details(obs_list, wrap=wrap) for obs_list in self]
+
+ elif view == "color":
+ if wrap is None:
+ wrap = True
+ obs_lists = [
+ self._colorgrid(obs_list, filter_func=filter_func, wrap=wrap)
+ for obs_list in self
+ ]
+
+ for obs_list in obs_lists:
+ console.print(obs_list)
+ print()
+
+ def _details(self, obs_list: List[Observation], wrap: bool):
+ indent = " " * 2
+ # Summarize class attributes
+ details = Table.grid(expand=True)
+ details.title = "Trace"
+ details.add_column()
+ details.add_row(
+ cleandoc(
+ f"""
+ Attributes:
+ {indent}{len(obs_list)} steps
+ {indent}{len(self.get_fluents())} fluents
+ """
+ )
+ )
+ steps = Table(
+ title="Steps", box=None, show_edge=False, pad_edge=False, expand=True
+ )
+ steps.add_column("Step", justify="right", width=8)
+ steps.add_column(
+ "State",
+ justify="center",
+ overflow="ellipsis",
+ max_width=100,
+ no_wrap=(not wrap),
+ )
+ steps.add_column("Action", overflow="ellipsis", no_wrap=(not wrap))
+
+ for obs in obs_list:
+ ind, state, action = obs.get_details()
+ steps.add_row(ind, state, action)
+
+ details.add_row(steps)
+
+ return details
+
+ @staticmethod
+ def _colorgrid(obs_list: List[Observation], filter_func: Callable, wrap: bool):
+ colorgrid = Table(
+ title="Trace", box=None, show_edge=False, pad_edge=False, expand=False
+ )
+ colorgrid.add_column("Fluent", justify="right")
+ colorgrid.add_column(
+ header=Text("Step", justify="center"), overflow="fold", no_wrap=(not wrap)
+ )
+ colorgrid.add_row(
+ "",
+ "".join(
+ [
+ "|" if i < len(obs_list) and (i + 1) % 5 == 0 else " "
+ for i in range(len(obs_list))
+ ]
+ ),
+ )
+
+ static = ObservationLists.get_obs_static_fluents(obs_list)
+ fluents = list(
+ filter(
+ filter_func,
+ sorted(
+ ObservationLists.get_obs_fluents(obs_list),
+ key=lambda f: float("inf") if f in static else len(str(f)),
+ ),
+ )
+ )
+
+ for fluent in fluents:
+ step_str = ""
+ for obs in obs_list:
+ if obs.state and obs.state[fluent]:
+ step_str += "[green]"
+ else:
+ step_str += "[red]"
+ step_str += "■"
+
+ colorgrid.add_row(str(fluent), step_str)
+
+ return colorgrid
+
+ @staticmethod
+ def get_obs_fluents(obs_list: List[Observation]):
+ fluents = set()
+ for obs in obs_list:
+ if obs.state:
+ fluents.update(list(obs.state.keys()))
+ return fluents
+
+ @staticmethod
+ def get_obs_static_fluents(obs_list: List[Observation]):
+ fstates = defaultdict(list)
+ for obs in obs_list:
+ if obs.state:
+ for f, v in obs.state.items():
+ fstates[f].append(v)
+
+ static = set()
+ for f, states in fstates.items():
+ if all(states) or not any(states):
+ static.add(f)
+
+ return static
diff --git a/macq/observation/partial_observation.py b/macq/observation/partial_observation.py
index 4b37b812..41e7fa01 100644
--- a/macq/observation/partial_observation.py
+++ b/macq/observation/partial_observation.py
@@ -1,14 +1,13 @@
-from logging import warning
-from ..utils import PercentError#, extract_fluent_subset
+from warnings import warn
+from typing import Set
+from ..utils import PercentError
from ..trace import Step, Fluent
from ..trace import PartialState
from . import Observation, InvalidQueryParameter
-from typing import Set
class PartialObservation(Observation):
"""The Partial Observability Token.
-
The partial observability token stores the step where some of the values of
the fluents in the step's state are unknown. Inherits the base Observation
class.
@@ -18,8 +17,7 @@ def __init__(
self, step: Step, percent_missing: float = 0, hide: Set[Fluent] = None
):
"""
- Creates an PartialObservation object, storing the step.
-
+ Creates a PartialObservation object, storing the step.
Args:
step (Step):
The step associated with this observation.
@@ -32,9 +30,9 @@ def __init__(
raise PercentError()
if percent_missing == 0 and not hide:
- warning("Creating a PartialObseration with no missing information.")
+ warn("Creating a PartialObseration with no missing information.")
- # necessary because multiple inheritance can change the parent of this class
+ # NOTE: Can't use super due to multiple inheritence (NoisyPartialObservation)
Observation.__init__(self, index=step.index)
# If percent_missing == 1 -> self.state = None (below).
@@ -58,33 +56,27 @@ def __eq__(self, other):
def hide_random_subset(self, step: Step, percent_missing: float):
"""Hides a random subset of the fluents in the step.
-
Args:
step (Step):
The step to tokenize.
percent_missing (float):
The percentage of fluents to hide (0-1).
-
Returns:
A Step whose state is a PartialState with the random fluents hidden.
"""
new_fluents = {}
- fluents = step.state.fluents
- hidden_f = self.extract_fluent_subset(fluents, percent_missing)
- # get new dict
- for f in fluents:
+ hidden_f = self.extract_fluent_subset(step.state, percent_missing)
+ for f in step.state:
new_fluents[f] = None if f in hidden_f else step.state[f]
return Step(PartialState(new_fluents), step.action, step.index)
def hide_subset(self, step: Step, hide: Set[Fluent]):
"""Hides the specified set of fluents in the observation.
-
Args:
step (Step):
The step to tokenize.
hide (Set[Fluent]):
The set of fluents that will be hidden.
-
Returns:
A Step whose state is a PartialState with the specified fluents hidden.
"""
@@ -99,6 +91,8 @@ def _matches(self, key: str, value: str):
return value is None
return self.action.details() == value
elif key == "fluent_holds":
+ if self.state is None:
+ return value is None
return self.state.holds(value)
else:
raise InvalidQueryParameter(PartialObservation, key)
diff --git a/macq/trace/__init__.py b/macq/trace/__init__.py
index 5d25627d..08b41978 100644
--- a/macq/trace/__init__.py
+++ b/macq/trace/__init__.py
@@ -5,8 +5,9 @@
from .step import Step
from .trace import Trace, SAS
from .trace_list import TraceList
-from .observation_lists import ObservationLists
-from .disordered_parallel_actions_observation_lists import DisorderedParallelActionsObservationLists
+from .disordered_parallel_actions_observation_lists import (
+ DisorderedParallelActionsObservationLists, ActionPair
+)
__all__ = [
@@ -19,6 +20,6 @@
"Trace",
"SAS",
"TraceList",
- "ObservationLists",
- "DisorderedParallelActionsObservationLists"
+ "DisorderedParallelActionsObservationLists",
+ "ActionPair"
]
diff --git a/macq/trace/action.py b/macq/trace/action.py
index d34bc52f..33760be7 100644
--- a/macq/trace/action.py
+++ b/macq/trace/action.py
@@ -1,5 +1,5 @@
from typing import List, Set
-from .fluent import Fluent, PlanningObject
+from .fluent import PlanningObject, Fluent
class Action:
@@ -12,16 +12,16 @@ class Action:
Attributes:
name (str):
The name of the action.
- obj_params (list):
- The list of objects the action acts on.
+ obj_params (List[PlanningObject]):
+ The set of objects the action acts on.
cost (int):
The cost to perform the action.
precond (Set[Fluent]):
- Optional; The set of Fluents that make up the precondition.
+ The set of Fluents that make up the precondition.
add (Set[Fluent]):
- Optional; The set of Fluents that make up the add effects.
+ The set of Fluents that make up the add effects.
delete (Set[Fluent]):
- Optional; The set of Fluents that make up the delete effects.
+ The set of Fluents that make up the delete effects.
"""
def __init__(
@@ -31,12 +31,11 @@ def __init__(
cost: int = 0,
precond: Set[Fluent] = None,
add: Set[Fluent] = None,
- delete: Set[Fluent] = None
+ delete: Set[Fluent] = None,
):
"""Initializes an Action with the parameters provided.
The `precond`, `add`, and `delete` args should only be provided in
Model deserialization.
-
Args:
name (str):
The name of the action.
@@ -58,7 +57,7 @@ def __init__(
self.add = add
self.delete = delete
- def __str__(self):
+ def __repr__(self):
string = f"{self.name} {' '.join(map(str, self.obj_params))}"
return string
@@ -77,17 +76,22 @@ def details(self):
string = f"{self.name} {' '.join([o.details() for o in self.obj_params])}"
return string
- def clone(self):
- return Action(self.name, self.obj_params, self.cost)
+ def clone(self, atomic=False):
+ if atomic:
+ return AtomicAction(
+ self.name, list(map(lambda o: o.details(), self.obj_params)), self.cost
+ )
- def add_parameter(self, obj: PlanningObject):
- """Adds an object to the action's parameters.
-
- Args:
- obj (PlanningObject):
- The object to be added to the action's object parameters.
- """
- self.obj_params.append(obj)
+ return Action(self.name, self.obj_params.copy(), self.cost)
def _serialize(self):
return self.name
+
+
+class AtomicAction(Action):
+ """An Action where the objects are represented by strings."""
+
+ def __init__(self, name: str, obj_params: List[str], cost: int = 0):
+ self.name = name
+ self.obj_params = obj_params
+ self.cost = cost
diff --git a/macq/trace/disordered_parallel_actions_observation_lists.py b/macq/trace/disordered_parallel_actions_observation_lists.py
index fbfa7582..451f5d5e 100644
--- a/macq/trace/disordered_parallel_actions_observation_lists.py
+++ b/macq/trace/disordered_parallel_actions_observation_lists.py
@@ -3,15 +3,17 @@
from numpy import dot
from random import random
from typing import Callable, Type, Set, List
-from . import ObservationLists, TraceList, Step, Action
-from ..observation import Observation
+from . import TraceList, Step, Action, PartialState, State
+from ..observation import Observation, ObservationLists
+
@dataclass
class ActionPair:
"""dataclass that allows a pair of actions to be referenced regardless of order
(that is, {action1, action2} is equivalent to {action2, action1}.)
"""
- actions : Set[Action]
+
+ actions: Set[Action]
def tup(self):
actions = list(self.actions)
@@ -24,6 +26,12 @@ def __hash__(self):
sum += hash(a.details())
return sum
+ def __repr__(self):
+ string = ""
+ for a in self.actions:
+ string += a.details() + ", "
+ return string[:-1]
+
def default_theta_vec(k : int):
"""Generate the default theta vector to be used in the calculation that extracts the probability of
actions being disordered; used to "weight" the features.
@@ -35,7 +43,8 @@ def default_theta_vec(k : int):
Returns:
The default theta vector.
"""
- return [(1/k)] * k
+ return [(1 / k)] * k
+
def objects_shared_feature(act_x: Action, act_y: Action):
"""Corresponds to default feature 1 from the AMDN paper.
@@ -51,11 +60,12 @@ def objects_shared_feature(act_x: Action, act_y: Action):
"""
num_shared = 0
for obj in act_x.obj_params:
- for other_obj in act_y. obj_params:
+ for other_obj in act_y.obj_params:
if obj == other_obj:
num_shared += 1
return num_shared
+
def num_parameters_feature(act_x: Action, act_y: Action):
"""Corresponds to default feature 2 from the AMDN paper.
@@ -70,6 +80,7 @@ def num_parameters_feature(act_x: Action, act_y: Action):
"""
return 1 if len(act_x.obj_params) == len(act_y.obj_params) else 0
+
def _decision(probability: float):
"""Makes a decision based on the given probability.
@@ -82,30 +93,49 @@ def _decision(probability: float):
"""
return random() < probability
+
class DisorderedParallelActionsObservationLists(ObservationLists):
- """Alternate ObservationLists type that enforces appropriate actions to be disordered and/or parallel.
+ """Alternate ObservationLists type that enforces appropriate actions to be disordered and/or parallel.
Inherits the base ObservationLists class.
The default feature functions and theta vector described in the AMDN paper are available for use in this module.
-
+
Attributes:
- traces (List[List[Token]]):
+ observations (List[List[Token]]):
The trace list converted to a list of lists of tokens.
+ type (Type[Observation]):
+ The type of token to be used.
all_par_act_sets (List[List[Set[Action]]]):
Holds the parallel action sets for all traces.
+ all_states (List(List[State])):
+ Holds the relevant states for all traces. Note that these are RELATIVE to the parallel action sets and only
+ contain the states between the sets.
features (List[Callable]):
The list of functions to be used to create the feature vector.
learned_theta (List[float]):
The supplied theta vector.
actions (List[Action]):
The list of all actions used in the traces given (no duplicates).
+ propositions (Set[Fluent]):
+ The set of all fluents.
cross_actions (List[ActionPair]):
The list of all possible `ActionPairs`.
+ denominator (float):
+ The value used for the denominator in all probability calculations (stored so it doesn't need to be recalculated
+ each time).
probabilities (Dict[ActionPair, float]):
A dictionary that contains a mapping of each possible `ActionPair` and the probability that the actions
in them are disordered.
"""
- def __init__(self, traces: TraceList, Token: Type[Observation], features: List[Callable], learned_theta: List[float], **kwargs):
+
+ def __init__(
+ self,
+ traces: TraceList,
+ Token: Type[Observation],
+ features: List[Callable],
+ learned_theta: List[float],
+ **kwargs
+ ):
"""AI is creating summary for __init__
Args:
@@ -120,20 +150,27 @@ def __init__(self, traces: TraceList, Token: Type[Observation], features: List[C
**kwargs:
Any extra arguments to be supplied to the Token __init__.
"""
- self.traces = []
+ self.observations = []
+ self.type = Token
self.all_par_act_sets = []
+ self.all_states = []
self.features = features
self.learned_theta = learned_theta
actions = {step.action for trace in traces for step in trace if step.action}
# cast to list for iteration purposes
self.actions = list(actions)
+ # set of all fluents
+ self.propositions = {f for trace in traces for step in trace for f in step.state.fluents}
# create |A| (action x action set, no duplicates)
self.cross_actions = [ActionPair({self.actions[i], self.actions[j]}) for i in range(len(self.actions)) for j in range(i + 1, len(self.actions))]
+ self.denominator = self._calculate_denom()
# dictionary that holds the probabilities of all actions being disordered
self.probabilities = self._calculate_all_probabilities()
self.tokenize(traces, Token, **kwargs)
- def _theta_dot_features_calc(self, f_vec: List[float], theta_vec: List[float]):
+
+ @staticmethod
+ def _theta_dot_features_calc(f_vec: List[float], theta_vec: List[float]):
"""Calculate the dot product of the feature vector and the theta vector, then use that as an exponent
for 'e'.
@@ -148,6 +185,15 @@ def _theta_dot_features_calc(self, f_vec: List[float], theta_vec: List[float]):
"""
return exp(dot(f_vec, theta_vec))
+ def _calculate_denom(self):
+ """
+ Calculates and returns the denominator used in probability calculations.
+ """
+ denominator = 0
+ for combo in self.cross_actions:
+ denominator += self._theta_dot_features_calc(self._get_f_vec(*combo.tup()), self.learned_theta)
+ return denominator
+
def _get_f_vec(self, act_x: Action, act_y: Action):
"""Returns the feature vector.
@@ -176,12 +222,8 @@ def _calculate_probability(self, act_x: Action, act_y: Action):
"""
# calculate the probability of two given actions being disordered
f_vec = self._get_f_vec(act_x, act_y)
- theta_vec = self.learned_theta
- numerator = self._theta_dot_features_calc(f_vec, theta_vec)
- denominator = 0
- for combo in self.cross_actions:
- denominator += self._theta_dot_features_calc(self._get_f_vec(*combo.tup()), theta_vec)
- return numerator/denominator
+ numerator = self._theta_dot_features_calc(f_vec, self.learned_theta)
+ return numerator/self.denominator
def _calculate_all_probabilities(self):
"""Calculates the probabilities of all combinations of actions being disordered.
@@ -197,6 +239,25 @@ def _calculate_all_probabilities(self):
probabilities[combo] = self._calculate_probability(*combo.tup())
return probabilities
+ def _get_new_partial_state(self):
+ """
+ Return a PartialState with the fluents used in this observation, with each fluent set to None as default.
+ """
+ cur_state = PartialState()
+ for f in self.propositions:
+ cur_state[f] = None
+ return cur_state
+
+ def _update_partial_state(self, partial_state: PartialState, orig_state: State, action: Action):
+ """
+ Update the provided PartialState with the fluents provided.
+ """
+ new_partial = partial_state.copy()
+ effects = set([e for e in action.add] + [e for e in action.delete])
+ for e in effects:
+ new_partial[e] = orig_state[e]
+ return new_partial
+
def tokenize(self, traces: TraceList, Token: Type[Observation], **kwargs):
"""Main driver that handles the tokenization process.
@@ -209,26 +270,32 @@ def tokenize(self, traces: TraceList, Token: Type[Observation], **kwargs):
Any extra arguments to be supplied to the Token __init__.
"""
# build parallel action sets
- for trace in traces:
+ for trace in traces:
par_act_sets = []
states = []
cur_par_act = set()
cur_par_act_conditions = set()
# add initial state
states.append(trace[0].state)
- # for the compiler
- cur_state = trace[1].state
+
+ cur_state = self._get_new_partial_state()
# last step doesn't have an action/just contains the state after the last action
for i in range(len(trace)):
a = trace[i].action
if a:
- a_conditions = set([p for p in a.precond] + [e for e in a.add] + [e for e in a.delete])
+ a_conditions = set(
+ [p for p in a.precond]
+ + [e for e in a.add]
+ + [e for e in a.delete]
+ )
# if the action has any conditions in common with any actions in the previous parallel set (NOT parallel)
- if a_conditions.intersection(cur_par_act_conditions) != set():
+ if a_conditions.intersection(cur_par_act_conditions) != set():
# add psi_k and s'_k to the final (ordered) lists of parallel action sets and states
- par_act_sets.append(cur_par_act)
+ par_act_sets.append(cur_par_act)
states.append(cur_state)
+ # reset the state
+ cur_state = self._get_new_partial_state()
# reset psi_k (that is, create a new parallel action set)
cur_par_act = set()
# reset the conditions
@@ -236,11 +303,11 @@ def tokenize(self, traces: TraceList, Token: Type[Observation], **kwargs):
# add the action and state to the appropriate psi_k and s'_k (either the existing ones, or
# new/empty ones if the current action is NOT parallel with actions in the previous set of actions.)
cur_par_act.add(a)
- cur_state = trace[i + 1].state
+ cur_state = self._update_partial_state(cur_state, trace[i + 1].state, trace[i].action)
cur_par_act_conditions.update(a_conditions)
# if on the last step of the trace, add the current set/state to the final result before exiting the loop
if i == len(trace) - 1:
- par_act_sets.append(cur_par_act)
+ par_act_sets.append(cur_par_act)
states.append(cur_state)
# generate disordered actions - do trace by trace
@@ -252,16 +319,21 @@ def tokenize(self, traces: TraceList, Token: Type[Observation], **kwargs):
for act_y in par_act_sets[j]:
if act_x != act_y:
# get probability and divide by distance
- prob = self.probabilities[ActionPair({act_x, act_y})]/(j - i)
+ prob = self.probabilities[
+ ActionPair({act_x, act_y})
+ ] / (j - i)
if _decision(prob):
par_act_sets[i].discard(act_x)
par_act_sets[i].add(act_y)
par_act_sets[j].discard(act_y)
par_act_sets[j].add(act_x)
+
self.all_par_act_sets.append(par_act_sets)
+ self.all_states.append(states)
tokens = []
for i in range(len(par_act_sets)):
for act in par_act_sets[i]:
- tokens.append(Token(Step(state=states[i], action=act, index=i), par_act_set_ID = i, **kwargs))
- self.append(tokens)
-
\ No newline at end of file
+ tokens.append(Token(Step(state=states[i], action=act, index=i), par_act_set_ID=i, **kwargs))
+ # add the final token, with the final state but no action
+ tokens.append(Token(Step(state=states[-1], action=None, index=len(par_act_sets)), par_act_set_ID=len(par_act_sets), **kwargs))
+ self.append(tokens)
\ No newline at end of file
diff --git a/macq/trace/fluent.py b/macq/trace/fluent.py
index d1581e3e..ff96f1fd 100644
--- a/macq/trace/fluent.py
+++ b/macq/trace/fluent.py
@@ -34,6 +34,9 @@ def __eq__(self, other):
def details(self):
return " ".join([self.obj_type, self.name])
+ def __repr__(self):
+ return self.details()
+
def _serialize(self):
return self.details()
@@ -66,7 +69,7 @@ def __hash__(self):
# Order of objects is important!
return hash(str(self))
- def __str__(self):
+ def __repr__(self):
return f"({self.name} {' '.join([o.details() for o in self.objects])})"
def __eq__(self, other):
diff --git a/macq/trace/observation_lists.py b/macq/trace/observation_lists.py
deleted file mode 100644
index fab452b5..00000000
--- a/macq/trace/observation_lists.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import macq.trace as TraceAPI
-from typing import List, Set, Type
-from ..observation import Observation
-from . import Trace
-
-class ObservationLists(TraceAPI.TraceList):
- traces: List[List[Observation]]
- # Disable methods
- generate_more = property()
- get_usage = property()
- tokenize = property()
- get_fluents = property()
-
- def __init__(self, traces: TraceAPI.TraceList, Token: Type[Observation], **kwargs):
- self.traces = []
- self.type = Token
- self.tokenize(traces, **kwargs)
-
- def tokenize(self, traces: TraceAPI.TraceList, **kwargs):
- trace: Trace
- for trace in traces:
- tokens = trace.tokenize(self.type, **kwargs)
- self.append(tokens)
-
- def fetch_observations(self, query: dict):
- matches: List[Set[Observation]] = list()
- trace: List[Observation]
- for i, trace in enumerate(self):
- matches.append(set())
- for obs in trace:
- if obs.matches(query): # if no matches, set can be empty
- matches[i].add(obs)
- return matches # list of sets of matching fluents from each trace
-
- def fetch_observation_windows(self, query: dict, left: int, right: int):
- windows = []
- matches = self.fetch_observations(query)
- trace: Set[Observation]
- for i, trace in enumerate(matches): # note obs.index starts at 1 (index = i+1)
- for obs in trace:
- start = obs.index - left - 1
- end = obs.index + right
- windows.append(self[i][start:end])
- return windows
-
- def get_transitions(self, action: str):
- query = {"action": action}
- return self.fetch_observation_windows(query, 0, 1)
-
- def get_all_transitions(self):
- actions = set()
- for trace in self:
- for obs in trace:
- action = obs.action
- if action:
- actions.add(action)
- # Actions in the observations can be either Action objects or strings depending on the type of observation
- try:
- return {
- action: self.get_transitions(action.details()) for action in actions
- }
- except AttributeError:
- return {action: self.get_transitions(str(action)) for action in actions}
\ No newline at end of file
diff --git a/macq/trace/partial_state.py b/macq/trace/partial_state.py
index f4dc18d6..737e8f98 100644
--- a/macq/trace/partial_state.py
+++ b/macq/trace/partial_state.py
@@ -1,16 +1,16 @@
from . import State
from . import Fluent
-from typing import Dict
+from typing import Dict, Union
class PartialState(State):
"""A Partial State where the value of some fluents are unknown."""
- def __init__(self, fluents: Dict[Fluent, bool] = {}):
+ def __init__(self, fluents: Dict[Fluent, Union[bool, None]] = {}):
"""
Args:
fluents (dict):
Optional; A mapping of `Fluent` objects to their value in this
state. Defaults to an empty `dict`.
"""
- super().__init__(fluents)
+ self.fluents = fluents
diff --git a/macq/trace/state.py b/macq/trace/state.py
index c1b08405..8378d621 100644
--- a/macq/trace/state.py
+++ b/macq/trace/state.py
@@ -1,5 +1,4 @@
from __future__ import annotations
-from dataclasses import dataclass
from typing import Dict
from rich.text import Text
from . import Fluent
@@ -16,7 +15,7 @@ class State:
A mapping of `Fluent` objects to their value in this state.
"""
- def __init__(self, fluents: Dict[Fluent, bool] = {}):
+ def __init__(self, fluents: Dict[Fluent, bool] = None):
"""Initializes State with an optional fluent-value mapping.
Args:
@@ -24,12 +23,10 @@ def __init__(self, fluents: Dict[Fluent, bool] = {}):
Optional; A mapping of `Fluent` objects to their value in this
state. Defaults to an empty `dict`.
"""
- self.fluents = fluents
+ self.fluents = fluents if fluents is not None else {}
def __eq__(self, other):
- if not isinstance(other, State):
- return False
- return self.fluents == other.fluents
+ return isinstance(other, State) and self.fluents == other.fluents
def __str__(self):
return ", ".join([str(fluent) for (fluent, value) in self.items() if value])
@@ -84,10 +81,19 @@ def details(self):
string.append(", ")
return string[:-2]
- def clone(self):
+ def clone(self, atomic=False):
+ if atomic:
+ return AtomicState({str(fluent): value for fluent, value in self.items()})
return State(self.fluents.copy())
def holds(self, fluent: str):
fluents = dict(map(lambda f: (f.name, f), self.keys()))
if fluent in fluents.keys():
return self[fluents[fluent]]
+
+
+class AtomicState(State):
+ """A State where the fluents are represented by strings."""
+
+ def __init__(self, fluents: Dict[str, bool] = None):
+ self.fluents = fluents if fluents is not None else {}
diff --git a/macq/trace/trace.py b/macq/trace/trace.py
index a597126b..314bf1b3 100644
--- a/macq/trace/trace.py
+++ b/macq/trace/trace.py
@@ -54,6 +54,9 @@ def __init__(self, steps: List[Step] = None):
self.steps = steps if steps is not None else []
self.__reinit_actions_and_fluents()
+ def __eq__(self, other):
+ return isinstance(other, Trace) and self.steps == other.steps
+
def __len__(self):
return len(self.steps)
@@ -262,7 +265,7 @@ def get_post_states(self, action: Action):
post_states.add(self[i + 1].state)
return post_states
- def get_sas_triples(self, action: Action) -> Set[SAS]:
+ def get_sas_triples(self, action: Action) -> List[SAS]:
"""Retrieves the list of (S,A,S') triples for the action in this trace.
In a (S,A,S') triple, S is the pre-state, A is the action, and S' is
@@ -276,11 +279,11 @@ def get_sas_triples(self, action: Action) -> Set[SAS]:
A `SAS` object, containing the `pre_state`, `action`, and
`post_state`.
"""
- sas_triples = set()
+ sas_triples = []
for i, step in enumerate(self):
if step.action == action:
triple = SAS(step.state, action, self[i + 1].state)
- sas_triples.add(triple)
+ sas_triples.append(triple)
return sas_triples
def get_total_cost(self):
diff --git a/macq/trace/trace_list.py b/macq/trace/trace_list.py
index a154fdb3..e7dd3cd6 100644
--- a/macq/trace/trace_list.py
+++ b/macq/trace/trace_list.py
@@ -1,28 +1,26 @@
+from warnings import warn
+from typing import List, Callable, Type, Optional, Union
from logging import warn
-from typing import List, Callable, Type, Optional, Set
+from collections.abc import MutableSequence
from rich.console import Console
-from . import Action, Trace
-from ..observation import Observation
-import macq.trace as TraceAPI
+from . import Action, Trace
+from ..observation import Observation, ObservationLists
-class TraceList:
- """A collection of traces.
+class TraceList(MutableSequence):
+ """A sequence of traces.
A `list`-like object, where each element is a `Trace` of the same planning
problem.
Attributes:
- traces (list):
+ traces (List[Trace]):
The list of `Trace` objects.
- generator (function | None):
+ generator (Callable | None):
The function used to generate the traces.
"""
- # Allow child classes to have traces as a list of any type
- traces: List
-
class MissingGenerator(Exception):
def __init__(
self,
@@ -33,107 +31,49 @@ def __init__(
self.message = message
super().__init__(message)
+ traces: List[Trace]
+ generator: Union[Callable, None]
+
def __init__(
self,
traces: List[Trace] = None,
- generator: Optional[Callable] = None,
+ generator: Callable = None,
):
"""Initializes a TraceList with a list of traces and a generator.
Args:
- traces (list):
+ traces (List[Trace]):
Optional; The list of `Trace` objects.
- generator (function):
+ generator (Callable):
Optional; The function used to generate the traces.
"""
self.traces = [] if traces is None else traces
self.generator = generator
- def __len__(self):
- return len(self.traces)
+ def __getitem__(self, key: int):
+ return self.traces[key]
def __setitem__(self, key: int, value: Trace):
self.traces[key] = value
- def __getitem__(self, key: int):
- return self.traces[key]
-
def __delitem__(self, key: int):
del self.traces[key]
def __iter__(self):
return iter(self.traces)
- def __reversed__(self):
- return reversed(self.traces)
-
- def __contains__(self, item):
- return item in self.traces
-
- def append(self, item):
- self.traces.append(item)
-
- def clear(self):
- self.traces.clear()
+ def __len__(self):
+ return len(self.traces)
def copy(self):
return self.traces.copy()
- def extend(self, iterable):
- self.traces.extend(iterable)
-
- def index(self, value):
- return self.traces.index(value)
-
- def insert(self, index: int, item):
- self.traces.insert(index, item)
-
- def pop(self):
- return self.traces.pop()
-
- def remove(self, value):
- self.traces.remove(value)
-
- def reverse(self):
- self.traces.reverse()
+ def insert(self, key: int, value: Trace):
+ self.traces.insert(key, value)
def sort(self, reverse: bool = False, key: Callable = lambda e: e.get_total_cost()):
self.traces.sort(reverse=reverse, key=key)
- def print(self, view="details", filter_func=lambda _: True, wrap=None):
- """Pretty prints the trace list in the specified view.
-
- Arguments:
- view ("details" | "color"):
- Specifies the view format to print in. "details" provides a
- detailed summary of each step in a trace. "color" provides a
- color grid, mapping fluents in a step to either red or green
- corresponding to the truth value.
- """
- console = Console()
-
- views = ["details", "color"]
- if view not in views:
- warn(f'Invalid view {view}. Defaulting to "details".')
- view = "details"
-
- traces = []
- if view == "details":
- if wrap is None:
- wrap = False
- traces = [trace.details(wrap=wrap) for trace in self]
-
- elif view == "color":
- if wrap is None:
- wrap = True
- traces = [
- trace.colorgrid(filter_func=filter_func, wrap=wrap) for trace in self
- ]
-
- for trace in traces:
- console.print(trace)
- print()
-
def generate_more(self, num: int):
"""Generates more traces using the generator function.
@@ -179,17 +119,53 @@ def get_fluents(self):
fluents.update(step.state.fluents)
return fluents
- def tokenize(self, Token: Type[Observation], ObsLists = None, **kwargs):
+ def tokenize(
+ self,
+ Token: Type[Observation],
+ ObsLists: Type[ObservationLists] = ObservationLists,
+ **kwargs,
+ ):
"""Tokenizes the steps in this trace.
Args:
Token (Observation):
A subclass of `Observation`, defining the method of tokenization
for the steps.
- ObsLists (Type[TraceAPI.ObservationLists]):
+ ObsLists (Type[ObservationLists]):
The type of `ObservationLists` to be used. Defaults to the base `ObservationLists`.
"""
- ObsLists : Type[TraceAPI.ObservationLists]
- if not ObsLists:
- ObsLists = TraceAPI.ObservationLists
- return ObsLists(self, Token, **kwargs)
\ No newline at end of file
+ return ObsLists(self, Token, **kwargs)
+
+ def print(self, view="details", filter_func=lambda _: True, wrap=None):
+ """Pretty prints the trace list in the specified view.
+
+ Arguments:
+ view ("details" | "color"):
+ Specifies the view format to print in. "details" provides a
+ detailed summary of each step in a trace. "color" provides a
+ color grid, mapping fluents in a step to either red or green
+ corresponding to the truth value.
+ """
+ console = Console()
+
+ views = ["details", "color"]
+ if view not in views:
+ warn(f'Invalid view {view}. Defaulting to "details".')
+ view = "details"
+
+ traces = []
+ if view == "details":
+ if wrap is None:
+ wrap = False
+ traces = [trace.details(wrap=wrap) for trace in self]
+
+ elif view == "color":
+ if wrap is None:
+ wrap = True
+ traces = [
+ trace.colorgrid(filter_func=filter_func, wrap=wrap) for trace in self
+ ]
+
+ for trace in traces:
+ console.print(trace)
+ print()
diff --git a/macq/utils/__init__.py b/macq/utils/__init__.py
index 84b24c99..f4adc905 100644
--- a/macq/utils/__init__.py
+++ b/macq/utils/__init__.py
@@ -1,9 +1,24 @@
-from .timer import set_timer_throw_exc, basic_timer, TraceSearchTimeOut
+from .timer import set_timer_throw_exc, basic_timer, TraceSearchTimeOut, InvalidTime
from .complex_encoder import ComplexEncoder
from .common_errors import PercentError
from .trace_errors import InvalidPlanLength, InvalidNumberOfTraces
from .trace_utils import set_num_traces, set_plan_length
from .tokenization_errors import TokenizationError
-#from .tokenization_utils import extract_fluent_subset
+from .progress import progress
-__all__ = ["set_timer_throw_exc", "basic_timer", "TraceSearchTimeOut", "ComplexEncoder", "PercentError", "set_num_traces", "set_plan_length", "InvalidPlanLength", "InvalidNumberOfTraces", "TokenizationError",]# "extract_fluent_subset"]
+# from .tokenization_utils import extract_fluent_subset
+
+__all__ = [
+ "set_timer_throw_exc",
+ "basic_timer",
+ "TraceSearchTimeOut",
+ "InvalidTime",
+ "ComplexEncoder",
+ "PercentError",
+ "set_num_traces",
+ "set_plan_length",
+ "InvalidPlanLength",
+ "InvalidNumberOfTraces",
+ "TokenizationError",
+ "progress",
+]
diff --git a/macq/utils/common_errors.py b/macq/utils/common_errors.py
index 681bcc9d..8c8b046f 100644
--- a/macq/utils/common_errors.py
+++ b/macq/utils/common_errors.py
@@ -5,4 +5,4 @@ def __init__(
self,
message="The percentage supplied is invalid.",
):
- super().__init__(message)
\ No newline at end of file
+ super().__init__(message)
diff --git a/macq/utils/progress.py b/macq/utils/progress.py
new file mode 100644
index 00000000..341a303f
--- /dev/null
+++ b/macq/utils/progress.py
@@ -0,0 +1,60 @@
+from typing import Iterable, Iterator, Sized, Any
+
+
+try:
+ from tqdm import tqdm, trange
+
+ TQDM = True
+
+except ModuleNotFoundError:
+ TQDM = False
+
+
+def tqdm_progress(iterable=None, *args, **kwargs) -> Any:
+ """Wraps a loop with tqdm to output a progress bar."""
+ if isinstance(iterable, range):
+ return trange(iterable.start, iterable.stop, iterable.step, *args, **kwargs)
+ return tqdm(iterable, *args, **kwargs)
+
+
+class vanilla_progress:
+ """Wraps a loop to output progress reports."""
+
+ def __init__(self, iterable: Iterable[Any], *args, **kwargs):
+ """Initializes a vanilla_progress object with the given iterable.
+
+ Args:
+ iterable (Iterable):
+ The iterable to loop over and track the progress of.
+ """
+ self.iterable = iterable
+ self.args = args
+ self.kwargs = kwargs
+
+ def __iter__(self) -> Iterator[Any]:
+ if isinstance(self.iterable, range):
+ start = self.iterable.start
+ stop = self.iterable.stop
+ step = self.iterable.step
+ total = (stop - start) / step
+ elif isinstance(self.iterable, Sized):
+ total = len(self.iterable)
+ else:
+ total = None
+
+ prev = 0
+ it = 1
+ for i in self.iterable:
+ yield i
+ if total is not None:
+ new = int(str(it / total)[2])
+ if new != prev:
+ prev = new
+ if new == 0:
+ print("100%")
+ else:
+ print(f"{new}0% ...")
+ it += 1
+
+
+progress = tqdm_progress if TQDM else vanilla_progress
diff --git a/macq/utils/pysat.py b/macq/utils/pysat.py
new file mode 100644
index 00000000..464d0f40
--- /dev/null
+++ b/macq/utils/pysat.py
@@ -0,0 +1,106 @@
+from typing import List, Tuple, Dict, Hashable
+from pysat.formula import WCNF
+from pysat.examples.rc2 import RC2
+from nnf import And, Or, Var
+from ..extract.exceptions import InvalidMaxSATModel
+
+
+def get_encoding(
+ clauses: And[Or[Var]], start: int = 1
+) -> Tuple[Dict[Hashable, int], Dict[int, Hashable]]:
+ """Maps NNF clauses to pysat clauses and vice-versa.
+
+ Args:
+ clauses (And[Or[Var]]):
+ NNF clauses (in conjunctive normal form) to be mapped to pysat clauses.
+ start (int):
+ Optional; The number to start the mapping from. Defaults to 1.
+
+ Returns:
+ Tuple[Dict[Hashable, int], Dict[int, Hashable]]:
+ The encode mapping (NNF to pysat), and the decode mapping (pysat to NNF).
+ """
+ decode = dict(enumerate(clauses.vars(), start=start))
+ encode = {v: k for k, v in decode.items()}
+ return encode, decode
+
+
+def encode(clauses: And[Or[Var]], encode: Dict[Hashable, int]) -> List[List[int]]:
+ """Encodes NNF clauses into pysat clauses.
+
+ Args:
+ clauses (And[Or[Var]]):
+ NNF clauses (in conjunctive normal form) to be converted to pysat clauses.
+ encode (Dict[Hashable, int]):
+ The encode mapping to apply to the NNF clauses.
+
+ Returns:
+ List[List[int]]:
+ The pysat encoded clauses.
+ """
+ encoded = [
+ [encode[var.name] if var.true else -encode[var.name] for var in clause]
+ for clause in clauses
+ ]
+ return encoded
+
+
+def to_wcnf(
+ soft_clauses: And[Or[Var]], weights: List[int], hard_clauses: And[Or[Var]] = None
+) -> Tuple[WCNF, Dict[int, Hashable]]:
+ """Builds a pysat weighted CNF theory from pysat clauses.
+
+ Args:
+ soft_clauses (And[Or[Var]]):
+ The soft clauses (NNF clauses, in CNF) for the WCNF theory.
+ weights (List[int]):
+ The weights to associate with the soft clauses.
+ hard_clauses (And[Or[Var]]):
+ Optional; Hard clauses (unweighted) to add to the WCNF theory.
+
+ Returns:
+ Tuple[WCNF, Dict[int, Hashable]]:
+ The WCNF theory, and the decode mapping to convert the pysat vars back to NNF.
+ """
+ wcnf = WCNF()
+ soft_encode, decode = get_encoding(soft_clauses)
+ encoded = encode(soft_clauses, soft_encode)
+ wcnf.extend(encoded, weights)
+
+ if hard_clauses:
+ hard_encode, hard_decode = get_encoding(hard_clauses, start=len(decode) + 1)
+ decode.update(hard_decode)
+ encoded = encode(hard_clauses, hard_encode)
+ wcnf.extend(encoded)
+
+ return wcnf, decode
+
+def extract_raw_model(max_sat: WCNF, decode: Dict[int, Hashable]) -> Dict[Hashable, bool]:
+ """Extracts a raw model given a WCNF and the corresponding decoding dictionary.
+
+ Args:
+ max_sat (WCNF):
+ The WCNF to solve for.
+ decode (Dict[int, Hashable]):
+ The decode dictionary mapping to convert the pysat vars back to NNF.
+
+ Raises:
+ InvalidMaxSATModel:
+ If the model is invalid.
+
+ Returns:
+ Dict[Hashable, bool]:
+ The raw model.
+ """
+ solver = RC2(max_sat)
+ encoded_model = solver.compute()
+
+ if not isinstance(encoded_model, list):
+ # should never be reached
+ raise InvalidMaxSATModel(encoded_model)
+
+ # decode the model (back to nnf vars)
+ model: Dict[Hashable, bool] = {
+ decode[abs(clause)]: clause > 0 for clause in encoded_model
+ }
+ return model
diff --git a/macq/utils/timer.py b/macq/utils/timer.py
index 561b7b54..9bad9298 100644
--- a/macq/utils/timer.py
+++ b/macq/utils/timer.py
@@ -2,7 +2,7 @@
from typing import Union
-def set_timer_throw_exc(num_seconds: Union[float, int], exception: Exception):
+def set_timer_throw_exc(num_seconds: Union[float, int], exception: Exception, *exception_args, **exception_kwargs):
def timer(function):
"""
Checks that a function runs within the specified time and raises an exception if it doesn't.
@@ -27,7 +27,7 @@ def wrapper(*args, **kwargs):
return thr.get()
else:
# otherwise, raise an exception if the function takes too long
- raise exception()
+ raise exception(*exception_args, **exception_kwargs)
return wrapper
@@ -60,12 +60,23 @@ def wrapper(*args, **kwargs):
class TraceSearchTimeOut(Exception):
"""
Raised when the time it takes to generate (or attempt to generate) a single trace is
- longer than the MAX_TRACE_TIME constant. MAX_TRACE_TIME is 30 seconds by default.
+ longer than the generator's `max_time` attribute.
"""
def __init__(
self,
- message="The generator took longer than MAX_TRACE_TIME in its attempt to generate a trace. "
- + "MAX_TRACE_TIME can be changed through the trace generator used.",
+ max_time: float
):
+ message=f"The generator could not find a suitable trace in {max_time} seconds or less. Change the `max_time` attribute for the trace generator used if you would like to have more time to generate a trace."
super().__init__(message)
+
+class InvalidTime(Exception):
+ """
+ Raised when the user supplies an invalid maximum time for a trace to be generated
+ to a generator.
+ """
+ def __init__(
+ self,
+ message="The provided maximum time is invalid.",
+ ):
+ super().__init__(message)
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 41619b15..ac4147a0 100644
--- a/setup.py
+++ b/setup.py
@@ -10,6 +10,7 @@
"tarski@git+git://github.com/aig-upf/tarski.git@devel#egg=tarski[arithmetic]",
"requests",
"rich",
+ "nnf",
"python-sat",
"bauhaus",
"numpy"
diff --git a/test_model.json b/test_model.json
deleted file mode 100644
index 26394024..00000000
--- a/test_model.json
+++ /dev/null
@@ -1 +0,0 @@
-{"fluents": ["(on object g object b)", "(on object e object d)", "(on object b object a)", "(on object i object j)", "(on object f object g)", "(on object j object g)", "(on object a object c)", "(on object i object d)", "(on object i object b)", "(on object h object e)", "(on object a object d)", "(holding object e)", "(on object e object a)", "(on object h object b)", "(ontable object h)", "(clear object b)", "(holding object a)", "(on object a object a)", "(on object c object i)", "(on object h object g)", "(clear object c)", "(on object d object e)", "(on object c object f)", "(on object g object h)", "(on object a object f)", "(on object i object g)", "(on object b object c)", "(holding object h)", "(on object i object h)", "(on object f object j)", "(on object e object f)", "(holding object i)", "(on object i object i)", "(on object e object i)", "(on object j object j)", "(on object d object d)", "(on object i object a)", "(on object h object d)", "(on object g object d)", "(on object d object c)", "(on object c object e)", "(on object d object b)", "(on object b object h)", "(on object g object e)", "(ontable object g)", "(on object c object j)", "(on object b object j)", "(holding object b)", "(on object j object h)", "(on object j object a)", "(holding object d)", "(on object j object f)", "(on object c object b)", "(on object b object f)", "(ontable object a)", "(on object a object i)", "(on object i object e)", "(on object h object j)", "(ontable object b)", "(on object a object b)", "(on object a object j)", "(ontable object f)", "(on object c object c)", "(on object h object f)", "(on object g object j)", "(on object e object c)", "(on object i object f)", "(ontable object e)", "(on object d object a)", "(handempty )", "(on object e object b)", "(on object f object f)", "(on object h object i)", "(on object f object c)", "(holding object j)", "(holding object g)", "(ontable object i)", "(ontable object d)", "(on object e object h)", "(on object f object a)", "(on object c object a)", "(on object a object g)", "(on object c object d)", "(on object j object c)", "(on object j object i)", "(on object f object i)", "(on object j object e)", "(on object g object g)", "(on object f object h)", "(on object a object h)", "(on object g object c)", "(on object a object e)", "(clear object g)", "(on object d object f)", "(on object j object d)", "(on object b object i)", "(on object b object d)", "(on object d object i)", "(on object h object c)", "(on object c object h)", "(on object g object a)", "(on object g object f)", "(on object d object j)", "(clear object e)", "(on object f object d)", "(on object d object h)", "(on object i object c)", "(clear object j)", "(on object g object i)", "(on object c object g)", "(clear object a)", "(clear object i)", "(on object h object a)", "(on object b object b)", "(ontable object c)", "(clear object f)", "(on object b object g)", "(holding object c)", "(on object e object j)", "(on object e object g)", "(on object d object g)", "(on object h object h)", "(on object f object b)", "(on object e object e)", "(on object f object e)", "(holding object f)", "(ontable object j)", "(clear object h)", "(on object b object e)", "(on object j object b)", "(clear object d)"], "actions": [{"name": "unstack", "obj_params": ["object c", "object e"], "cost": 0, "precond": ["(ontable object f)", "(clear object f)", "(on object d object i)", "(on object a object d)", "(handempty )", "(clear object c)", "(on object b object g)", "(on object c object e)", "(on object g object h)", "(on object h object a)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(holding object c)", "(clear object e)"], "delete": ["(on object c object e)", "(handempty )", "(clear object c)"]}, {"name": "pick-up", "obj_params": ["object c"], "cost": 0, "precond": ["(ontable object f)", "(clear object f)", "(ontable object c)", "(on object d object i)", "(on object a object d)", "(handempty )", "(clear object c)", "(on object b object g)", "(on object g object h)", "(on object h object a)", "(clear object e)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(holding object c)"], "delete": ["(ontable object c)", "(handempty )", "(clear object c)"]}, {"name": "put-down", "obj_params": ["object c"], "cost": 0, "precond": ["(ontable object f)", "(clear object f)", "(on object d object i)", "(on object a object d)", "(on object b object g)", "(on object g object h)", "(holding object c)", "(on object h object a)", "(clear object e)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(ontable object c)", "(handempty )", "(clear object c)"], "delete": ["(holding object c)"]}, {"name": "stack", "obj_params": ["object f", "object e"], "cost": 0, "precond": ["(ontable object c)", "(on object d object i)", "(clear object c)", "(on object a object d)", "(on object b object g)", "(on object g object h)", "(on object h object a)", "(ontable object i)", "(on object j object b)"], "add": ["(clear object f)", "(clear object e)", "(handempty )", "(on object f object e)", "(on object e object f)"], "delete": ["(holding object e)", "(holding object f)", "(clear object f)", "(clear object e)"]}, {"name": "put-down", "obj_params": ["object f"], "cost": 0, "precond": ["(on object d object i)", "(clear object c)", "(on object a object d)", "(holding object f)", "(on object b object g)", "(on object g object h)", "(on object h object a)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(ontable object f)", "(clear object f)", "(handempty )"], "delete": ["(holding object f)"]}, {"name": "stack", "obj_params": ["object c", "object f"], "cost": 0, "precond": ["(on object d object i)", "(on object a object d)", "(on object b object g)", "(on object g object h)", "(on object h object a)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(on object c object f)", "(clear object f)", "(handempty )", "(on object f object c)", "(clear object c)"], "delete": ["(clear object c)", "(holding object c)", "(clear object f)", "(holding object f)"]}, {"name": "unstack", "obj_params": ["object c", "object f"], "cost": 0, "precond": ["(clear object f)", "(on object d object i)", "(on object a object d)", "(on object f object c)", "(handempty )", "(on object b object g)", "(on object c object e)", "(on object g object h)", "(on object h object a)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(clear object c)", "(holding object f)"], "delete": ["(on object f object c)", "(clear object f)", "(handempty )"]}, {"name": "unstack", "obj_params": ["object e", "object j"], "cost": 0, "precond": ["(ontable object f)", "(clear object f)", "(ontable object c)", "(on object d object i)", "(on object a object d)", "(handempty )", "(clear object c)", "(on object b object g)", "(on object g object h)", "(on object h object a)", "(clear object e)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(clear object j)", "(holding object e)"], "delete": ["(on object e object j)", "(handempty )", "(clear object e)"]}, {"name": "pick-up", "obj_params": ["object f"], "cost": 0, "precond": ["(ontable object f)", "(clear object f)", "(on object d object i)", "(on object a object d)", "(handempty )", "(clear object c)", "(on object b object g)", "(on object g object h)", "(on object h object a)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(holding object f)"], "delete": ["(ontable object f)", "(clear object f)", "(handempty )"]}, {"name": "stack", "obj_params": ["object c", "object e"], "cost": 0, "precond": ["(ontable object f)", "(clear object f)", "(on object d object i)", "(on object a object d)", "(on object b object g)", "(on object g object h)", "(holding object c)", "(on object h object a)", "(clear object e)", "(ontable object i)", "(on object j object b)", "(on object e object j)"], "add": ["(on object c object e)", "(handempty )", "(clear object c)"], "delete": ["(holding object c)", "(clear object e)"]}]}
\ No newline at end of file
diff --git a/tests/extract/test_amdn.py b/tests/extract/test_amdn.py
index 418e87cd..9717d5e9 100644
--- a/tests/extract/test_amdn.py
+++ b/tests/extract/test_amdn.py
@@ -1,38 +1,197 @@
-from macq.trace.disordered_parallel_actions_observation_lists import default_theta_vec, num_parameters_feature, objects_shared_feature
+from macq.trace.disordered_parallel_actions_observation_lists import (
+ default_theta_vec,
+ num_parameters_feature,
+ objects_shared_feature,
+)
from macq.utils.tokenization_errors import TokenizationError
from tests.utils.generators import generate_blocks_traces
from macq.extract import Extract, modes
-from macq.generate.pddl import RandomGoalSampling
+from macq.generate.pddl import *
from macq.observation import *
from macq.trace import *
-
from pathlib import Path
import pytest
+
def test_tokenization_error():
with pytest.raises(TokenizationError):
trace = generate_blocks_traces(3)[0]
trace.tokenize(Token=NoisyPartialDisorderedParallelObservation)
+def test_tracelist():
+ # define objects
+ red_truck = PlanningObject("", "red_truck")
+ blue_truck = PlanningObject("", "blue_truck")
+ location_a = PlanningObject("", "location_a")
+ location_b = PlanningObject("", "location_b")
+ location_c = PlanningObject("", "location_c")
+ location_d = PlanningObject("", "location_d")
+
+ red_truck_is_truck = Fluent("truck", [red_truck])
+ blue_truck_is_truck = Fluent("truck", [blue_truck])
+ location_a_is_place = Fluent("place", [location_a])
+ location_b_is_place = Fluent("place", [location_b])
+ location_c_is_place = Fluent("place", [location_c])
+ location_d_is_place = Fluent("place", [location_d])
+ red_at_a = Fluent("at", [red_truck, location_a])
+ red_at_b = Fluent("at", [red_truck, location_b])
+ red_at_c = Fluent("at", [red_truck, location_c])
+ red_at_d = Fluent("at", [red_truck, location_d])
+ blue_at_a = Fluent("at", [blue_truck, location_a])
+ blue_at_b = Fluent("at", [blue_truck, location_b])
+ blue_at_c = Fluent("at", [blue_truck, location_c])
+ blue_at_d = Fluent("at", [blue_truck, location_d])
+
+
+
+ drive_red_a_b = Action("drive", [red_truck, location_a, location_b], precond={red_truck_is_truck, location_a_is_place, location_b_is_place, red_at_a}, add={red_at_b}, delete={red_at_a})
+ drive_blue_c_d = Action("drive", [blue_truck, location_c, location_d], precond={blue_truck_is_truck, location_c_is_place, location_d_is_place, blue_at_c}, add={blue_at_d}, delete={blue_at_c})
+ drive_blue_d_b = Action("drive", [blue_truck, location_d, location_b], precond={blue_truck_is_truck, location_d_is_place, location_b_is_place, blue_at_d}, add={blue_at_b}, delete={blue_at_d})
+ drive_red_b_d = Action("drive", [red_truck, location_b, location_d], precond={red_truck_is_truck, location_b_is_place, location_d_is_place, red_at_b}, add={red_at_d}, delete={red_at_b})
+
+
+ # trace: {red a -> b, blue c -> d}, {blue d -> b}, {red b -> d}, {red d -> a, blue b -> c}
+ step_0 = Step(
+ State(
+ {
+ red_truck_is_truck: True,
+ blue_truck_is_truck: True,
+ location_a_is_place: True,
+ location_b_is_place: True,
+ location_c_is_place: True,
+ location_d_is_place: True,
+ red_at_a: True,
+ red_at_b: False,
+ red_at_c: False,
+ red_at_d: False,
+ blue_at_a: False,
+ blue_at_b: False,
+ blue_at_c: True,
+ blue_at_d: False,
+ }),
+ drive_red_a_b,
+ 0
+ )
+
+ step_1 = Step(
+ State(
+ {
+ red_truck_is_truck: True,
+ blue_truck_is_truck: True,
+ location_a_is_place: True,
+ location_b_is_place: True,
+ location_c_is_place: True,
+ location_d_is_place: True,
+ red_at_a: False,
+ red_at_b: True,
+ red_at_c: False,
+ red_at_d: False,
+ blue_at_a: False,
+ blue_at_b: False,
+ blue_at_c: True,
+ blue_at_d: False,
+ }),
+ drive_blue_c_d,
+ 1
+ )
+
+ step_2 = Step(
+ State(
+ {
+ red_truck_is_truck: True,
+ blue_truck_is_truck: True,
+ location_a_is_place: True,
+ location_b_is_place: True,
+ location_c_is_place: True,
+ location_d_is_place: True,
+ red_at_a: False,
+ red_at_b: True,
+ red_at_c: False,
+ red_at_d: False,
+ blue_at_a: False,
+ blue_at_b: False,
+ blue_at_c: False,
+ blue_at_d: True,
+ }),
+ drive_blue_d_b,
+ 2
+ )
+
+ step_3 = Step(
+ State(
+ {
+ red_truck_is_truck: True,
+ blue_truck_is_truck: True,
+ location_a_is_place: True,
+ location_b_is_place: True,
+ location_c_is_place: True,
+ location_d_is_place: True,
+ red_at_a: False,
+ red_at_b: True,
+ red_at_c: False,
+ red_at_d: False,
+ blue_at_a: False,
+ blue_at_b: True,
+ blue_at_c: False,
+ blue_at_d: False,
+ }),
+ drive_red_b_d,
+ 3
+ )
+
+ step_4 = Step(
+ State(
+ {
+ red_truck_is_truck: True,
+ blue_truck_is_truck: True,
+ location_a_is_place: True,
+ location_b_is_place: True,
+ location_c_is_place: True,
+ location_d_is_place: True,
+ red_at_a: False,
+ red_at_b: False,
+ red_at_c: False,
+ red_at_d: True,
+ blue_at_a: False,
+ blue_at_b: True,
+ blue_at_c: False,
+ blue_at_d: False,
+ }),
+ None,
+ 4
+ )
+ #step_2.action = None
+ #return TraceList([Trace([step_0, step_1, step_2])])#, step_3, step_4])])
+ return TraceList([Trace([step_0, step_1, step_2, step_3, step_4])])
if __name__ == "__main__":
# exit out to the base macq folder so we can get to /tests
base = Path(__file__).parent.parent
+
+ # use blocksworld (NOTE: no actions are parallel in this domain)
dom = str((base / "pddl_testing_files/blocks_domain.pddl").resolve())
prob = str((base / "pddl_testing_files/blocks_problem.pddl").resolve())
-
# TODO: replace with a domain-specific random trace generator
traces = RandomGoalSampling(
prob=prob,
dom=dom,
- #problem_id=2337,
observe_pres_effs=True,
num_traces=1,
- steps_deep=10,
+ steps_deep=4,
subset_size_perc=0.1,
- enforced_hill_climbing_sampling=True
+ enforced_hill_climbing_sampling=True,
).traces
+
+ # use the simple truck domain for debugging
+ # traces = test_tracelist()
+
+ # use the simple door domain for debugging
+ # dom = str((base / "pddl_testing_files/door_dom.pddl").resolve())
+ # prob = str((base / "pddl_testing_files/door_prob.pddl").resolve())
+ # traces = TraceList([TraceFromGoal(dom=dom, prob=prob, observe_pres_effs=True).trace])
+ traces.print(wrap="y")
+
features = [objects_shared_feature, num_parameters_feature]
learned_theta = default_theta_vec(2)
observations = traces.tokenize(
@@ -40,6 +199,10 @@ def test_tokenization_error():
ObsLists=DisorderedParallelActionsObservationLists,
features=features,
learned_theta=learned_theta,
- percent_missing=0.10,
- percent_noisy=0.05,
- )
\ No newline at end of file
+ percent_missing=0,
+ percent_noisy=0,
+ replace=True
+ )
+ model = Extract(observations, modes.AMDN, debug=True, occ_threshold = 2)
+ f = open("results.txt", "w")
+ f.write(model.details())
\ No newline at end of file
diff --git a/tests/extract/test_arms.py b/tests/extract/test_arms.py
new file mode 100644
index 00000000..ae9ec82e
--- /dev/null
+++ b/tests/extract/test_arms.py
@@ -0,0 +1,51 @@
+from pathlib import Path
+from typing import List
+from macq.trace import *
+from macq.extract import Extract, modes
+from macq.observation import PartialObservation
+from macq.generate.pddl import *
+
+
+def get_fluent(name: str, objs: List[str]):
+ objects = [PlanningObject(o.split()[0], o.split()[1]) for o in objs]
+ return Fluent(name, objects)
+
+
+def test_arms():
+ base = Path(__file__).parent.parent
+ dom = str((base / "pddl_testing_files/blocks_domain.pddl").resolve())
+ prob = str((base / "pddl_testing_files/blocks_problem.pddl").resolve())
+
+ traces = TraceList()
+ generator = TraceFromGoal(dom=dom, prob=prob)
+
+ generator.change_goal(
+ {
+ get_fluent("on", ["object a", "object b"]),
+ get_fluent("on", ["object b", "object c"]),
+ }
+ )
+ traces.append(generator.generate_trace())
+ generator.change_goal(
+ {
+ get_fluent("on", ["object b", "object a"]),
+ get_fluent("on", ["object c", "object b"]),
+ }
+ )
+ traces.append(generator.generate_trace())
+
+ observations = traces.tokenize(PartialObservation, percent_missing=0.5)
+ model = Extract(
+ observations,
+ modes.ARMS,
+ debug=False,
+ upper_bound=2,
+ min_support=2,
+ action_weight=110,
+ info_weight=100,
+ threshold=0.6,
+ info3_default=30,
+ plan_default=30,
+ )
+
+ assert model
diff --git a/tests/extract/test_extract.py b/tests/extract/test_extract.py
index 1ff02e9f..fa8fc4fe 100644
--- a/tests/extract/test_extract.py
+++ b/tests/extract/test_extract.py
@@ -3,8 +3,7 @@
from macq.observation import Observation
from tests.utils.test_traces import blocks_world
-# Other functionality of extract is implicitly tested by any extraction technique
-# This is reflected in coverage reports
+# Other functionality of extract is tested by extraction technique tests
def test_incompatible_observation_token():
diff --git a/tests/extract/test_observer.py b/tests/extract/test_observer.py
index e00c4aed..2d550a1d 100644
--- a/tests/extract/test_observer.py
+++ b/tests/extract/test_observer.py
@@ -19,12 +19,18 @@ def test_observer():
if __name__ == "__main__":
# exit out to the base macq folder so we can get to /tests
base = Path(__file__).parent.parent
- model_blocks_dom = str((base / "generated_testing_files/model_blocks_domain.pddl").resolve())
- model_blocks_prob = str((base / "generated_testing_files/model_blocks_problem.pddl").resolve())
+ model_blocks_dom = str(
+ (base / "generated_testing_files/model_blocks_domain.pddl").resolve()
+ )
+ model_blocks_prob = str(
+ (base / "generated_testing_files/model_blocks_problem.pddl").resolve()
+ )
traces = blocks_world(5)
observations = traces.tokenize(IdentityObservation)
traces.print()
model = Extract(observations, modes.OBSERVER)
print(model.details())
- model.to_pddl("model_blocks_dom", "model_blocks_prob", model_blocks_dom, model_blocks_prob)
+ model.to_pddl(
+ "model_blocks_dom", "model_blocks_prob", model_blocks_dom, model_blocks_prob
+ )
diff --git a/tests/extract/test_slaf.py b/tests/extract/test_slaf.py
index 6ce11f2d..5bbbc08f 100644
--- a/tests/extract/test_slaf.py
+++ b/tests/extract/test_slaf.py
@@ -10,7 +10,6 @@ def test_slaf():
traces = generate_blocks_traces(plan_len=2, num_traces=1)
observations = traces.tokenize(
AtomicPartialObservation,
- method=AtomicPartialObservation.random_subset,
percent_missing=0.10,
)
model = Extract(observations, modes.SLAF)
@@ -23,17 +22,22 @@ def test_slaf():
if __name__ == "__main__":
# exit out to the base macq folder so we can get to /tests
base = Path(__file__).parent.parent
- model_blocks_dom = str((base / "generated_testing_files/model_blocks_domain.pddl").resolve())
- model_blocks_prob = str((base / "generated_testing_files/model_blocks_problem.pddl").resolve())
+ model_blocks_dom = str(
+ (base / "generated_testing_files/model_blocks_domain.pddl").resolve()
+ )
+ model_blocks_prob = str(
+ (base / "generated_testing_files/model_blocks_problem.pddl").resolve()
+ )
traces = generate_blocks_traces(plan_len=2, num_traces=1)
observations = traces.tokenize(
AtomicPartialObservation,
- method=AtomicPartialObservation.random_subset,
percent_missing=0.10,
)
traces.print()
- model = Extract(observations, modes.SLAF, debug_mode=True)
+ model = Extract(observations, modes.SLAF)
print(model.details())
- model.to_pddl("model_blocks_dom", "model_blocks_prob", model_blocks_dom, model_blocks_prob)
+ model.to_pddl(
+ "model_blocks_dom", "model_blocks_prob", model_blocks_dom, model_blocks_prob
+ )
diff --git a/tests/generate/pddl/test_plan.py b/tests/generate/pddl/test_plan.py
index 0efe17c6..f09e800d 100644
--- a/tests/generate/pddl/test_plan.py
+++ b/tests/generate/pddl/test_plan.py
@@ -20,4 +20,4 @@
plan = vanilla.generate_plan(from_ipc_file=True, filename=path)
tracelist = TraceList()
tracelist.append(vanilla.generate_single_trace_from_plan(plan))
- tracelist.print(wrap="y")
\ No newline at end of file
+ tracelist.print(wrap="y")
diff --git a/tests/generate/pddl/test_random_goal_sampling.py b/tests/generate/pddl/test_random_goal_sampling.py
index e4b051b1..be73380c 100644
--- a/tests/generate/pddl/test_random_goal_sampling.py
+++ b/tests/generate/pddl/test_random_goal_sampling.py
@@ -11,10 +11,11 @@
random_sampler = RandomGoalSampling(
dom=dom,
prob=prob,
- num_traces=3,
- steps_deep=10,
+ num_traces=20,
+ steps_deep=20,
subset_size_perc=0.1,
- enforced_hill_climbing_sampling=False
+ enforced_hill_climbing_sampling=False,
+ max_time=10
)
traces = random_sampler.traces
traces.print(wrap="y")
@@ -27,5 +28,5 @@
num_traces=3,
steps_deep=10,
subset_size_perc=0.1,
- enforced_hill_climbing_sampling=False
- ).traces
\ No newline at end of file
+ enforced_hill_climbing_sampling=False,
+ ).traces
diff --git a/tests/generate/pddl/test_trace_from_goal.py b/tests/generate/pddl/test_trace_from_goal.py
index 148a4fb4..25d461c7 100644
--- a/tests/generate/pddl/test_trace_from_goal.py
+++ b/tests/generate/pddl/test_trace_from_goal.py
@@ -32,8 +32,12 @@ def test_invalid_goal_change():
dom = str((base / "pddl_testing_files/blocks_domain.pddl").resolve())
prob = str((base / "pddl_testing_files/blocks_problem.pddl").resolve())
- new_blocks_dom = str((base / "generated_testing_files/new_blocks_dom.pddl").resolve())
- new_blocks_prob = str((base / "generated_testing_files/new_blocks_prob.pddl").resolve())
+ new_blocks_dom = str(
+ (base / "generated_testing_files/new_blocks_dom.pddl").resolve()
+ )
+ new_blocks_prob = str(
+ (base / "generated_testing_files/new_blocks_prob.pddl").resolve()
+ )
new_game_dom = str((base / "generated_testing_files/new_game_dom.pddl").resolve())
new_game_prob = str((base / "generated_testing_files/new_game_prob.pddl").resolve())
diff --git a/tests/generate/pddl/test_vanilla_sampling.py b/tests/generate/pddl/test_vanilla_sampling.py
index 0765562b..23eff7ff 100644
--- a/tests/generate/pddl/test_vanilla_sampling.py
+++ b/tests/generate/pddl/test_vanilla_sampling.py
@@ -4,6 +4,7 @@
from macq.generate.pddl.generator import InvalidGoalFluent
from macq.utils import InvalidNumberOfTraces, InvalidPlanLength
from macq.trace import Fluent, PlanningObject, TraceList
+from macq.utils import TraceSearchTimeOut, InvalidTime
def test_invalid_vanilla_sampling():
@@ -30,15 +31,28 @@ def test_invalid_vanilla_sampling():
"new_blocks_dom.pddl",
"new_blocks_prob.pddl",
)
+
+ with pytest.raises(TraceSearchTimeOut):
+ VanillaSampling(dom=dom, prob=prob, plan_len=10, num_traces=1, max_time=5)
+
+ with pytest.raises(InvalidTime):
+ VanillaSampling(dom=dom, prob=prob, plan_len=10, num_traces=1, max_time=0)
if __name__ == "__main__":
# exit out to the base macq folder so we can get to /tests
base = Path(__file__).parent.parent.parent
+
dom = str((base / "pddl_testing_files/blocks_domain.pddl").resolve())
prob = str((base / "pddl_testing_files/blocks_problem.pddl").resolve())
vanilla = VanillaSampling(dom=dom, prob=prob, plan_len=7, num_traces=10)
+ traces = vanilla.traces
+ traces.generate_more(3)
+ dom = str((base / "pddl_testing_files/playlist_domain.pddl").resolve())
+ prob = str((base / "pddl_testing_files/playlist_problem.pddl").resolve())
+ VanillaSampling(dom=dom, prob=prob, plan_len=10, num_traces=10, max_time=3)
+
new_blocks_dom = str((base / "generated_testing_files/new_blocks_dom.pddl").resolve())
new_blocks_prob = str((base / "generated_testing_files/new_blocks_prob.pddl").resolve())
new_game_dom = str((base / "generated_testing_files/new_game_dom.pddl").resolve())
@@ -86,4 +100,6 @@ def test_invalid_vanilla_sampling():
tracelist.print(wrap="y")
# test generating traces with action preconditions/effects known
- vanilla_traces = VanillaSampling(problem_id=4627, plan_len=7, num_traces=10, observe_pres_effs=True).traces
+ vanilla_traces = VanillaSampling(
+ problem_id=123, plan_len=7, num_traces=10, observe_pres_effs=True
+ ).traces
diff --git a/tests/pddl_testing_files/door_dom.pddl b/tests/pddl_testing_files/door_dom.pddl
new file mode 100644
index 00000000..9a14e3b7
--- /dev/null
+++ b/tests/pddl_testing_files/door_dom.pddl
@@ -0,0 +1,20 @@
+(define (domain door)
+
+ (:requirements :strips )
+
+ (:predicates
+ (roomA ) (roomB ) (open )
+ )
+
+ (:action open
+ :parameters ( )
+ :precondition (and (roomA ))
+ :effect (and (open ))
+ )
+
+ (:action walk
+ :parameters ( )
+ :precondition (and (roomA ) (open ))
+ :effect (and (not (roomA )) (roomB ))
+ )
+)
\ No newline at end of file
diff --git a/tests/pddl_testing_files/door_prob.pddl b/tests/pddl_testing_files/door_prob.pddl
new file mode 100644
index 00000000..fabd0d89
--- /dev/null
+++ b/tests/pddl_testing_files/door_prob.pddl
@@ -0,0 +1,4 @@
+(define (problem doors)
+ (:domain door)
+ (:init (roomA ))
+ (:goal (and (roomB ))))
\ No newline at end of file
diff --git a/tests/test_readme.py b/tests/test_readme.py
index 594a0e2b..f81967d7 100644
--- a/tests/test_readme.py
+++ b/tests/test_readme.py
@@ -22,6 +22,8 @@ def test_readme():
assert len(traces) == 4
action1 = traces[0][0].action
+ assert action1
+ assert traces.get_usage(action1)
trace = traces[0]
assert len(trace) == 5
@@ -42,4 +44,4 @@ def test_readme():
# run as a script to look over the extracted model
traces = generate_traces()
model = extract_model(traces)
- print(model.details())
\ No newline at end of file
+ print(model.details())
diff --git a/tests/trace/test_action.py b/tests/trace/test_action.py
index 7358ed48..8425a1b4 100644
--- a/tests/trace/test_action.py
+++ b/tests/trace/test_action.py
@@ -13,5 +13,5 @@ def test_action():
assert str(a1)
obj = PlanningObject("test_obj", "test")
- a1.add_parameter(obj)
+ a1.obj_params.append(obj)
assert obj in a1.obj_params
diff --git a/tests/trace/test_trace.py b/tests/trace/test_trace.py
index 6e4c9be3..c21eb43d 100644
--- a/tests/trace/test_trace.py
+++ b/tests/trace/test_trace.py
@@ -61,7 +61,7 @@ def test_trace_get_sas_triples():
(state2, state3) = (trace.steps[1].state, trace.steps[2].state)
assert isinstance(action2, Action)
- assert trace.get_sas_triples(action2) == {SAS(state2, action2, state3)}
+ assert trace.get_sas_triples(action2) == [SAS(state2, action2, state3)]
# test that the total cost is working correctly
diff --git a/tests/trace/test_trace_list.py b/tests/trace/test_trace_list.py
index 200ff93f..29368cc9 100644
--- a/tests/trace/test_trace_list.py
+++ b/tests/trace/test_trace_list.py
@@ -22,6 +22,7 @@ def test_trace_list():
assert trace_list[0] is first
action = trace_list[0].steps[0].action
+ assert action
usages = trace_list.get_usage(action)
for i, trace in enumerate(trace_list):
assert usages[i] == trace.get_usage(action)