From 9109e115778fb2eb5ee9df9a60a91adb82eff698 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Mon, 1 Jan 2024 23:01:26 -0500 Subject: [PATCH] adding typing to tree branches --- penman/layout.py | 13 +++++++------ penman/transform.py | 4 ++-- penman/tree.py | 35 +++++++++++++++++++++++++++-------- penman/types.py | 5 +++-- 4 files changed, 39 insertions(+), 18 deletions(-) diff --git a/penman/layout.py b/penman/layout.py index c81dc89..e26b30e 100644 --- a/penman/layout.py +++ b/penman/layout.py @@ -59,7 +59,7 @@ from penman.graph import CONCEPT_ROLE, Graph from penman.model import Model from penman.surface import Alignment, RoleAlignment -from penman.tree import Tree, is_atomic +from penman.tree import Tree, is_atomic, is_tgt_node, is_tgt_symbol from penman.types import BasicTriple, Branch, Node, Role, Variable logger = logging.getLogger(__name__) @@ -166,10 +166,10 @@ def _interpret_node(t: Node, variables: Set[Variable], model: Model): has_concept |= role == CONCEPT_ROLE # atomic targets - if is_atomic(target): - target, target_epis = _process_atomic(target) + if is_tgt_symbol(target): + tgt, target_epis = _process_atomic(target) epis.extend(target_epis) - triple = (var, role, target) + triple = (var, role, tgt) if model.is_role_inverted(role): if target in variables: triple = model.invert(triple) @@ -178,7 +178,8 @@ def _interpret_node(t: Node, variables: Set[Variable], model: Model): triples.append(triple) epidata.append((triple, epis)) # nested nodes - else: + # mypy forgets that (Node ∨ Sym) ^ ¬Sym → Node + elif is_tgt_node(target): triple = model.deinvert((var, role, target[0])) triples.append(triple) @@ -566,7 +567,7 @@ def _rearrange(node: Node, key: Callable[[Branch], Any]) -> None: first = [] rest = branches[:] for _, target in rest: - if not is_atomic(target): + if is_tgt_node(target): _rearrange(target, key=key) branches[:] = first + sorted(rest, key=key) diff --git a/penman/transform.py b/penman/transform.py index d8e05a2..88b0260 100644 --- a/penman/transform.py +++ b/penman/transform.py @@ -17,7 +17,7 @@ ) from penman.model import Model from penman.surface import Alignment, RoleAlignment, alignments -from penman.tree import Tree, is_atomic +from penman.tree import Tree, is_tgt_node from penman.types import BasicTriple, Node, Target, Variable logger = logging.getLogger(__name__) @@ -62,7 +62,7 @@ def _canonicalize_node(node: Node, model: Model) -> Node: role, tgt = edge # alignments aren't parsed off yet, so handle them superficially role, tilde, alignment = role.partition('~') - if not is_atomic(tgt): + if is_tgt_node(tgt): tgt = _canonicalize_node(tgt, model) canonical_role = model.canonicalize_role(role) + tilde + alignment canonical_edges.append((canonical_role, tgt)) diff --git a/penman/tree.py b/penman/tree.py index 53c811d..50c8373 100644 --- a/penman/tree.py +++ b/penman/tree.py @@ -4,7 +4,9 @@ from typing import Any, Dict, Iterator, List, Mapping, Optional, Set, Tuple -from penman.types import Branch, Node, Variable +from typing_extensions import TypeGuard + +from penman.types import Branch, Node, Symbol, Variable _Step = Tuple[Tuple[int, ...], Branch] # see Tree.walk() @@ -112,10 +114,10 @@ def _format(node: Node, level: int) -> str: def _format_branch(branch: Branch, level: int) -> str: role, target = branch - if is_atomic(target): - target = repr(target) - else: + if is_tgt_node(target): target = _format(target, level) + else: + target = repr(target) return f'({role!r}, {target})' @@ -124,7 +126,7 @@ def _nodes(node: Node) -> List[Node]: ns = [] if var is None else [node] for _, target in branches: # if target is not atomic, assume it's a valid tree node - if not is_atomic(target): + if is_tgt_node(target): ns.extend(_nodes(target)) return ns @@ -135,7 +137,7 @@ def _walk(node: Node, path: Tuple[int, ...]) -> Iterator[_Step]: curpath = path + (i,) yield (curpath, branch) _, target = branch - if not is_atomic(target): + if is_tgt_node(target): yield from _walk(target, curpath) @@ -180,15 +182,32 @@ def _map_vars( newbranches: List[Branch] = [] for role, tgt in branches: - if not is_atomic(tgt): + if is_tgt_node(tgt): tgt = _map_vars(tgt, varmap) - elif role != '/' and tgt in varmap: + # MyPy forgets that (Node ∨ Sym) ^ ¬Node → Sym + elif is_tgt_symbol(tgt) and role != '/' and tgt in varmap: tgt = varmap[tgt] newbranches.append((role, tgt)) return (varmap[var], newbranches) +def is_tgt_node(target: Symbol | Node) -> TypeGuard[Node]: + """ + Inverse of :func:`is_atomic`, only for Symbol | Node from branches. + Automatically narrows the type to Node for better type inference + """ + return not is_atomic(target) + + +def is_tgt_symbol(target: Symbol | Node) -> TypeGuard[Symbol]: + """ + Same as :func:`is_atomic`, only for Symbol | Node from branches. + Automatically narrows the type to Symbol for better type inference + """ + return is_atomic(target) + + def is_atomic(x: Any) -> bool: """ Return ``True`` if *x* is a valid atomic value. diff --git a/penman/types.py b/penman/types.py index 55e8653..1d2545f 100644 --- a/penman/types.py +++ b/penman/types.py @@ -2,14 +2,15 @@ Basic types used by various Penman modules. """ -from typing import Any, Iterable, List, Tuple, Union +from typing import Iterable, List, Tuple, Union Variable = str Constant = Union[str, float, int, None] # None for missing values Role = str # '' for anonymous relations +Symbol = str # Tree types -Branch = Tuple[Role, Any] +Branch = Tuple[Role, Union[Symbol, "Node"]] Node = Tuple[Variable, List[Branch]] # Graph types