Skip to content

Commit

Permalink
Merge pull request GazzolaLab#372 from skim0119/typing/timestepper
Browse files Browse the repository at this point in the history
Typing: timestepper module
  • Loading branch information
skim0119 authored Apr 30, 2024
2 parents f95925d + df474eb commit b152c20
Show file tree
Hide file tree
Showing 20 changed files with 556 additions and 341 deletions.
1 change: 0 additions & 1 deletion elastica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
)
from elastica._linalg import levi_civita_tensor
from elastica.utils import isqrt
from elastica.typing import RodType, SystemType, AllowedContactType
from elastica.timestepper import (
integrate,
PositionVerlet,
Expand Down
4 changes: 2 additions & 2 deletions elastica/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,12 @@ def __init__(
rotational_constraint_selector = np.array([True, True, True])
# properly validate the user-provided constraint selectors
assert (
type(translational_constraint_selector) == np.ndarray
isinstance(translational_constraint_selector, np.ndarray)
and translational_constraint_selector.dtype == bool
and translational_constraint_selector.shape == (3,)
), "Translational constraint selector must be a 1D boolean array of length 3."
assert (
type(rotational_constraint_selector) == np.ndarray
isinstance(rotational_constraint_selector, np.ndarray)
and rotational_constraint_selector.dtype == bool
and rotational_constraint_selector.shape == (3,)
), "Rotational constraint selector must be a 1D boolean array of length 3."
Expand Down
10 changes: 6 additions & 4 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"""
from typing import Iterable, Callable, AnyStr

import numpy as np

from collections.abc import MutableSequence

from elastica.rod import RodBase
Expand Down Expand Up @@ -181,22 +183,22 @@ def finalize(self):
self._feature_group_synchronize.pop(index)
)

def synchronize(self, time: float):
def synchronize(self, time: np.floating):
# Collection call _feature_group_synchronize
for feature in self._feature_group_synchronize:
feature(time)

def constrain_values(self, time: float):
def constrain_values(self, time: np.floating):
# Collection call _feature_group_constrain_values
for feature in self._feature_group_constrain_values:
feature(time)

def constrain_rates(self, time: float):
def constrain_rates(self, time: np.floating):
# Collection call _feature_group_constrain_rates
for feature in self._feature_group_constrain_rates:
feature(time)

def apply_callbacks(self, time: float, current_step: int):
def apply_callbacks(self, time: np.floating, current_step: int):
# Collection call _feature_group_callback
for feature in self._feature_group_callback:
feature(time, current_step)
4 changes: 2 additions & 2 deletions elastica/modules/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ def __init__(
def set_index(self, first_idx, second_idx):
# TODO assert range
# First check if the types of first rod idx and second rod idx variable are same.
assert type(first_idx) == type(
second_idx
assert isinstance(
first_idx, type(second_idx)
), "Type of first_connect_idx :{}".format(
type(first_idx)
) + " is different than second_connect_idx :{}".format(
Expand Down
70 changes: 34 additions & 36 deletions elastica/rod/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,38 @@

# FIXME : Explicit Stepper doesn't work as States lose the
# views they initially had when working with a timestepper.
"""
class _RodExplicitStepperMixin:
def __init__(self):
(
self.state,
self.__deriv_state,
self.position_collection,
self.director_collection,
self.velocity_collection,
self.omega_collection,
self.acceleration_collection,
self.alpha_collection, # angular acceleration
) = _bootstrap_from_data(
"explicit", self.n_elems, self._vector_states, self._matrix_states
)
# def __setattr__(self, name, value):
# np.copy(self.__dict__[name], value)
def __call__(self, time, *args, **kwargs):
self.update_accelerations(time) # Internal, external
# print("KRC", self.state.kinematic_rate_collection)
# print("DEr", self.__deriv_state.rate_collection)
if np.shares_memory(
self.state.kinematic_rate_collection,
self.velocity_collection
# self.__deriv_state.rate_collection
):
print("Shares memory")
else:
print("Explicit states does not share memory")
return self.__deriv_state
"""
# class _RodExplicitStepperMixin:
# def __init__(self):
# (
# self.state,
# self.__deriv_state,
# self.position_collection,
# self.director_collection,
# self.velocity_collection,
# self.omega_collection,
# self.acceleration_collection,
# self.alpha_collection, # angular acceleration
# ) = _bootstrap_from_data(
# "explicit", self.n_elems, self._vector_states, self._matrix_states
# )
#
# # def __setattr__(self, name, value):
# # np.copy(self.__dict__[name], value)
#
# def __call__(self, time, *args, **kwargs):
# self.update_accelerations(time) # Internal, external
#
# # print("KRC", self.state.kinematic_rate_collection)
# # print("DEr", self.__deriv_state.rate_collection)
# if np.shares_memory(
# self.state.kinematic_rate_collection,
# self.velocity_collection
# # self.__deriv_state.rate_collection
# ):
# print("Shares memory")
# else:
# print("Explicit states does not share memory")
# return self.__deriv_state


class _RodSymplecticStepperMixin:
Expand Down Expand Up @@ -472,7 +470,7 @@ def __init__(
self.velocity_collection = velocity_collection
self.omega_collection = omega_collection

def kinematic_rates(self, time, *args, **kwargs):
def kinematic_rates(self, time, prefac):
"""Yields kinematic rates to interact with _KinematicState
Returns
Expand All @@ -488,7 +486,7 @@ def kinematic_rates(self, time, *args, **kwargs):
# Comes from kin_state -> (x,Q) += dt * (v,w) <- First part of dyn_state
return self.velocity_collection, self.omega_collection

def dynamic_rates(self, time, prefac, *args, **kwargs):
def dynamic_rates(self, time, prefac):
"""Yields dynamic rates to add to with _DynamicState
Returns
-------
Expand Down
6 changes: 0 additions & 6 deletions elastica/rod/rod_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,3 @@ def __init__(self):
RodBase does not take any arguments.
"""
pass
# self.position_collection = NotImplemented
# self.omega_collection = NotImplemented
# self.acceleration_collection = NotImplemented
# self.alpha_collection = NotImplemented
# self.external_forces = NotImplemented
# self.external_torques = NotImplemented
2 changes: 1 addition & 1 deletion elastica/systems/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def is_system_a_collection(system):
def is_system_a_collection(system: object) -> bool:
# Check if system is a "collection" of smaller systems
# by checking for the [] method
"""
Expand Down
77 changes: 77 additions & 0 deletions elastica/systems/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
__doc__ = """Base class for elastica system"""

from typing import Protocol

from elastica.typing import StateType
from elastica.rod.data_structures import _KinematicState, _DynamicState

import numpy as np
from numpy.typing import NDArray


class SystemProtocol(Protocol):
"""
Protocol for all elastica system
"""

@property
def n_nodes(self) -> int: ...

@property
def position_collection(self) -> NDArray: ...

@property
def velocity_collection(self) -> NDArray: ...

@property
def acceleration_collection(self) -> NDArray: ...

@property
def omega_collection(self) -> NDArray: ...

@property
def alpha_collection(self) -> NDArray: ...

@property
def external_forces(self) -> NDArray: ...

@property
def external_torques(self) -> NDArray: ...


class SymplecticSystemProtocol(SystemProtocol, Protocol):
"""
Protocol for system with symplectic state variables
"""

@property
def kinematic_states(self) -> _KinematicState: ...

@property
def dynamic_states(self) -> _DynamicState: ...

@property
def rate_collection(self) -> NDArray: ...

@property
def dvdt_dwdt_collection(self) -> NDArray: ...

def kinematic_rates(
self, time: np.floating, prefac: np.floating
) -> tuple[NDArray, NDArray]: ...

def dynamic_rates(
self, time: np.floating, prefac: np.floating
) -> tuple[NDArray]: ...

def update_internal_forces_and_torques(self, time: np.floating) -> None: ...


class ExplicitSystemProtocol(SystemProtocol, Protocol):
# TODO: Temporarily made to handle explicit stepper.
# Need to be refactored as the explicit stepper is further developed.
def __call__(self, time: np.floating, dt: np.floating) -> np.floating: ...
@property
def state(self) -> StateType: ...
@state.setter
def state(self, state: StateType) -> None: ...
85 changes: 39 additions & 46 deletions elastica/timestepper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,78 @@
__doc__ = """Timestepping utilities to be used with Rod and RigidBody classes"""

from typing import Tuple, List, Callable, Type
from elastica.typing import SystemType

import numpy as np
from tqdm import tqdm
from elastica.timestepper.symplectic_steppers import (
SymplecticStepperTag,
PositionVerlet,
PEFRL,
)
from elastica.timestepper.explicit_steppers import (
ExplicitStepperTag,
RungeKutta4,
EulerForward,
)

from elastica.systems import is_system_a_collection

# TODO: Both extend_stepper_interface and integrate should be in separate file.
# __init__ is probably not an ideal place to have these scripts.
def extend_stepper_interface(Stepper, System):
from elastica.utils import extend_instance
from elastica.systems import is_system_a_collection
from .symplectic_steppers import PositionVerlet, PEFRL
from .explicit_steppers import RungeKutta4, EulerForward

# Check if system is a "collection" of smaller systems
# by checking for the [] method
is_this_system_a_collection = is_system_a_collection(System)
from .tag import SymplecticStepperTag, ExplicitStepperTag
from .protocol import StepperProtocol, StatefulStepperProtocol
from .protocol import MethodCollectorProtocol

"""
# Stateful steppers are no more used so remove them
ConcreteStepper = (
Stepper.stepper if _StatefulStepper in Stepper.__class__.mro() else Stepper
)
"""
ConcreteStepper = Stepper

if type(ConcreteStepper.Tag) == SymplecticStepperTag:
# TODO: Both extend_stepper_interface and integrate should be in separate file.
# __init__ is probably not an ideal place to have these scripts.
def extend_stepper_interface(
Stepper: StepperProtocol, System: SystemType
) -> Tuple[Callable, Tuple[Callable]]:

# StepperMethodCollector: Type[MethodCollectorProtocol]
# SystemStepper: Type[StepperProtocol]
if isinstance(Stepper.Tag, SymplecticStepperTag):
from elastica.timestepper.symplectic_steppers import (
_SystemInstanceStepper,
_SystemCollectionStepper,
SymplecticStepperMethods as StepperMethodCollector,
SymplecticStepperMethods,
)
elif type(ConcreteStepper.Tag) == ExplicitStepperTag:

StepperMethodCollector = SymplecticStepperMethods
elif isinstance(Stepper.Tag, ExplicitStepperTag): # type: ignore[no-redef]
from elastica.timestepper.explicit_steppers import (
_SystemInstanceStepper,
_SystemCollectionStepper,
ExplicitStepperMethods as StepperMethodCollector,
ExplicitStepperMethods,
)
# elif SymplecticCosseratRodStepper in ConcreteStepper.__class__.mro():
# return # hacky fix for now. remove HybridSteppers in a future version.

StepperMethodCollector = ExplicitStepperMethods
else:
raise NotImplementedError(
"Only explicit and symplectic steppers are supported, given stepper is {}".format(
ConcreteStepper.__class__.__name__
Stepper.__class__.__name__
)
)

stepper_methods = StepperMethodCollector(ConcreteStepper)
do_step_method = (
_SystemCollectionStepper.do_step
if is_this_system_a_collection
else _SystemInstanceStepper.do_step
)
return do_step_method, stepper_methods.step_methods()
# Check if system is a "collection" of smaller systems
if is_system_a_collection(System):
SystemStepper = _SystemCollectionStepper
else:
SystemStepper = _SystemInstanceStepper

stepper_methods: Tuple[Callable] = StepperMethodCollector(Stepper).step_methods()
do_step_method: Callable = SystemStepper.do_step
return do_step_method, stepper_methods


# TODO Improve interface of this function to take args and kwargs for ease of use
def integrate(
StatefulStepper,
System,
StatefulStepper: StatefulStepperProtocol,
System: SystemType,
final_time: float,
n_steps: int = 1000,
restart_time: float = 0.0,
progress_bar: bool = True,
**kwargs,
):
) -> float:
"""
Parameters
----------
StatefulStepper :
StatefulStepper : StatefulStepperProtocol
Stepper algorithm to use.
System :
System : SystemType
The elastica-system to simulate.
final_time : float
Total simulation time. The timestep is determined by final_time / n_steps.
Expand Down
Loading

0 comments on commit b152c20

Please sign in to comment.