Skip to content

Commit

Permalink
Improve Keep Redundant Spaces algorithm for PatchedPhaseDiagram (mate…
Browse files Browse the repository at this point in the history
…rialsproject#3900)

* fix: old algorithm to deduplicate spaces didn't find the minimum subset

* test: direct test for remove_redundant_spaces static method

* doc: clean up old comments, add details explaining why patchedphasediagram as_dict doesn't save computations due to shared memory id issue.

* pre-commit auto-fixes

* lint: spelling

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
CompRhys and pre-commit-ci[bot] authored Jul 17, 2024
1 parent 454aa5e commit 0234182
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 38 deletions.
96 changes: 61 additions & 35 deletions src/pymatgen/analysis/phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,6 @@ def get_decomp_and_phase_separation_energy(
}

# NOTE calling PhaseDiagram is only reasonable if the composition has fewer than 5 elements
# TODO can we call PatchedPhaseDiagram here?
inner_hull = PhaseDiagram(reduced_space)

competing_entries = inner_hull.stable_entries | {*self._get_stable_entries_in_space(entry_elems)}
Expand Down Expand Up @@ -1036,8 +1035,6 @@ def get_critical_compositions(self, comp1, comp2):
if np.all(c1 == c2):
return [comp1.copy(), comp2.copy()]

# NOTE made into method to facilitate inheritance of this method
# in PatchedPhaseDiagram if approximate solution can be found.
intersections = self._get_simplex_intersections(c1, c2)

# find position along line
Expand Down Expand Up @@ -1619,29 +1616,21 @@ def __init__(
# Add the elemental references
inds.extend([min_entries.index(el) for el in el_refs.values()])

self.qhull_entries = tuple(min_entries[idx] for idx in inds)
qhull_entries = tuple(min_entries[idx] for idx in inds)
# make qhull spaces frozensets since they become keys to self.pds dict and frozensets are hashable
# prevent repeating elements in chemical space and avoid the ordering problem (i.e. Fe-O == O-Fe automatically)
self._qhull_spaces = tuple(frozenset(entry.elements) for entry in self.qhull_entries)
qhull_spaces = tuple(frozenset(entry.elements) for entry in qhull_entries)

# Get all unique chemical spaces
spaces = {s for s in self._qhull_spaces if len(s) > 1}
spaces = {s for s in qhull_spaces if len(s) > 1}

# Remove redundant chemical spaces
if not keep_all_spaces and len(spaces) > 1:
max_size = max(len(s) for s in spaces)

systems = set()
# NOTE reduce the number of comparisons by only comparing to larger sets
for idx in range(2, max_size + 1):
test = (s for s in spaces if len(s) == idx)
refer = (s for s in spaces if len(s) > idx)
systems |= {t for t in test if not any(t.issubset(r) for r in refer)}

spaces = systems
spaces = self.remove_redundant_spaces(spaces, keep_all_spaces)

# TODO comprhys: refactor to have self._compute method to allow serialization
self.spaces = sorted(spaces, key=len, reverse=False) # Calculate pds for smaller dimension spaces first
self.spaces = sorted(spaces, key=len, reverse=True) # Calculate pds for smaller dimension spaces last
self.qhull_entries = qhull_entries
self._qhull_spaces = qhull_spaces
self.pds = dict(self._get_pd_patch_for_space(s) for s in tqdm(self.spaces, disable=not verbose))
self.all_entries = all_entries
self.el_refs = el_refs
Expand Down Expand Up @@ -1675,7 +1664,19 @@ def __contains__(self, item: frozenset[Element]) -> bool:
return item in self.pds

def as_dict(self) -> dict[str, Any]:
"""
"""Write the entries and elements used to construct the PatchedPhaseDiagram
to a dictionary.
NOTE unlike PhaseDiagram the computation involved in constructing the
PatchedPhaseDiagram is not saved on serialisation. This is done because
hierarchically calling the `PhaseDiagram.as_dict()` method would break the
link in memory between entries in overlapping patches leading to a
ballooning of the amount of memory used.
NOTE For memory efficiency the best way to store patched phase diagrams is
via pickling. As this allows all the entries in overlapping patches to share
the same id in memory when unpickling.
Returns:
dict[str, Any]: MSONable dictionary representation of PatchedPhaseDiagram.
"""
Expand All @@ -1688,7 +1689,18 @@ def as_dict(self) -> dict[str, Any]:

@classmethod
def from_dict(cls, dct: dict) -> Self:
"""
"""Reconstruct PatchedPhaseDiagram from dictionary serialisation.
NOTE unlike PhaseDiagram the computation involved in constructing the
PatchedPhaseDiagram is not saved on serialisation. This is done because
hierarchically calling the `PhaseDiagram.as_dict()` method would break the
link in memory between entries in overlapping patches leading to a
ballooning of the amount of memory used.
NOTE For memory efficiency the best way to store patched phase diagrams is
via pickling. As this allows all the entries in overlapping patches to share
the same id in memory when unpickling.
Args:
dct (dict): dictionary representation of PatchedPhaseDiagram.
Expand All @@ -1699,9 +1711,23 @@ def from_dict(cls, dct: dict) -> Self:
elements = [Element.from_dict(elem) for elem in dct["elements"]]
return cls(entries, elements)

@staticmethod
def remove_redundant_spaces(spaces, keep_all_spaces=False):
if keep_all_spaces or len(spaces) <= 1:
return spaces

# Sort spaces by size in descending order and pre-compute lengths
sorted_spaces = sorted(spaces, key=len, reverse=True)

result = []
for i, space_i in enumerate(sorted_spaces):
if not any(space_i.issubset(larger_space) for larger_space in sorted_spaces[:i]):
result.append(space_i)

return result

# NOTE following methods are inherited unchanged from PhaseDiagram:
# __repr__,
# as_dict,
# all_entries_hulldata,
# unstable_entries,
# stable_entries,
Expand Down Expand Up @@ -1771,8 +1797,6 @@ def get_equilibrium_reaction_energy(self, entry: Entry) -> float:
"""
return self.get_phase_separation_energy(entry, stable_only=True)

# NOTE the following functions are not implemented for PatchedPhaseDiagram

def get_decomp_and_e_above_hull(
self,
entry: PDEntry,
Expand All @@ -1787,6 +1811,20 @@ def get_decomp_and_e_above_hull(
entry=entry, allow_negative=allow_negative, check_stable=check_stable, on_error=on_error
)

def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]:
"""
Args:
space (frozenset[Element]): chemical space of the form A-B-X.
Returns:
space, PhaseDiagram for the given chemical space
"""
space_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces) if space.issuperset(s)]

return space, PhaseDiagram(space_entries)

# NOTE the following functions are not implemented for PatchedPhaseDiagram

def _get_facet_and_simplex(self):
"""Not Implemented - See PhaseDiagram."""
raise NotImplementedError("_get_facet_and_simplex() not implemented for PatchedPhaseDiagram")
Expand Down Expand Up @@ -1835,18 +1873,6 @@ def get_chempot_range_stability_phase(self):
"""Not Implemented - See PhaseDiagram."""
raise NotImplementedError("get_chempot_range_stability_phase() not implemented for PatchedPhaseDiagram")

def _get_pd_patch_for_space(self, space: frozenset[Element]) -> tuple[frozenset[Element], PhaseDiagram]:
"""
Args:
space (frozenset[Element]): chemical space of the form A-B-X.
Returns:
space, PhaseDiagram for the given chemical space
"""
space_entries = [e for e, s in zip(self.qhull_entries, self._qhull_spaces) if space.issuperset(s)]

return space, PhaseDiagram(space_entries)


class ReactionDiagram:
"""
Expand Down
26 changes: 23 additions & 3 deletions tests/analysis/test_phase_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections
import unittest
import unittest.mock
from itertools import combinations
from numbers import Number
from unittest import TestCase

Expand Down Expand Up @@ -709,6 +710,7 @@ def setUp(self):

self.pd = PhaseDiagram(entries=self.entries)
self.ppd = PatchedPhaseDiagram(entries=self.entries)
self.ppd_all = PatchedPhaseDiagram(entries=self.entries, keep_all_spaces=True)

# novel entries not in any of the patches
self.novel_comps = [Composition("H5C2OP"), Composition("V2PH4C")]
Expand Down Expand Up @@ -756,7 +758,11 @@ def test_dimensionality(self):

# test dims of sub PDs
dim_counts = collections.Counter(pd.dim for pd in self.ppd.pds.values())
assert dim_counts == {3: 7, 2: 6, 4: 2}
assert dim_counts == {4: 2, 3: 2}

# test dims of sub PDs
dim_counts = collections.Counter(pd.dim for pd in self.ppd_all.pds.values())
assert dim_counts == {2: 8, 3: 7, 4: 2}

def test_get_hull_energy(self):
for comp in self.novel_comps:
Expand All @@ -772,7 +778,7 @@ def test_get_decomp_and_e_above_hull(self):
assert np.isclose(e_above_hull_pd, e_above_hull_ppd)

def test_repr(self):
assert repr(self.ppd) == str(self.ppd) == "PatchedPhaseDiagram covering 15 sub-spaces"
assert repr(self.ppd) == str(self.ppd) == "PatchedPhaseDiagram covering 4 sub-spaces"

def test_as_from_dict(self):
ppd_dict = self.ppd.as_dict()
Expand Down Expand Up @@ -810,7 +816,8 @@ def test_getitem(self):
pd = self.ppd[chem_space]
assert isinstance(pd, PhaseDiagram)
assert chem_space in pd._qhull_spaces
assert str(pd) == "V-C phase diagram\n4 stable phases: \nC, V, V6C5, V2C"
assert len(str(pd)) == 186
assert str(pd).startswith("V-H-C-O phase diagram\n25 stable phases:")

with pytest.raises(KeyError, match="frozenset"):
self.ppd[frozenset(map(Element, "HBCNOFPS"))]
Expand All @@ -830,6 +837,19 @@ def test_setitem_and_delitem(self):
assert self.ppd[unlikely_chem_space] == self.pd
del self.ppd[unlikely_chem_space] # test __delitem__() and restore original state

def test_remove_redundant_spaces(self):
spaces = tuple(frozenset(entry.elements) for entry in self.ppd.qhull_entries)
# NOTE this is 5 not 4 as "He" is a non redundant space that gets dropped for other reasons
assert len(self.ppd.remove_redundant_spaces(spaces)) == 5

test = (
list(combinations(range(1, 7), 4))
+ list(combinations(range(1, 10), 2))
+ list(combinations([1, 4, 7, 9, 2], 5))
)
test = [frozenset(t) for t in test]
assert len(self.ppd.remove_redundant_spaces(test)) == 30


class TestReactionDiagram(TestCase):
def setUp(self):
Expand Down

0 comments on commit 0234182

Please sign in to comment.