Skip to content

Commit

Permalink
use OperatorGroup to handle callback operator
Browse files Browse the repository at this point in the history
  • Loading branch information
skim0119 committed Jun 29, 2024
1 parent 76d46f6 commit d88df67
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 29 deletions.
4 changes: 3 additions & 1 deletion elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(self) -> None:
self._feature_group_constrain_rates: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_callback: list[OperatorCallbackType] = []
self._feature_group_callback: OperatorGroupFIFO[
OperatorCallbackType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_finalize: list[OperatorFinalizeType] = []
# We need to initialize our mixin classes
super().__init__()
Expand Down
35 changes: 15 additions & 20 deletions elastica/modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType
from .protocol import ModuleProtocol

import functools

import numpy as np

from elastica.callback_functions import CallBackBaseClass
Expand All @@ -29,9 +31,7 @@ class CallBacks:

def __init__(self: SystemCollectionProtocol) -> None:
self._callback_list: list[ModuleProtocol] = []
self._callback_operators: list[tuple[int, CallBackBaseClass]] = []
super(CallBacks, self).__init__()
self._feature_group_callback.append(self._callback_execution)
self._feature_group_finalize.append(self._finalize_callback)

def collect_diagnostics(
Expand All @@ -54,31 +54,26 @@ def collect_diagnostics(
sys_idx: SystemIdxType = self.get_system_index(system)

# Create _Constraint object, cache it and return to user
_callbacks: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callbacks)
_callback: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callback)
self._feature_group_callback.append_id(_callback)

return _callbacks
return _callback

def _finalize_callback(self: SystemCollectionProtocol) -> None:
# dev : the first index stores the rod index to collect data.
self._callback_operators = [
(callback.id(), callback.instantiate()) for callback in self._callback_list
]
for callback in self._callback_list:
sys_id = callback.id()
callback_instance = callback.instantiate()

callback_operator = functools.partial(
callback_instance.make_callback, system=self[sys_id]
)
self._feature_group_callback.add_operators(callback, [callback_operator])

self._callback_list.clear()
del self._callback_list

# First callback execution
time = np.float64(0.0)
self._callback_execution(time=time, current_step=0)

def _callback_execution(
self: SystemCollectionProtocol,
time: np.float64,
current_step: int,
) -> None:
for sys_id, callback in self._callback_operators:
callback.make_callback(self[sys_id], time, current_step)


class _CallBack:
"""
Expand Down
11 changes: 3 additions & 8 deletions elastica/modules/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _feature_group_constrain_rates(
def constrain_rates(self, time: np.float64) -> None: ...

@property
def _feature_group_callback(self) -> list[OperatorCallbackType]: ...
def _feature_group_callback(
self,
) -> OperatorGroupFIFO[OperatorCallbackType, ModuleProtocol]: ...

def apply_callbacks(self, time: np.float64, current_step: int) -> None: ...

Expand Down Expand Up @@ -102,18 +104,11 @@ def connect(
# CallBack API
_finalize_callback: OperatorFinalizeType
_callback_list: list[ModuleProtocol]
_callback_operators: list[tuple[int, CallBackBaseClass]]

@abstractmethod
def collect_diagnostics(self, system: SystemType) -> ModuleProtocol:
raise NotImplementedError

@abstractmethod
def _callback_execution(
self, time: np.float64, current_step: int, *args: Any, **kwargs: Any
) -> None:
raise NotImplementedError

# Constraints API
_constraints_list: list[ModuleProtocol]
_finalize_constraints: OperatorFinalizeType
Expand Down

0 comments on commit d88df67

Please sign in to comment.