From 02341828b94a001b4f3cbdae6b24eaeeb0df7d7d Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 17 Jul 2024 10:59:37 -0400 Subject: [PATCH] Improve Keep Redundant Spaces algorithm for PatchedPhaseDiagram (#3900) * fix: old algorithm to deduplicate spaces didn't find the minimum subset * test: direct test for remove_redundant_spaces static method * doc: clean up old comments, add details explaining why patchedphasediagram as_dict doesn't save computations due to shared memory id issue. * pre-commit auto-fixes * lint: spelling --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/pymatgen/analysis/phase_diagram.py | 96 ++++++++++++++++---------- tests/analysis/test_phase_diagram.py | 26 ++++++- 2 files changed, 84 insertions(+), 38 deletions(-) diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py index b9d99a7ca44..6369e2369a0 100644 --- a/src/pymatgen/analysis/phase_diagram.py +++ b/src/pymatgen/analysis/phase_diagram.py @@ -903,7 +903,6 @@ def get_decomp_and_phase_separation_energy( } # NOTE calling PhaseDiagram is only reasonable if the composition has fewer than 5 elements - # TODO can we call PatchedPhaseDiagram here? inner_hull = PhaseDiagram(reduced_space) competing_entries = inner_hull.stable_entries | {*self._get_stable_entries_in_space(entry_elems)} @@ -1036,8 +1035,6 @@ def get_critical_compositions(self, comp1, comp2): if np.all(c1 == c2): return [comp1.copy(), comp2.copy()] - # NOTE made into method to facilitate inheritance of this method - # in PatchedPhaseDiagram if approximate solution can be found. intersections = self._get_simplex_intersections(c1, c2) # find position along line @@ -1619,29 +1616,21 @@ def __init__( # Add the elemental references inds.extend([min_entries.index(el) for el in el_refs.values()]) - self.qhull_entries = tuple(min_entries[idx] for idx in inds) + qhull_entries = tuple(min_entries[idx] for idx in inds) # make qhull spaces frozensets since they become keys to self.pds dict and frozensets are hashable # prevent repeating elements in chemical space and avoid the ordering problem (i.e. Fe-O == O-Fe automatically) - self._qhull_spaces = tuple(frozenset(entry.elements) for entry in self.qhull_entries) + qhull_spaces = tuple(frozenset(entry.elements) for entry in qhull_entries) # Get all unique chemical spaces - spaces = {s for s in self._qhull_spaces if len(s) > 1} + spaces = {s for s in qhull_spaces if len(s) > 1} # Remove redundant chemical spaces - if not keep_all_spaces and len(spaces) > 1: - max_size = max(len(s) for s in spaces) - - systems = set() - # NOTE reduce the number of comparisons by only comparing to larger sets - for idx in range(2, max_size + 1): - test = (s for s in spaces if len(s) == idx) - refer = (s for s in spaces if len(s) > idx) - systems |= {t for t in test if not any(t.issubset(r) for r in refer)} - - spaces = systems + spaces = self.remove_redundant_spaces(spaces, keep_all_spaces) # TODO comprhys: refactor to have self._compute method to allow serialization - self.spaces = sorted(spaces, key=len, reverse=False) # Calculate pds for smaller dimension spaces first + self.spaces = sorted(spaces, key=len, reverse=True) # Calculate pds for smaller dimension spaces last + self.qhull_entries = qhull_entries + self._qhull_spaces = qhull_spaces self.pds = dict(self._get_pd_patch_for_space(s) for s in tqdm(self.spaces, disable=not verbose)) self.all_entries = all_entries self.el_refs = el_refs @@ -1675,7 +1664,19 @@ def __contains__(self, item: frozenset[Element]) -> bool: return item in self.pds def as_dict(self) -> dict[str, Any]: - """ + """Write the entries and elements used to construct the PatchedPhaseDiagram + to a dictionary. + + NOTE unlike PhaseDiagram the computation involved in constructing the + PatchedPhaseDiagram is not saved on serialisation. This is done because + hierarchically calling the `PhaseDiagram.as_dict()` method would break the + link in memory between entries in overlapping patches leading to a + ballooning of the amount of memory used. + + NOTE For memory efficiency the best way to store patched phase diagrams is + via pickling. As this allows all the entries in overlapping patches to share + the same id in memory when unpickling. + Returns: dict[str, Any]: MSONable dictionary representation of PatchedPhaseDiagram. """ @@ -1688,7 +1689,18 @@ def as_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, dct: dict) -> Self: - """ + """Reconstruct PatchedPhaseDiagram from dictionary serialisation. + + NOTE unlike PhaseDiagram the computation involved in constructing the + PatchedPhaseDiagram is not saved on serialisation. This is done because + hierarchically calling the `PhaseDiagram.as_dict()` method would break the + link in memory between entries in overlapping patches leading to a + ballooning of the amount of memory used. + + NOTE For memory efficiency the best way to store patched phase diagrams is + via pickling. As this allows all the entries in overlapping patches to share + the same id in memory when unpickling. + Args: dct (dict): dictionary representation of PatchedPhaseDiagram. @@ -1699,9 +1711,23 @@ def from_dict(cls, dct: dict) -> Self: elements = [Element.from_dict(elem) for elem in dct["elements"]] return cls(entries, elements) + @staticmethod + def remove_redundant_spaces(spaces, keep_all_spaces=False): + if keep_all_spaces or len(spaces) <= 1: + return spaces + + # Sort spaces by size in descending order and pre-compute lengths + sorted_spaces = sorted(spaces, key=len, reverse=True) + + result = [] + for i, space_i in enumerate(sorted_spaces): + if not any(space_i.issubset(larger_space) for larger_space in sorted_spaces[:i]): + result.append(space_i) + + return result + # NOTE following methods are inherited unchanged from PhaseDiagram: # __repr__, - # as_dict, # all_entries_hulldata, # unstable_entries, # stable_entries, @@ -1771,8 +1797,6 @@ def get_equilibrium_reaction_energy(self, entry: Entry) -> float: """ return self.get_phase_separation_energy(entry, stable_only=True) - # NOTE the following functions are not implemented for PatchedPhaseDiagram - def get_decomp_and_e_above_hull( self, entry: PDEntry, @@ -1787,6 +1811,20 @@ def get_decomp_and_e_above_hull( entry=entry, allow_negative=allow_negative, check_stable=check_stable, on_error=on_error ) + def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]: + """ + Args: + space (frozenset[Element]): chemical space of the form A-B-X. + + Returns: + space, PhaseDiagram for the given chemical space + """ + space_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces) if space.issuperset(s)] + + return space, PhaseDiagram(space_entries) + + # NOTE the following functions are not implemented for PatchedPhaseDiagram + def _get_facet_and_simplex(self): """Not Implemented - See PhaseDiagram.""" raise NotImplementedError("_get_facet_and_simplex() not implemented for PatchedPhaseDiagram") @@ -1835,18 +1873,6 @@ def get_chempot_range_stability_phase(self): """Not Implemented - See PhaseDiagram.""" raise NotImplementedError("get_chempot_range_stability_phase() not implemented for PatchedPhaseDiagram") - def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]: - """ - Args: - space (frozenset[Element]): chemical space of the form A-B-X. - - Returns: - space, PhaseDiagram for the given chemical space - """ - space_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces) if space.issuperset(s)] - - return space, PhaseDiagram(space_entries) - class ReactionDiagram: """ diff --git a/tests/analysis/test_phase_diagram.py b/tests/analysis/test_phase_diagram.py index e05464d774d..997a7cd0bac 100644 --- a/tests/analysis/test_phase_diagram.py +++ b/tests/analysis/test_phase_diagram.py @@ -3,6 +3,7 @@ import collections import unittest import unittest.mock +from itertools import combinations from numbers import Number from unittest import TestCase @@ -709,6 +710,7 @@ def setUp(self): self.pd = PhaseDiagram(entries=self.entries) self.ppd = PatchedPhaseDiagram(entries=self.entries) + self.ppd_all = PatchedPhaseDiagram(entries=self.entries, keep_all_spaces=True) # novel entries not in any of the patches self.novel_comps = [Composition("H5C2OP"), Composition("V2PH4C")] @@ -756,7 +758,11 @@ def test_dimensionality(self): # test dims of sub PDs dim_counts = collections.Counter(pd.dim for pd in self.ppd.pds.values()) - assert dim_counts == {3: 7, 2: 6, 4: 2} + assert dim_counts == {4: 2, 3: 2} + + # test dims of sub PDs + dim_counts = collections.Counter(pd.dim for pd in self.ppd_all.pds.values()) + assert dim_counts == {2: 8, 3: 7, 4: 2} def test_get_hull_energy(self): for comp in self.novel_comps: @@ -772,7 +778,7 @@ def test_get_decomp_and_e_above_hull(self): assert np.isclose(e_above_hull_pd, e_above_hull_ppd) def test_repr(self): - assert repr(self.ppd) == str(self.ppd) == "PatchedPhaseDiagram covering 15 sub-spaces" + assert repr(self.ppd) == str(self.ppd) == "PatchedPhaseDiagram covering 4 sub-spaces" def test_as_from_dict(self): ppd_dict = self.ppd.as_dict() @@ -810,7 +816,8 @@ def test_getitem(self): pd = self.ppd[chem_space] assert isinstance(pd, PhaseDiagram) assert chem_space in pd._qhull_spaces - assert str(pd) == "V-C phase diagram\n4 stable phases: \nC, V, V6C5, V2C" + assert len(str(pd)) == 186 + assert str(pd).startswith("V-H-C-O phase diagram\n25 stable phases:") with pytest.raises(KeyError, match="frozenset"): self.ppd[frozenset(map(Element, "HBCNOFPS"))] @@ -830,6 +837,19 @@ def test_setitem_and_delitem(self): assert self.ppd[unlikely_chem_space] == self.pd del self.ppd[unlikely_chem_space] # test __delitem__() and restore original state + def test_remove_redundant_spaces(self): + spaces = tuple(frozenset(entry.elements) for entry in self.ppd.qhull_entries) + # NOTE this is 5 not 4 as "He" is a non redundant space that gets dropped for other reasons + assert len(self.ppd.remove_redundant_spaces(spaces)) == 5 + + test = ( + list(combinations(range(1, 7), 4)) + + list(combinations(range(1, 10), 2)) + + list(combinations([1, 4, 7, 9, 2], 5)) + ) + test = [frozenset(t) for t in test] + assert len(self.ppd.remove_redundant_spaces(test)) == 30 + class TestReactionDiagram(TestCase): def setUp(self):