diff --git a/tucan/canonicalization.py b/tucan/canonicalization.py index 7fcf778..0aa0f89 100644 --- a/tucan/canonicalization.py +++ b/tucan/canonicalization.py @@ -2,7 +2,7 @@ from tucan.graph_utils import attribute_sequence import networkx as nx from igraph import Graph as iGraph -from typing import Iterator +from typing import Generator def partition_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph: @@ -26,19 +26,14 @@ def get_number_of_partitions(m: nx.Graph) -> int: return max(nx.get_node_attributes(m, PARTITION).values()) -def refine_partitions(m: nx.Graph) -> Iterator[nx.Graph]: - n_current_partitions = get_number_of_partitions(m) - - if n_current_partitions == m.number_of_nodes() - 1: - # partitions are discrete (i.e., each node in a separate partition) - return m +def refine_partitions(m: nx.Graph) -> Generator[nx.Graph, None, None]: m_refined = partition_molecule_by_attribute(m, PARTITION) - if get_number_of_partitions(m_refined) == n_current_partitions: - # no refinement possible - return m - yield m_refined + if get_number_of_partitions(m_refined) == get_number_of_partitions(m): + # No more refinement possible. + yield m_refined + return yield from refine_partitions(m_refined) @@ -70,9 +65,8 @@ def assign_canonical_labels(m: nx.Graph) -> dict[int, int]: def canonicalize_molecule(m: nx.Graph) -> nx.Graph: m_partitioned_by_invariant_code = partition_molecule_by_attribute(m, INVARIANT_CODE) - m_refined = list(refine_partitions(m_partitioned_by_invariant_code)) - m_partitioned = m_refined[-1] if m_refined else m_partitioned_by_invariant_code + m_refined = list(refine_partitions(m_partitioned_by_invariant_code))[-1] - canonical_labels = assign_canonical_labels(m_partitioned) + canonical_labels = assign_canonical_labels(m_refined) - return nx.relabel_nodes(m_partitioned, canonical_labels, copy=True) + return nx.relabel_nodes(m_refined, canonical_labels, copy=True)