Skip to content

Commit

Permalink
Fix generator
Browse files Browse the repository at this point in the history
  • Loading branch information
JanCBrammer committed Aug 23, 2024
1 parent f9bbb69 commit 8520978
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions tucan/canonicalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tucan.graph_utils import attribute_sequence
import networkx as nx
from igraph import Graph as iGraph
from typing import Iterator
from typing import Generator


def partition_molecule_by_attribute(m: nx.Graph, attribute: str) -> nx.Graph:
Expand All @@ -26,19 +26,14 @@ def get_number_of_partitions(m: nx.Graph) -> int:
return max(nx.get_node_attributes(m, PARTITION).values())


def refine_partitions(m: nx.Graph) -> Iterator[nx.Graph]:
n_current_partitions = get_number_of_partitions(m)

if n_current_partitions == m.number_of_nodes() - 1:
# partitions are discrete (i.e., each node in a separate partition)
return m
def refine_partitions(m: nx.Graph) -> Generator[nx.Graph, None, None]:

m_refined = partition_molecule_by_attribute(m, PARTITION)
if get_number_of_partitions(m_refined) == n_current_partitions:
# no refinement possible
return m

yield m_refined
if get_number_of_partitions(m_refined) == get_number_of_partitions(m):
# No more refinement possible.
yield m_refined
return

yield from refine_partitions(m_refined)

Expand Down Expand Up @@ -70,9 +65,8 @@ def assign_canonical_labels(m: nx.Graph) -> dict[int, int]:

def canonicalize_molecule(m: nx.Graph) -> nx.Graph:
m_partitioned_by_invariant_code = partition_molecule_by_attribute(m, INVARIANT_CODE)
m_refined = list(refine_partitions(m_partitioned_by_invariant_code))
m_partitioned = m_refined[-1] if m_refined else m_partitioned_by_invariant_code
m_refined = list(refine_partitions(m_partitioned_by_invariant_code))[-1]

canonical_labels = assign_canonical_labels(m_partitioned)
canonical_labels = assign_canonical_labels(m_refined)

return nx.relabel_nodes(m_partitioned, canonical_labels, copy=True)
return nx.relabel_nodes(m_refined, canonical_labels, copy=True)

0 comments on commit 8520978

Please sign in to comment.