Skip to content

Commit

Permalink
test: fix unittest according to change in removing operator list
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed Jun 29, 2024
1 parent d88df67 commit 9519294
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 55 deletions.
2 changes: 0 additions & 2 deletions elastica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@
integrate,
PositionVerlet,
PEFRL,
RungeKutta4,
EulerForward,
extend_stepper_interface,
)
from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody
Expand Down
2 changes: 1 addition & 1 deletion elastica/experimental/timestepper/explicit_steppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
StateType,
)
from elastica.systems.protocol import ExplicitSystemProtocol
from .protocol import ExplicitStepperProtocol, MemoryProtocol
from elastica.timestepper.protocol import ExplicitStepperProtocol, MemoryProtocol


"""
Expand Down
10 changes: 5 additions & 5 deletions elastica/experimental/timestepper/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import Iterator, TypeVar, Generic, Type
from elastica.timestepper.protocol import ExplicitStepperProtocol
from elastica.typing import SystemCollectionType
from elastica.experimental.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)


from copy import copy

Expand All @@ -12,11 +17,6 @@ def make_memory_for_explicit_stepper(
) -> "MemoryCollection":
# TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.)

from elastica.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)

# is_this_system_a_collection = is_system_a_collection(system)

memory_cls: Type
Expand Down
8 changes: 4 additions & 4 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def synchronize(self, time: np.float64) -> None:
Features are registered in _feature_group_synchronize.
"""
for func in self._feature_group_synchronize:
func(time)
func(time=time)

@final
def constrain_values(self, time: np.float64) -> None:
Expand All @@ -256,7 +256,7 @@ def constrain_values(self, time: np.float64) -> None:
Features are registered in _feature_group_constrain_values.
"""
for func in self._feature_group_constrain_values:
func(time)
func(time=time)

@final
def constrain_rates(self, time: np.float64) -> None:
Expand All @@ -265,7 +265,7 @@ def constrain_rates(self, time: np.float64) -> None:
Features are registered in _feature_group_constrain_rates.
"""
for func in self._feature_group_constrain_rates:
func(time)
func(time=time)

@final
def apply_callbacks(self, time: np.float64, current_step: int) -> None:
Expand All @@ -274,4 +274,4 @@ def apply_callbacks(self, time: np.float64, current_step: int) -> None:
Features are registered in _feature_group_callback.
"""
for func in self._feature_group_callback:
func(time, current_step)
func(time=time, current_step=current_step)
12 changes: 7 additions & 5 deletions elastica/modules/protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Protocol, Generator, TypeVar, Any, Type, overload
from typing import TYPE_CHECKING
from typing_extensions import Self # python 3.11: from typing import Self

from abc import abstractmethod
Expand All @@ -20,7 +21,8 @@

import numpy as np

from .operator_group import OperatorGroupFIFO
if TYPE_CHECKING:
from .operator_group import OperatorGroupFIFO


class MixinProtocol(Protocol):
Expand Down Expand Up @@ -55,28 +57,28 @@ def __getitem__(self, i: slice | int) -> "list[SystemType] | SystemType": ...
@property
def _feature_group_synchronize(
self,
) -> OperatorGroupFIFO[OperatorType, ModuleProtocol]: ...
) -> "OperatorGroupFIFO[OperatorType, ModuleProtocol]": ...

def synchronize(self, time: np.float64) -> None: ...

@property
def _feature_group_constrain_values(
self,
) -> OperatorGroupFIFO[OperatorType, ModuleProtocol]: ...
) -> "OperatorGroupFIFO[OperatorType, ModuleProtocol]": ...

def constrain_values(self, time: np.float64) -> None: ...

@property
def _feature_group_constrain_rates(
self,
) -> OperatorGroupFIFO[OperatorType, ModuleProtocol]: ...
) -> "OperatorGroupFIFO[OperatorType, ModuleProtocol]": ...

def constrain_rates(self, time: np.float64) -> None: ...

@property
def _feature_group_callback(
self,
) -> OperatorGroupFIFO[OperatorCallbackType, ModuleProtocol]: ...
) -> "OperatorGroupFIFO[OperatorCallbackType, ModuleProtocol]": ...

def apply_callbacks(self, time: np.float64, current_step: int) -> None: ...

Expand Down
1 change: 0 additions & 1 deletion elastica/timestepper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from elastica.systems import is_system_a_collection

from .symplectic_steppers import PositionVerlet, PEFRL
from .explicit_steppers import RungeKutta4, EulerForward
from .protocol import StepperProtocol, SymplecticStepperProtocol


Expand Down
6 changes: 4 additions & 2 deletions tests/test_math/test_timestepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from elastica.timestepper import integrate, extend_stepper_interface

from elastica.timestepper.explicit_steppers import (
from elastica.experimental.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
ExplicitStepperMixin,
Expand Down Expand Up @@ -245,7 +245,9 @@ def test_explicit_steppers(self, explicit_stepper):

# Before stepping, let's extend the interface of the stepper
# while providing memory slots
from elastica.systems.memory import make_memory_for_explicit_stepper
from elastica.experimental.timestepper.memory import (
make_memory_for_explicit_stepper,
)

memory_collection = make_memory_for_explicit_stepper(stepper, collective_system)
from elastica.timestepper import extend_stepper_interface
Expand Down
18 changes: 16 additions & 2 deletions tests/test_modules/test_base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,16 @@ def test_constraint(self, load_collection, legal_constraint):
simulator_class.finalize()
# After finalize check if the created constrain object is instance of the class we have given.
assert isinstance(
simulator_class._constraints_operators[-1][-1], legal_constraint
simulator_class._feature_group_constrain_values._operator_collection[-1][
-1
].func.__self__,
legal_constraint,
)
assert isinstance(
simulator_class._feature_group_constrain_rates._operator_collection[-1][
-1
].func.__self__,
legal_constraint,
)

# TODO: this is a dummy test for constrain values and rates find a better way to test them
Expand Down Expand Up @@ -225,7 +234,12 @@ def test_callback(self, load_collection, legal_callback):
simulator_class.collect_diagnostics(rod).using(legal_callback)
simulator_class.finalize()
# After finalize check if the created callback object is instance of the class we have given.
assert isinstance(simulator_class._callback_operators[-1][-1], legal_callback)
assert isinstance(
simulator_class._feature_group_callback._operator_collection[-1][
-1
].func.__self__,
legal_callback,
)

# TODO: this is a dummy test for apply_callbacks find a better way to test them
simulator_class.apply_callbacks(time=0, current_step=0)
5 changes: 4 additions & 1 deletion tests/test_modules/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,13 @@ def mock_init(self, *args, **kwargs):

def test_callback_finalize_correctness(self, load_rod_with_callbacks):
scwc, callback_cls = load_rod_with_callbacks
callback_features = [d for d in scwc._callback_list]

scwc._finalize_callback()

for x, y in scwc._callback_operators:
for _callback in callback_features:
x = _callback.id()
y = _callback.instantiate()
assert type(x) is int
assert type(y) is callback_cls

Expand Down
22 changes: 9 additions & 13 deletions tests/test_modules/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,24 +315,20 @@ def constrain_rates(self, *args, **kwargs) -> None:

def test_constrain_finalize_correctness(self, load_rod_with_constraints):
scwc, bc_cls = load_rod_with_constraints
bc_features = [bc for bc in scwc._constraints_list]

scwc._finalize_constraints()
assert not hasattr(scwc, "_constraints_list")

for x, y in scwc._constraints_operators:
assert type(x) is int
assert type(y) is bc_cls
for _constraint in bc_features:
x = _constraint.id()
y = _constraint.instantiate(scwc[x])
assert isinstance(x, int)
assert isinstance(y, bc_cls)

def test_constraint_properties(self, load_rod_with_constraints):
scwc, _ = load_rod_with_constraints
scwc._finalize_constraints()

for i in [0, 1, -1]:
x, y = scwc._constraints_operators[i]
mock_rod = scwc[i]
# Test system
assert type(x) is int
assert type(y.system) is type(mock_rod)
assert y.system is mock_rod, f"{len(scwc)}"
assert type(y.system) is type(scwc[x])
assert y.system is scwc[x], f"{len(scwc)}"
# Test node indices
assert y.constrained_position_idx.size == 0
# Test element indices. TODO: maybe add more generalized test
Expand Down
24 changes: 8 additions & 16 deletions tests/test_modules/test_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,26 +180,18 @@ def dampen_rates(self, *args, **kwargs) -> None:

return scwd, MockDamper

def test_dampen_finalize_correctness(self, load_rod_with_dampers):
def test_dampen_finalize_clear_instances(self, load_rod_with_dampers):
scwd, damper_cls = load_rod_with_dampers
damping_features = [d for d in scwd._damping_list]

scwd._finalize_dampers()
assert not hasattr(scwd, "_damping_list")

for x, y in scwd._damping_operators:
assert type(x) is int
assert type(y) is damper_cls

def test_damper_properties(self, load_rod_with_dampers):
scwd, _ = load_rod_with_dampers
scwd._finalize_dampers()

for i in [0, 1, -1]:
x, y = scwd._damping_operators[i]
mock_rod = scwd[i]
# Test system
assert type(x) is int
assert type(y.system) is type(mock_rod)
assert y.system is mock_rod, f"{len(scwd)}"
for _damping in damping_features:
x = _damping.id()
y = _damping.instantiate(scwd[x])
assert isinstance(x, int)
assert isinstance(y, damper_cls)

@pytest.mark.xfail
def test_dampers_finalize_sorted(self, load_rod_with_dampers):
Expand Down
3 changes: 0 additions & 3 deletions tests/test_rigid_body/test_rigid_body_data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from elastica.rigidbody.data_structures import _RigidRodSymplecticStepperMixin
from elastica._rotations import _rotate
from elastica.timestepper import (
RungeKutta4,
EulerForward,
PEFRL,
PositionVerlet,
integrate,
Expand Down Expand Up @@ -92,7 +90,6 @@ def analytical_solution(self, type, time):
return analytical_solution


ExplicitSteppers = [EulerForward, RungeKutta4]
SymplecticSteppers = [PositionVerlet, PEFRL]


Expand Down

0 comments on commit 9519294

Please sign in to comment.