Skip to content

Commit

Permalink
Merge pull request #109 from ssardina-research/string-representation
Browse files Browse the repository at this point in the history
Implement __str__ for Domain and Problem
  • Loading branch information
francescofuggitti authored Mar 21, 2024
2 parents 4ee8d63 + 553de64 commit 8e3be08
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 55 deletions.
57 changes: 56 additions & 1 deletion pddl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
#

"""
Core module of the package.
It contains the class definitions to build and modify PDDL domains or problems.
"""
from textwrap import indent
from typing import AbstractSet, Collection, Dict, Optional, Tuple, cast

from pddl._validation import (
Expand All @@ -27,6 +27,14 @@
from pddl.action import Action
from pddl.custom_types import name as name_type
from pddl.custom_types import namelike, parse_name, to_names, to_types # noqa: F401
from pddl.formatter import (
print_constants,
print_function_skeleton,
print_predicates_with_types,
print_types_or_functions_with_parents,
remove_empty_lines,
sort_and_print_collection,
)
from pddl.helpers.base import assert_, check, ensure, ensure_set
from pddl.logic.base import And, Formula, is_literal
from pddl.logic.functions import FunctionExpression, Metric, NumericFunction
Expand Down Expand Up @@ -148,6 +156,36 @@ def __eq__(self, other):
and self.actions == other.actions
)

def __str__(self) -> str:
"""Print a PDDL domain object."""
result = f"(define (domain {self.name})"
body = ""
indentation = " " * 4
body += sort_and_print_collection("(:requirements ", self.requirements, ")\n")
body += print_types_or_functions_with_parents("(:types", self.types, ")\n")
body += print_constants("(:constants", self.constants, ")\n")
if self.predicates:
body += f"(:predicates {print_predicates_with_types(self.predicates)})\n"
if self.functions:
body += print_types_or_functions_with_parents(
"(:functions", self.functions, ")\n", print_function_skeleton
)
body += sort_and_print_collection(
"",
self.derived_predicates,
"",
to_string=lambda obj: str(obj) + "\n",
)
body += sort_and_print_collection(
"",
self.actions,
"",
to_string=lambda obj: str(obj) + "\n",
)
result = result + "\n" + indent(body, indentation) + "\n)"
result = remove_empty_lines(result)
return result


class Problem:
"""A class for a PDDL problem file."""
Expand Down Expand Up @@ -331,3 +369,20 @@ def __eq__(self, other):
and self.goal == other.goal
and self.metric == other.metric
)

def __str__(self) -> str:
"""Print a PDDL problem object."""
result = f"(define (problem {self.name})"
body = f"(:domain {self.domain_name})\n"
indentation = " " * 4
body += sort_and_print_collection("(:requirements ", self.requirements, ")\n")
if self.objects:
body += print_constants("(:objects", self.objects, ")\n")
body += sort_and_print_collection(
"(:init ", self.init, ")\n", is_mandatory=True
)
body += f"{'(:goal ' + str(self.goal) + ')'}\n"
body += f"{'(:metric ' + str(self.metric) + ')'}\n" if self.metric else ""
result = result + "\n" + indent(body, indentation) + "\n)"
result = remove_empty_lines(result)
return result
84 changes: 30 additions & 54 deletions pddl/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,51 @@
#

"""Formatting utilities for PDDL domains and problems."""
from textwrap import indent
from typing import Callable, Collection, Dict, List, Optional, TypeVar

from pddl.core import Domain, Problem
from pddl.custom_types import name
from pddl.logic.functions import NumericFunction
from pddl.logic.terms import Constant

T = TypeVar("T", name, NumericFunction)


def _remove_empty_lines(s: str) -> str:
def remove_empty_lines(s: str) -> str:
"""Remove empty lines from string."""
return "\n".join(filter(str.strip, s.splitlines()))


def _sort_and_print_collection(
def sort_and_print_collection(
prefix,
collection: Collection,
postfix,
to_string: Callable = str,
is_mandatory: bool = False,
):
r"""Produce the string of a PDDL section for a collection (e.g., requirements, actions, objects, etc.).
Prefix starts the PDDL section, like "(:requirements" or "(:actions"
Postfix ends the section, usually with ")\n"
The collection is sorted and printed as a string, using to_string to convert each element to a string.
Args:
prefix (str): start of the string
collection (Collection): the collection of entities to report as a string
postfix (str): the end of the string
to_string (Callable, optional): the function to use to convert to string. Defaults to str.
is_mandatory (bool, optional): if the string is mandatory even if the collection is empty. Defaults to False.
Returns:
str: a string with <prefix> <string of collection> <postfix>
"""
if len(collection) > 0:
return prefix + " ".join(sorted(map(to_string, collection))) + postfix
elif is_mandatory:
return prefix + postfix
return ""


def _print_types_or_functions_with_parents(
def print_types_or_functions_with_parents(
prefix: str,
types_dict: Dict[T, Optional[name]],
postfix: str,
Expand All @@ -53,10 +67,10 @@ def _print_types_or_functions_with_parents(
name_by_obj.setdefault(parent_type, []).append(obj_name) # type: ignore
if not bool(name_by_obj):
return ""
return _print_typed_lists(prefix, name_by_obj, postfix, to_string)
return print_typed_lists(prefix, name_by_obj, postfix, to_string)


def _print_constants(
def print_constants(
prefix, constants: Collection[Constant], postfix, to_string: Callable = str
):
"""Print constants in a PDDL domain."""
Expand All @@ -65,10 +79,11 @@ def _print_constants(
term_by_type_tags.setdefault(c.type_tag, []).append(c.name)
if not bool(term_by_type_tags):
return ""
return _print_typed_lists(prefix, term_by_type_tags, postfix, to_string)
return print_typed_lists(prefix, term_by_type_tags, postfix, to_string)


def _print_predicates_with_types(predicates: Collection):
def print_predicates_with_types(predicates: Collection):
"""Generate a string with predicates with type tags for the :predicates section."""
result = ""
for p in sorted(predicates):
if p.arity == 0:
Expand All @@ -89,7 +104,7 @@ def _print_predicates_with_types(predicates: Collection):
return result.strip()


def _print_function_skeleton(function: NumericFunction) -> str:
def print_function_skeleton(function: NumericFunction) -> str:
"""Callable to print a function skeleton with type tags."""
result = ""
if function.arity == 0:
Expand All @@ -106,7 +121,7 @@ def _print_function_skeleton(function: NumericFunction) -> str:
return result


def _print_typed_lists(
def print_typed_lists(
prefix,
names_by_obj: Dict[Optional[T], List[name]],
postfix,
Expand Down Expand Up @@ -139,50 +154,11 @@ def _print_typed_lists(
return result


def domain_to_string(domain: Domain) -> str:
def domain_to_string(domain) -> str:
"""Print a PDDL domain object."""
result = f"(define (domain {domain.name})"
body = ""
indentation = " " * 4
body += _sort_and_print_collection("(:requirements ", domain.requirements, ")\n")
body += _print_types_or_functions_with_parents("(:types", domain.types, ")\n")
body += _print_constants("(:constants", domain.constants, ")\n")
if domain.predicates:
body += f"(:predicates {_print_predicates_with_types(domain.predicates)})\n"
if domain.functions:
body += _print_types_or_functions_with_parents(
"(:functions", domain.functions, ")\n", _print_function_skeleton
)
body += _sort_and_print_collection(
"",
domain.derived_predicates,
"",
to_string=lambda obj: str(obj) + "\n",
)
body += _sort_and_print_collection(
"",
domain.actions,
"",
to_string=lambda obj: str(obj) + "\n",
)
result = result + "\n" + indent(body, indentation) + "\n)"
result = _remove_empty_lines(result)
return result
return str(domain)


def problem_to_string(problem: Problem) -> str:
def problem_to_string(problem) -> str:
"""Print a PDDL problem object."""
result = f"(define (problem {problem.name})"
body = f"(:domain {problem.domain_name})\n"
indentation = " " * 4
body += _sort_and_print_collection("(:requirements ", problem.requirements, ")\n")
if problem.objects:
body += _print_constants("(:objects", problem.objects, ")\n")
body += _sort_and_print_collection(
"(:init ", problem.init, ")\n", is_mandatory=True
)
body += f"{'(:goal ' + str(problem.goal) + ')'}\n"
body += f"{'(:metric ' + str(problem.metric) + ')'}\n" if problem.metric else ""
result = result + "\n" + indent(body, indentation) + "\n)"
result = _remove_empty_lines(result)
return result
return str(problem)

0 comments on commit 8e3be08

Please sign in to comment.