diff --git a/elastica/modules/base_system.py b/elastica/modules/base_system.py index d6b562f5..84910587 100644 --- a/elastica/modules/base_system.py +++ b/elastica/modules/base_system.py @@ -65,7 +65,9 @@ def __init__(self) -> None: self._feature_group_constrain_rates: OperatorGroupFIFO[ OperatorType, ModuleProtocol ] = OperatorGroupFIFO() - self._feature_group_callback: list[OperatorCallbackType] = [] + self._feature_group_callback: OperatorGroupFIFO[ + OperatorCallbackType, ModuleProtocol + ] = OperatorGroupFIFO() self._feature_group_finalize: list[OperatorFinalizeType] = [] # We need to initialize our mixin classes super().__init__() diff --git a/elastica/modules/callbacks.py b/elastica/modules/callbacks.py index de1e5091..9d886926 100644 --- a/elastica/modules/callbacks.py +++ b/elastica/modules/callbacks.py @@ -9,6 +9,8 @@ from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType from .protocol import ModuleProtocol +import functools + import numpy as np from elastica.callback_functions import CallBackBaseClass @@ -29,9 +31,7 @@ class CallBacks: def __init__(self: SystemCollectionProtocol) -> None: self._callback_list: list[ModuleProtocol] = [] - self._callback_operators: list[tuple[int, CallBackBaseClass]] = [] super(CallBacks, self).__init__() - self._feature_group_callback.append(self._callback_execution) self._feature_group_finalize.append(self._finalize_callback) def collect_diagnostics( @@ -54,31 +54,26 @@ def collect_diagnostics( sys_idx: SystemIdxType = self.get_system_index(system) # Create _Constraint object, cache it and return to user - _callbacks: ModuleProtocol = _CallBack(sys_idx) - self._callback_list.append(_callbacks) + _callback: ModuleProtocol = _CallBack(sys_idx) + self._callback_list.append(_callback) + self._feature_group_callback.append_id(_callback) - return _callbacks + return _callback def _finalize_callback(self: SystemCollectionProtocol) -> None: # dev : the first index stores the rod index to collect data. - self._callback_operators = [ - (callback.id(), callback.instantiate()) for callback in self._callback_list - ] + for callback in self._callback_list: + sys_id = callback.id() + callback_instance = callback.instantiate() + + callback_operator = functools.partial( + callback_instance.make_callback, system=self[sys_id] + ) + self._feature_group_callback.add_operators(callback, [callback_operator]) + self._callback_list.clear() del self._callback_list - # First callback execution - time = np.float64(0.0) - self._callback_execution(time=time, current_step=0) - - def _callback_execution( - self: SystemCollectionProtocol, - time: np.float64, - current_step: int, - ) -> None: - for sys_id, callback in self._callback_operators: - callback.make_callback(self[sys_id], time, current_step) - class _CallBack: """ diff --git a/elastica/modules/protocol.py b/elastica/modules/protocol.py index 90cd5448..2dac0c6d 100644 --- a/elastica/modules/protocol.py +++ b/elastica/modules/protocol.py @@ -74,7 +74,9 @@ def _feature_group_constrain_rates( def constrain_rates(self, time: np.float64) -> None: ... @property - def _feature_group_callback(self) -> list[OperatorCallbackType]: ... + def _feature_group_callback( + self, + ) -> OperatorGroupFIFO[OperatorCallbackType, ModuleProtocol]: ... def apply_callbacks(self, time: np.float64, current_step: int) -> None: ... @@ -102,18 +104,11 @@ def connect( # CallBack API _finalize_callback: OperatorFinalizeType _callback_list: list[ModuleProtocol] - _callback_operators: list[tuple[int, CallBackBaseClass]] @abstractmethod def collect_diagnostics(self, system: SystemType) -> ModuleProtocol: raise NotImplementedError - @abstractmethod - def _callback_execution( - self, time: np.float64, current_step: int, *args: Any, **kwargs: Any - ) -> None: - raise NotImplementedError - # Constraints API _constraints_list: list[ModuleProtocol] _finalize_constraints: OperatorFinalizeType