diff --git a/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py b/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py index 2b5907dc8..829a41b4f 100644 --- a/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py +++ b/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py @@ -90,7 +90,7 @@ def __call__( if self.do_vacuum: chem_sys = ChemicalSystem( - components={RFEComponentLabels.LIGAND: component}, + components={RFEComponentLabels.LIGAND.value: component}, name=component.name + "_vacuum", ) yield chem_sys @@ -98,8 +98,8 @@ def __call__( if self.solvent is not None: chem_sys = ChemicalSystem( components={ - RFEComponentLabels.LIGAND: component, - RFEComponentLabels.SOLVENT: self.solvent, + RFEComponentLabels.LIGAND.value: component, + RFEComponentLabels.SOLVENT.value: self.solvent, }, name=component.name + "_solvent", ) @@ -108,13 +108,13 @@ def __call__( components: dict[str, Component] if self.protein is not None: components = { - RFEComponentLabels.LIGAND: component, - RFEComponentLabels.PROTEIN: self.protein, + RFEComponentLabels.LIGAND.value: component, + RFEComponentLabels.PROTEIN.value: self.protein, } for i, c in enumerate(self.cofactors): - components.update({f'{RFEComponentLabels.COFACTOR}{i+1}': c}) + components.update({f'{RFEComponentLabels.COFACTOR.value}{i+1}': c}) if self.solvent is not None: - components.update({RFEComponentLabels.SOLVENT: self.solvent}) + components.update({RFEComponentLabels.SOLVENT.value: self.solvent}) chem_sys = ChemicalSystem( components=components, name=component.name + "_complex" ) diff --git a/openfe/tests/setup/chemicalsystem_generator/component_checks.py b/openfe/tests/setup/chemicalsystem_generator/component_checks.py index cfea082ce..6b703a346 100644 --- a/openfe/tests/setup/chemicalsystem_generator/component_checks.py +++ b/openfe/tests/setup/chemicalsystem_generator/component_checks.py @@ -14,3 +14,7 @@ def solventC_in_chem_sys(chemical_system: ChemicalSystem) -> bool: def proteinC_in_chem_sys(chemical_system: ChemicalSystem) -> bool: return RFEComponentLabels.PROTEIN in chemical_system.components + +def cofactorC_in_chem_sys(chemical_system: ChemicalSystem) -> bool: + # cofactors are numbered from 1 + return f"{RFEComponentLabels.COFACTOR.value}1" in chemical_system.components diff --git a/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py b/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py index bf706923f..29dc074a9 100644 --- a/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py +++ b/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py @@ -11,7 +11,7 @@ from ...conftest import T4_protein_component from gufe import SolventComponent -from .component_checks import proteinC_in_chem_sys, solventC_in_chem_sys, ligandC_in_chem_sys +from .component_checks import proteinC_in_chem_sys, solventC_in_chem_sys, ligandC_in_chem_sys, cofactorC_in_chem_sys def test_easy_chemical_system_generator_init(T4_protein_component): @@ -55,7 +55,6 @@ def test_build_solvent_chemical_system(ethane): def test_build_protein_chemical_system(ethane, T4_protein_component): - # TODO: cofactors with eg5 system chem_sys_generator = EasyChemicalSystemGenerator( protein=T4_protein_component, ) @@ -66,6 +65,21 @@ def test_build_protein_chemical_system(ethane, T4_protein_component): assert proteinC_in_chem_sys(chem_sys) assert not solventC_in_chem_sys(chem_sys) assert ligandC_in_chem_sys(chem_sys) + assert not cofactorC_in_chem_sys(chem_sys) + +def test_build_cofactor_chemical_system(eg5_cofactor, eg5_ligands, eg5_protein): + chem_sys_generator = EasyChemicalSystemGenerator( + cofactors=[eg5_cofactor], protein=eg5_protein + ) + chem_sys = next(chem_sys_generator(eg5_ligands[0])) + + assert chem_sys is not None + assert isinstance(chem_sys, ChemicalSystem) + assert proteinC_in_chem_sys(chem_sys) + assert not solventC_in_chem_sys(chem_sys) + assert ligandC_in_chem_sys(chem_sys) + assert cofactorC_in_chem_sys(chem_sys) + def test_build_hydr_scenario_chemical_systems(ethane): @@ -91,7 +105,6 @@ def test_build_binding_scenario_chemical_systems(ethane, T4_protein_component): assert len(chem_syss) == 2 assert all([isinstance(chem_sys, ChemicalSystem) for chem_sys in chem_syss]) - print(chem_syss) assert [proteinC_in_chem_sys(chem_sys) for chem_sys in chem_syss] == [False, True] assert [solventC_in_chem_sys(chem_sys) for chem_sys in chem_syss] == [True, True] assert [ligandC_in_chem_sys(chem_sys) for chem_sys in chem_syss] == [True, True] diff --git a/openfecli/tests/commands/test_plan_rbfe_network.py b/openfecli/tests/commands/test_plan_rbfe_network.py index 523909344..25b703811 100644 --- a/openfecli/tests/commands/test_plan_rbfe_network.py +++ b/openfecli/tests/commands/test_plan_rbfe_network.py @@ -2,7 +2,6 @@ import pytest from importlib import resources -import os import shutil from click.testing import CliRunner @@ -148,9 +147,18 @@ def test_plan_rbfe_network_cofactors(eg5_files): with runner.isolated_filesystem(): result = runner.invoke(plan_rbfe_network, args) - print(result.output) - assert result.exit_code == 0 + # make sure the cofactor is in the transformations + network = AlchemicalNetwork.from_dict( + json.load(open("alchemicalNetwork/alchemicalNetwork.json"), cls=JSON_HANDLER.decoder) + ) + for edge in network.edges: + if "protein" in edge.stateA.components: + assert "cofactor1" in edge.stateA.components + assert "cofactor1" in edge.stateB.components + else: + assert "cofactor1" not in edge.stateA.components + assert "cofactor1" not in edge.stateB.components @pytest.fixture def cdk8_files():