Skip to content

Commit

Permalink
change signature of Protocol.create to have a list of ComponentMappin…
Browse files Browse the repository at this point in the history
…gs (#260)

* change signature of Protocol.create to have a list of ComponentMappings

was previously a dict[str, ComponentMapping] where the strings were arbitrary labels.  These arbitrary labels were superfluous.  Instead can use the arbitrary labels on ChemicalSystems to label components.  ComponentMappings can (should) be matched against components in ChemicalSystem using the Component eq operations.

also affects Transformation object, which was essentially wrapper around Protocol + ChemicalSystem

* make mapping input ComponentMapping or list thereof

makes input for most cases simpler

* fixing up change to mapping argument

* fix up test for .to_rbfe_alchemical_network

* make _create signature more predictable

massage mapping inputs into list in Protocol.create

* final fixups for Protocol.create API change

* final fixups for Protocol.create API change

* allow None type for mapping in _create

* doc fixups

* another doc fixup

* typing fixes

---------

Co-authored-by: Irfan Alibay <[email protected]>
  • Loading branch information
richardjgowers and IAlibay authored Jan 24, 2024
1 parent a426dce commit c12bba6
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 31 deletions.
13 changes: 5 additions & 8 deletions gufe/ligandnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ def _to_rfe_alchemical_network(
components: dict[str, :class:`.Component`]
non-alchemical components (components that will be on both sides
of a transformation)
leg_label: dict[str, list[str]]
leg_labels: dict[str, list[str]]
mapping of the names for legs (the keys of this dict) to a list
of the component names. The componnent names must be the same as
as used in the ``componentns`` dict.
of the component names. The component names must be the same as
used in the ``components`` dict.
protocol: :class:`.Protocol`
the protocol to apply
alchemical_label: str
Expand Down Expand Up @@ -237,12 +237,9 @@ def sys_from_dict(component):
else:
name = ""

mapping: dict[str, gufe.ComponentMapping] = {
alchemical_label: edge,
}

transformation = gufe.Transformation(sysA, sysB, protocol,
mapping, name)
mapping=edge,
name=name)

transformations.append(transformation)

Expand Down
17 changes: 9 additions & 8 deletions gufe/protocols/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,16 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, ComponentMapping]] = None,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]],
extends: Optional[ProtocolDAGResult] = None,
) -> list[ProtocolUnit]:
"""Method to override in custom :class:`Protocol` subclasses.
This method should take two `ChemicalSystem`s, and optionally a
dict mapping string to ``ComponentMapping``, and prepare a collection of ``ProtocolUnit`` instances
that when executed in order give sufficient information to estimate the
free energy difference between those two `ChemicalSystem`s.
This method should take two `ChemicalSystem`s, and optionally one or
more ``ComponentMapping`` objects, and prepare a collection of
``ProtocolUnit`` instances that when executed in order give sufficient
information to estimate the free energy difference between those two
`ChemicalSystem`s.
This method should return a list of `ProtocolUnit` instances.
For an instance in which another `ProtocolUnit` is given as a parameter
Expand All @@ -170,7 +171,7 @@ def create(
*,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Union[dict[str, ComponentMapping], None],
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]],
extends: Optional[ProtocolDAGResult] = None,
name: Optional[str] = None,
transformation_key: Optional[GufeKey] = None
Expand All @@ -192,9 +193,9 @@ def create(
The starting `ChemicalSystem` for the transformation.
stateB : ChemicalSystem
The ending `ChemicalSystem` for the transformation.
mapping : Optional[dict[str, ComponentMapping]]
mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]]
Mappings of e.g. atoms between a labelled component in the
stateA and stateB `ChemicalSystem` .
stateA and stateB `ChemicalSystem` .
extends : Optional[ProtocolDAGResult]
If provided, then the `ProtocolDAG` produced will start from the
end state of the given `ProtocolDAGResult`. This allows for
Expand Down
4 changes: 2 additions & 2 deletions gufe/tests/test_ligand_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ def test_to_rbfe_alchemical_network(
assert compsA.get('protein') == compsB.get('protein')
assert compsA.get('cofactor') == compsB.get('cofactor')

assert list(edge.mapping) == ['ligand']
assert edge.mapping['ligand'] in real_molecules_network.edges
assert isinstance(edge.mapping, gufe.ComponentMapping)
assert edge.mapping in real_molecules_network.edges

def test_to_rbfe_alchemical_network_autoname_false(
self,
Expand Down
6 changes: 3 additions & 3 deletions gufe/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, ComponentMapping]] = None,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]=None,
extends: Optional[ProtocolDAGResult] = None,
) -> List[ProtocolUnit]:

Expand Down Expand Up @@ -172,7 +172,7 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, ComponentMapping]] = None,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]=None,
extends: Optional[ProtocolDAGResult] = None,
) -> list[ProtocolUnit]:

Expand Down Expand Up @@ -513,7 +513,7 @@ def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[dict[str, ComponentMapping]] = None,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None,
extends: Optional[ProtocolDAGResult] = None,
) -> List[ProtocolUnit]:
return [NoDepUnit(settings=self.settings,
Expand Down
6 changes: 3 additions & 3 deletions gufe/tests/test_protocoldag.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def _default_settings(cls):
def _defaults(cls):
return {}

def _create(self, stateA, stateB, mapping=None, extends=None) -> list[gufe.ProtocolUnit]:
def _create(self, stateA, stateB, mapping, extends=None) -> list[gufe.ProtocolUnit]:
return [
WriterUnit(identity=i) for i in range(self.settings.n_repeats) # type: ignore
WriterUnit(identity=i) for i in range(self.settings.n_repeats) # type: ignore
]

def _gather(self, results):
Expand All @@ -69,7 +69,7 @@ def writefile_dag():

p = WriterProtocol(settings=WriterProtocol.default_settings())

return p.create(stateA=s1, stateB=s2, mapping={})
return p.create(stateA=s1, stateB=s2, mapping=[])


@pytest.mark.parametrize('keep_shared', [False, True])
Expand Down
16 changes: 9 additions & 7 deletions gufe/transformations/transformation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/gufe

from typing import Optional, Iterable
from typing import Optional, Iterable, Union
import json

from ..tokenization import GufeTokenizable, JSON_HANDLER
Expand All @@ -17,16 +17,20 @@ class Transformation(GufeTokenizable):
Connects two :class:`ChemicalSystem` objects, with directionality.
"""
_stateA: ChemicalSystem
_stateB: ChemicalSystem
_name: Optional[str]
_mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]]
_protocol: Protocol

def __init__(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
protocol: Protocol,
mapping: Optional[dict[str, ComponentMapping]] = None,
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]] = None,
name: Optional[str] = None,
):

self._stateA = stateA
self._stateB = stateB
self._mapping = mapping
Expand Down Expand Up @@ -64,10 +68,8 @@ def protocol(self) -> Protocol:
return self._protocol

@property
def mapping(self) -> Optional[dict[str, ComponentMapping]]:
"""
Mapping of e.g. atoms between ``stateA`` and ``stateB``.
"""
def mapping(self) -> Optional[Union[ComponentMapping, list[ComponentMapping]]]:
"""The mappings relevant for this Transformation"""
return self._mapping

@property
Expand Down

0 comments on commit c12bba6

Please sign in to comment.