From 9519294811c9fbca162412fc361e848bdc6b2117 Mon Sep 17 00:00:00 2001 From: Seung Hyun Kim Date: Sat, 29 Jun 2024 01:03:20 -0500 Subject: [PATCH] test: fix unittest according to change in removing operator list --- elastica/__init__.py | 2 -- .../timestepper/explicit_steppers.py | 2 +- elastica/experimental/timestepper/memory.py | 10 ++++---- elastica/modules/base_system.py | 8 +++---- elastica/modules/protocol.py | 12 ++++++---- elastica/timestepper/__init__.py | 1 - tests/test_math/test_timestepper.py | 6 +++-- tests/test_modules/test_base_system.py | 18 ++++++++++++-- tests/test_modules/test_callbacks.py | 5 +++- tests/test_modules/test_constraints.py | 22 +++++++---------- tests/test_modules/test_damping.py | 24 +++++++------------ .../test_rigid_body_data_structures.py | 3 --- 12 files changed, 58 insertions(+), 55 deletions(-) diff --git a/elastica/__init__.py b/elastica/__init__.py index 21dfddb7..b5a1e8ea 100644 --- a/elastica/__init__.py +++ b/elastica/__init__.py @@ -77,8 +77,6 @@ integrate, PositionVerlet, PEFRL, - RungeKutta4, - EulerForward, extend_stepper_interface, ) from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody diff --git a/elastica/experimental/timestepper/explicit_steppers.py b/elastica/experimental/timestepper/explicit_steppers.py index 827e77b9..6ebda1da 100644 --- a/elastica/experimental/timestepper/explicit_steppers.py +++ b/elastica/experimental/timestepper/explicit_steppers.py @@ -13,7 +13,7 @@ StateType, ) from elastica.systems.protocol import ExplicitSystemProtocol -from .protocol import ExplicitStepperProtocol, MemoryProtocol +from elastica.timestepper.protocol import ExplicitStepperProtocol, MemoryProtocol """ diff --git a/elastica/experimental/timestepper/memory.py b/elastica/experimental/timestepper/memory.py index c669be9b..0948a5ee 100644 --- a/elastica/experimental/timestepper/memory.py +++ b/elastica/experimental/timestepper/memory.py @@ -1,6 +1,11 @@ from typing import Iterator, TypeVar, Generic, Type from elastica.timestepper.protocol import ExplicitStepperProtocol from elastica.typing import SystemCollectionType +from elastica.experimental.timestepper.explicit_steppers import ( + RungeKutta4, + EulerForward, +) + from copy import copy @@ -12,11 +17,6 @@ def make_memory_for_explicit_stepper( ) -> "MemoryCollection": # TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.) - from elastica.timestepper.explicit_steppers import ( - RungeKutta4, - EulerForward, - ) - # is_this_system_a_collection = is_system_a_collection(system) memory_cls: Type diff --git a/elastica/modules/base_system.py b/elastica/modules/base_system.py index 84910587..759abc6b 100644 --- a/elastica/modules/base_system.py +++ b/elastica/modules/base_system.py @@ -247,7 +247,7 @@ def synchronize(self, time: np.float64) -> None: Features are registered in _feature_group_synchronize. """ for func in self._feature_group_synchronize: - func(time) + func(time=time) @final def constrain_values(self, time: np.float64) -> None: @@ -256,7 +256,7 @@ def constrain_values(self, time: np.float64) -> None: Features are registered in _feature_group_constrain_values. """ for func in self._feature_group_constrain_values: - func(time) + func(time=time) @final def constrain_rates(self, time: np.float64) -> None: @@ -265,7 +265,7 @@ def constrain_rates(self, time: np.float64) -> None: Features are registered in _feature_group_constrain_rates. """ for func in self._feature_group_constrain_rates: - func(time) + func(time=time) @final def apply_callbacks(self, time: np.float64, current_step: int) -> None: @@ -274,4 +274,4 @@ def apply_callbacks(self, time: np.float64, current_step: int) -> None: Features are registered in _feature_group_callback. """ for func in self._feature_group_callback: - func(time, current_step) + func(time=time, current_step=current_step) diff --git a/elastica/modules/protocol.py b/elastica/modules/protocol.py index 2dac0c6d..11c39dca 100644 --- a/elastica/modules/protocol.py +++ b/elastica/modules/protocol.py @@ -1,4 +1,5 @@ from typing import Protocol, Generator, TypeVar, Any, Type, overload +from typing import TYPE_CHECKING from typing_extensions import Self # python 3.11: from typing import Self from abc import abstractmethod @@ -20,7 +21,8 @@ import numpy as np -from .operator_group import OperatorGroupFIFO +if TYPE_CHECKING: + from .operator_group import OperatorGroupFIFO class MixinProtocol(Protocol): @@ -55,28 +57,28 @@ def __getitem__(self, i: slice | int) -> "list[SystemType] | SystemType": ... @property def _feature_group_synchronize( self, - ) -> OperatorGroupFIFO[OperatorType, ModuleProtocol]: ... + ) -> "OperatorGroupFIFO[OperatorType, ModuleProtocol]": ... def synchronize(self, time: np.float64) -> None: ... @property def _feature_group_constrain_values( self, - ) -> OperatorGroupFIFO[OperatorType, ModuleProtocol]: ... + ) -> "OperatorGroupFIFO[OperatorType, ModuleProtocol]": ... def constrain_values(self, time: np.float64) -> None: ... @property def _feature_group_constrain_rates( self, - ) -> OperatorGroupFIFO[OperatorType, ModuleProtocol]: ... + ) -> "OperatorGroupFIFO[OperatorType, ModuleProtocol]": ... def constrain_rates(self, time: np.float64) -> None: ... @property def _feature_group_callback( self, - ) -> OperatorGroupFIFO[OperatorCallbackType, ModuleProtocol]: ... + ) -> "OperatorGroupFIFO[OperatorCallbackType, ModuleProtocol]": ... def apply_callbacks(self, time: np.float64, current_step: int) -> None: ... diff --git a/elastica/timestepper/__init__.py b/elastica/timestepper/__init__.py index 25f88341..b1797536 100644 --- a/elastica/timestepper/__init__.py +++ b/elastica/timestepper/__init__.py @@ -9,7 +9,6 @@ from elastica.systems import is_system_a_collection from .symplectic_steppers import PositionVerlet, PEFRL -from .explicit_steppers import RungeKutta4, EulerForward from .protocol import StepperProtocol, SymplecticStepperProtocol diff --git a/tests/test_math/test_timestepper.py b/tests/test_math/test_timestepper.py index 2afa6de6..6a616985 100644 --- a/tests/test_math/test_timestepper.py +++ b/tests/test_math/test_timestepper.py @@ -15,7 +15,7 @@ ) from elastica.timestepper import integrate, extend_stepper_interface -from elastica.timestepper.explicit_steppers import ( +from elastica.experimental.timestepper.explicit_steppers import ( RungeKutta4, EulerForward, ExplicitStepperMixin, @@ -245,7 +245,9 @@ def test_explicit_steppers(self, explicit_stepper): # Before stepping, let's extend the interface of the stepper # while providing memory slots - from elastica.systems.memory import make_memory_for_explicit_stepper + from elastica.experimental.timestepper.memory import ( + make_memory_for_explicit_stepper, + ) memory_collection = make_memory_for_explicit_stepper(stepper, collective_system) from elastica.timestepper import extend_stepper_interface diff --git a/tests/test_modules/test_base_system.py b/tests/test_modules/test_base_system.py index 694bd84f..9d4d9c3b 100644 --- a/tests/test_modules/test_base_system.py +++ b/tests/test_modules/test_base_system.py @@ -186,7 +186,16 @@ def test_constraint(self, load_collection, legal_constraint): simulator_class.finalize() # After finalize check if the created constrain object is instance of the class we have given. assert isinstance( - simulator_class._constraints_operators[-1][-1], legal_constraint + simulator_class._feature_group_constrain_values._operator_collection[-1][ + -1 + ].func.__self__, + legal_constraint, + ) + assert isinstance( + simulator_class._feature_group_constrain_rates._operator_collection[-1][ + -1 + ].func.__self__, + legal_constraint, ) # TODO: this is a dummy test for constrain values and rates find a better way to test them @@ -225,7 +234,12 @@ def test_callback(self, load_collection, legal_callback): simulator_class.collect_diagnostics(rod).using(legal_callback) simulator_class.finalize() # After finalize check if the created callback object is instance of the class we have given. - assert isinstance(simulator_class._callback_operators[-1][-1], legal_callback) + assert isinstance( + simulator_class._feature_group_callback._operator_collection[-1][ + -1 + ].func.__self__, + legal_callback, + ) # TODO: this is a dummy test for apply_callbacks find a better way to test them simulator_class.apply_callbacks(time=0, current_step=0) diff --git a/tests/test_modules/test_callbacks.py b/tests/test_modules/test_callbacks.py index a58901e2..d5830229 100644 --- a/tests/test_modules/test_callbacks.py +++ b/tests/test_modules/test_callbacks.py @@ -161,10 +161,13 @@ def mock_init(self, *args, **kwargs): def test_callback_finalize_correctness(self, load_rod_with_callbacks): scwc, callback_cls = load_rod_with_callbacks + callback_features = [d for d in scwc._callback_list] scwc._finalize_callback() - for x, y in scwc._callback_operators: + for _callback in callback_features: + x = _callback.id() + y = _callback.instantiate() assert type(x) is int assert type(y) is callback_cls diff --git a/tests/test_modules/test_constraints.py b/tests/test_modules/test_constraints.py index c6072c39..6ff952eb 100644 --- a/tests/test_modules/test_constraints.py +++ b/tests/test_modules/test_constraints.py @@ -315,24 +315,20 @@ def constrain_rates(self, *args, **kwargs) -> None: def test_constrain_finalize_correctness(self, load_rod_with_constraints): scwc, bc_cls = load_rod_with_constraints + bc_features = [bc for bc in scwc._constraints_list] scwc._finalize_constraints() + assert not hasattr(scwc, "_constraints_list") - for x, y in scwc._constraints_operators: - assert type(x) is int - assert type(y) is bc_cls + for _constraint in bc_features: + x = _constraint.id() + y = _constraint.instantiate(scwc[x]) + assert isinstance(x, int) + assert isinstance(y, bc_cls) - def test_constraint_properties(self, load_rod_with_constraints): - scwc, _ = load_rod_with_constraints - scwc._finalize_constraints() - - for i in [0, 1, -1]: - x, y = scwc._constraints_operators[i] - mock_rod = scwc[i] # Test system - assert type(x) is int - assert type(y.system) is type(mock_rod) - assert y.system is mock_rod, f"{len(scwc)}" + assert type(y.system) is type(scwc[x]) + assert y.system is scwc[x], f"{len(scwc)}" # Test node indices assert y.constrained_position_idx.size == 0 # Test element indices. TODO: maybe add more generalized test diff --git a/tests/test_modules/test_damping.py b/tests/test_modules/test_damping.py index 2634da84..d3b23dc8 100644 --- a/tests/test_modules/test_damping.py +++ b/tests/test_modules/test_damping.py @@ -180,26 +180,18 @@ def dampen_rates(self, *args, **kwargs) -> None: return scwd, MockDamper - def test_dampen_finalize_correctness(self, load_rod_with_dampers): + def test_dampen_finalize_clear_instances(self, load_rod_with_dampers): scwd, damper_cls = load_rod_with_dampers + damping_features = [d for d in scwd._damping_list] scwd._finalize_dampers() + assert not hasattr(scwd, "_damping_list") - for x, y in scwd._damping_operators: - assert type(x) is int - assert type(y) is damper_cls - - def test_damper_properties(self, load_rod_with_dampers): - scwd, _ = load_rod_with_dampers - scwd._finalize_dampers() - - for i in [0, 1, -1]: - x, y = scwd._damping_operators[i] - mock_rod = scwd[i] - # Test system - assert type(x) is int - assert type(y.system) is type(mock_rod) - assert y.system is mock_rod, f"{len(scwd)}" + for _damping in damping_features: + x = _damping.id() + y = _damping.instantiate(scwd[x]) + assert isinstance(x, int) + assert isinstance(y, damper_cls) @pytest.mark.xfail def test_dampers_finalize_sorted(self, load_rod_with_dampers): diff --git a/tests/test_rigid_body/test_rigid_body_data_structures.py b/tests/test_rigid_body/test_rigid_body_data_structures.py index 0ca7715a..84419952 100644 --- a/tests/test_rigid_body/test_rigid_body_data_structures.py +++ b/tests/test_rigid_body/test_rigid_body_data_structures.py @@ -6,8 +6,6 @@ from elastica.rigidbody.data_structures import _RigidRodSymplecticStepperMixin from elastica._rotations import _rotate from elastica.timestepper import ( - RungeKutta4, - EulerForward, PEFRL, PositionVerlet, integrate, @@ -92,7 +90,6 @@ def analytical_solution(self, type, time): return analytical_solution -ExplicitSteppers = [EulerForward, RungeKutta4] SymplecticSteppers = [PositionVerlet, PEFRL]