Skip to content

Commit

Permalink
added tests and docs for mdpow.restart
Browse files Browse the repository at this point in the history
- add tests for mdpow.restart
- clarified docs for Journal and Journalled
- update CHANGES
  • Loading branch information
orbeckst committed Oct 11, 2024
1 parent c31eed0 commit 3c0b00c
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Fixes
* fix rcoulomb in CHARMM energy minimization MDP template file (PR #210)
* fix ensemble.EnsembleAnalysis.check_groups_from_common_ensemble (#212)
* updated versioneer (#285)
* fix that simulation stages cannot be restarted after error (#272)


2022-01-03 0.8.0
Expand Down
27 changes: 20 additions & 7 deletions mdpow/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def incomplete(self):
def incomplete(self, stage):
if not stage in self.stages:
raise ValueError(
"can only assign a registered stage from %(stages)r" % vars(self)
"Can only assign a registered stage from %(stages)r" % vars(self)
)
self.__incomplete = stage

Expand Down Expand Up @@ -157,7 +157,16 @@ def has_completed(self, stage):
return stage in self.history

def has_not_completed(self, stage):
"""Returns ``True`` if the *stage* had been started but not completed yet."""
"""Returns ``True`` if the *stage* had been started but not completed yet.
This is subtly different from ``not`` :func:`has_completed` in
that two things have to be true:
1. No stage is active (which is the case when a restart is attempted).
2. The `stage` has not been completed previously (i.e.,
:func:`has_completed` returns ``False``)
"""
return self.current is None and not self.has_completed(stage)

def clear(self):
Expand Down Expand Up @@ -190,13 +199,14 @@ def __init__(self, *args, **kwargs):
len(self.journal.history)
except AttributeError:
self.journal = Journal(self.protocols)
super(Journalled, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)

def get_protocol(self, protocol):
"""Return method for *protocol*.
- If *protocol* is a real method of the class then the method is
returned.
returned. This method should implement its own use of
:meth:`Journal.start` and :meth:`Journal.completed`.
- If *protocol* is a registered protocol name but no method of
the name exists (i.e. *protocol* is a "dummy protocol") then
Expand All @@ -205,9 +215,12 @@ def get_protocol(self, protocol):
.. function:: dummy_protocol(func, *args, **kwargs)
Runs *func* with the arguments and keywords between calls
to :meth:`Journal.start` and :meth:`Journal.completed`,
with the stage set to *protocol*.
Runs *func* with the arguments and keywords between calls to
:meth:`Journal.start` and :meth:`Journal.completed`, with the
stage set to *protocol*.
The function should return ``True`` on success and ``False`` on
failure.
- Raises a :exc:`ValueError` if the *protocol* is not
registered (i.e. not found in :attr:`Journalled.protocols`).
Expand Down
175 changes: 175 additions & 0 deletions mdpow/tests/test_journals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import pytest

from mdpow import restart


@pytest.fixture
def journal():
return restart.Journal(["pre", "main", "post"])


class TestJournal:
def test_full_sequence(self, journal):
journal.start("pre")
assert journal.current == "pre"
journal.completed("pre")

journal.start("main")
assert journal.current == "main"
journal.completed("main")

journal.start("post")
assert journal.current == "post"
journal.completed("post")

def test_set_wrong_stage_ValueError(self, journal):
with pytest.raises(ValueError, match="Can only assign a registered stage"):
journal.start("BEGIN !")

def test_JournalSequenceError_no_completion(self, journal):
with pytest.raises(restart.JournalSequenceError, match="Cannot start stage"):
journal.start("pre")
assert journal.current == "pre"

journal.start("main")

@pytest.mark.xfail
def test_JournalSequenceError_skip_stage(self, journal):
# Currently allows skipping a stage and does not enforce ALL previous
# stages to have completed.
with pytest.raises(restart.JournalSequenceError, match="Cannot start stage"):
journal.start("pre")
assert journal.current == "pre"
journal.completed("pre")

journal.start("post")

def test_start_idempotent(self, journal):
# test that start() can be called multiple time (#278)
journal.start("pre")
journal.start("pre")
assert journal.current == "pre"

def test_incomplete_known_stage(self, journal):
journal.incomplete = "main"
assert journal.incomplete == "main"

def test_incomplete_unknown_stage_ValueError(self, journal):
with pytest.raises(ValueError, match="Can only assign a registered stage from"):
journal.incomplete = "BEGIN !"

def test_clear(self, journal):
journal.start("pre")
journal.completed("pre")
journal.start("main")
# manually setting incomplete
journal.incomplete = journal.current

assert journal.current == "main"
assert journal.incomplete == journal.current

journal.clear()
assert journal.current is None
assert journal.incomplete is None

def test_history(self, journal):
journal.start("pre")
journal.completed("pre")
journal.start("main")
journal.completed("main")
journal.start("post")

# completed stages
assert journal.history == ["pre", "main"]

def test_history_del(self, journal):
journal.start("pre")
journal.completed("pre")
journal.start("main")
journal.completed("main")
assert journal.history

del journal.history
assert journal.history == []

def test_has_completed(self, journal):
journal.start("pre")
journal.completed("pre")

assert journal.has_completed("pre")
assert not journal.has_completed("main")

def test_has_not_completed(self, journal):
journal.start("pre")
journal.completed("pre")
journal.start("main")
# simulate crash/restart
del journal.current

assert journal.has_not_completed("main")
assert not journal.has_not_completed("pre")


# need a real class so that it can be pickled later
class JournalledMemory(restart.Journalled):
# divide is a dummy protocol
protocols = ["divide", "multiply"]

def __init__(self):
self.memory = 1
super().__init__()

def multiply(self, x):
self.journal.start("multiply")
self.memory *= x
self.journal.completed("multiply")


@pytest.fixture
def journalled():
return JournalledMemory()


class TestJournalled:
@staticmethod
def divide(m, x):
return m.memory / x

def test_get_protocol_of_class(self, journalled):
f = journalled.get_protocol("multiply")
f(10)
assert journalled.memory == 10
assert journalled.journal.has_completed("multiply")

def test_get_protocol_dummy(self, journalled):
dummy_protocol = journalled.get_protocol("divide")
result = dummy_protocol(self.divide, journalled, 10)

assert result == 1 / 10
assert journalled.journal.has_completed("divide")

def test_get_protocol_dummy_incomplete(self, journalled):
dummy_protocol = journalled.get_protocol("divide")
with pytest.raises(ZeroDivisionError):
result = dummy_protocol(self.divide, journalled, 0)
assert not journalled.journal.has_completed("divide")

def test_save_load(self, tmp_path):
# instantiate a class that can be pickled (without pytest magic)
journalled = JournalledMemory()
f = journalled.get_protocol("multiply")
f(10)
assert journalled.memory == 10

pickle = tmp_path / "memory.pkl"
journalled.save(pickle)

assert pickle.exists()

# change instance
f(99)
assert journalled.memory == 10 * 99

# reload previous state
journalled.load(pickle)
assert journalled.memory == 10

0 comments on commit 3c0b00c

Please sign in to comment.