diff --git a/chimerapy/engine/eventbus/eventbus.py b/chimerapy/engine/eventbus/eventbus.py index b091210..a5d0eef 100644 --- a/chimerapy/engine/eventbus/eventbus.py +++ b/chimerapy/engine/eventbus/eventbus.py @@ -13,6 +13,8 @@ Literal, TypeVar, Dict, + Coroutine, + Tuple, ) from aioreactive import AsyncObservable, AsyncObserver, AsyncSubject @@ -34,6 +36,20 @@ class Event: timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat()) +def future_wrapper(coroutine: Coroutine) -> Tuple[Coroutine, Future]: + + future: Future = Future() + + async def wrapper(): + try: + result = await coroutine + future.set_result(result) + except Exception as e: + future.set_exception(e) + + return wrapper(), future + + class EventBus(AsyncObservable): def __init__(self, thread: Optional[AsyncLoopThread] = None): self.stream = AsyncSubject() @@ -87,12 +103,22 @@ async def await_event(self, event_type: str) -> Event: #################################################################### def send(self, event: Event) -> Future: - assert isinstance(self.thread, AsyncLoopThread) - return self.thread.exec(self.asend(event)) + if isinstance(self.thread, AsyncLoopThread): + return self.thread.exec(self.asend(event)) + else: + loop = asyncio.get_event_loop() + wrapper, future = future_wrapper(self.asend(event)) + loop.create_task(wrapper) + return future def subscribe(self, observer: AsyncObserver) -> Future: - assert isinstance(self.thread, AsyncLoopThread) - return self.thread.exec(self.asubscribe(observer)) + if isinstance(self.thread, AsyncLoopThread): + return self.thread.exec(self.asubscribe(observer)) + else: + loop = asyncio.get_event_loop() + wrapper, future = future_wrapper(self.asubscribe(observer)) + loop.create_task(wrapper) + return future class TypedObserver(AsyncObserver, Generic[T]): diff --git a/chimerapy/engine/worker/http_server_service.py b/chimerapy/engine/worker/http_server_service.py index 16cf02a..b91ba5e 100644 --- a/chimerapy/engine/worker/http_server_service.py +++ b/chimerapy/engine/worker/http_server_service.py @@ -15,7 +15,6 @@ NodeDiagnostics, ) from ..networking import Server -from ..networking.async_loop_thread import AsyncLoopThread from ..networking.enums import NODE_MESSAGE from ..utils import update_dataclass from ..eventbus import EventBus, Event, TypedObserver @@ -38,7 +37,6 @@ def __init__( self, name: str, state: WorkerState, - thread: AsyncLoopThread, eventbus: EventBus, logger: logging.Logger, ): @@ -46,7 +44,6 @@ def __init__( # Save input parameters self.name = name self.state = state - self.thread = thread self.eventbus = eventbus self.logger = logger @@ -80,9 +77,22 @@ def __init__( NODE_MESSAGE.DIAGNOSTICS: self._async_node_diagnostics, }, parent_logger=self.logger, - thread=self.thread, ) + @property + def ip(self) -> str: + return self._ip + + @property + def port(self) -> int: + return self._port + + @property + def url(self) -> str: + return f"http://{self._ip}:{self._port}" + + async def async_init(self): + # Specify observers self.observers: Dict[str, TypedObserver] = { "start": TypedObserver("start", on_asend=self.start, handle_event="drop"), @@ -103,19 +113,7 @@ def __init__( ), } for ob in self.observers.values(): - self.eventbus.subscribe(ob).result(timeout=1) - - @property - def ip(self) -> str: - return self._ip - - @property - def port(self) -> int: - return self._port - - @property - def url(self) -> str: - return f"http://{self._ip}:{self._port}" + await self.eventbus.asubscribe(ob) async def start(self): diff --git a/chimerapy/engine/worker/node_handler_service/node_handler_service.py b/chimerapy/engine/worker/node_handler_service/node_handler_service.py index 94e59c3..611267c 100644 --- a/chimerapy/engine/worker/node_handler_service/node_handler_service.py +++ b/chimerapy/engine/worker/node_handler_service/node_handler_service.py @@ -58,6 +58,8 @@ def __init__( "threading": ThreadNodeController, } + async def async_init(self): + # Specify observers self.observers: Dict[str, TypedObserver] = { "start": TypedObserver("start", on_asend=self.start, handle_event="drop"), @@ -126,7 +128,7 @@ def __init__( ), } for ob in self.observers.values(): - self.eventbus.subscribe(ob).result(timeout=1) + await self.eventbus.asubscribe(ob) async def start(self) -> bool: # Containers @@ -246,7 +248,7 @@ async def async_create_node(self, node_config: Union[NodeConfig, Dict]) -> bool: self.node_controllers[node_config.id] = controller # Mark success - # self.logger.debug(f"{self}: completed node creation: {id}") + self.logger.debug(f"{self}: completed node creation: {id}") break if not success: @@ -261,7 +263,9 @@ async def async_destroy_node(self, node_id: str) -> bool: success = False if node_id in self.node_controllers: + self.logger.debug(f"{self}: destroying Node {node_id}") await self.node_controllers[node_id].shutdown() + self.logger.debug(f"{self}: destroyed Node {node_id}") if node_id in self.state.nodes: del self.state.nodes[node_id] diff --git a/chimerapy/engine/worker/worker.py b/chimerapy/engine/worker/worker.py index f59e7ad..95eec62 100644 --- a/chimerapy/engine/worker/worker.py +++ b/chimerapy/engine/worker/worker.py @@ -99,7 +99,7 @@ def __init__( self.http_server = HttpServerService( name="http_server", state=self.state, - thread=self._thread, + # thread=self._thread, eventbus=self.eventbus, logger=self.logger, ) @@ -358,7 +358,7 @@ def shutdown(self, blocking: bool = True) -> Union[Future[bool], bool]: """ if not self._alive: return True - + self.logger.info(f"{self}: Shutting down") # Only execute if thread exists diff --git a/test/worker/node_handler/test_node_controller.py b/test/worker/node_handler/test_node_controller.py index 6717b91..89e8e3d 100644 --- a/test/worker/node_handler/test_node_controller.py +++ b/test/worker/node_handler/test_node_controller.py @@ -17,35 +17,26 @@ OUTPUT = 1 -# class TestNode: -# def run( -# self, -# blocking: bool = True, -# running: Optional[mp.Value] = None, # type: ignore -# eventbus=None, -# ): -# while running: -# time.sleep(0.1) -# return OUTPUT - async def test_mp_node_controller(): session = MPSession() - # node = TestNode() node = GenNode(name="Gen1") + node_controller = MPNodeController(node, logger) # type: ignore node_controller.run(session) await asyncio.sleep(0.25) + await node_controller.shutdown() assert node_controller.future.result() == OUTPUT async def test_thread_node_controller(): session = ThreadSession() - # node = TestNode() node = GenNode(name="Gen1") + node_controller = ThreadNodeController(node, logger) # type: ignore node_controller.run(session) await asyncio.sleep(0.25) + await node_controller.shutdown() assert node_controller.future.result() == OUTPUT diff --git a/test/worker/node_handler/test_node_handler.py b/test/worker/node_handler/test_node_handler.py index 3016905..e941336 100644 --- a/test/worker/node_handler/test_node_handler.py +++ b/test/worker/node_handler/test_node_handler.py @@ -6,7 +6,6 @@ import chimerapy.engine as cpe from chimerapy.engine.worker.node_handler_service import NodeHandlerService from chimerapy.engine.worker.http_server_service import HttpServerService -from chimerapy.engine.networking.async_loop_thread import AsyncLoopThread from chimerapy.engine.eventbus import EventBus, make_evented, Event from chimerapy.engine.states import WorkerState @@ -71,13 +70,11 @@ def node_with_reg_methods(logreceiver): return NodeWithRegisteredMethods(name="RegNode1", debug_port=logreceiver.port) -@pytest.fixture(scope="module") -def node_handler_setup(): +@pytest.fixture +async def node_handler_setup(): # Event Loop - thread = AsyncLoopThread() - thread.start() - eventbus = EventBus(thread=thread) + eventbus = EventBus() # Requirements state = make_evented(WorkerState(), event_bus=eventbus) @@ -93,20 +90,20 @@ def node_handler_setup(): logger=logger, logreceiver=log_receiver, ) + await node_handler.async_init() # Necessary dependency http_server = HttpServerService( - name="http_server", state=state, thread=thread, eventbus=eventbus, logger=logger + name="http_server", state=state, eventbus=eventbus, logger=logger ) - # thread.exec(http_server.start()).result(timeout=10) - eventbus.send(Event("start")).result(timeout=10) + await http_server.async_init() + await eventbus.asend(Event("start")) yield (node_handler, http_server) - - eventbus.send(Event("shutdown")) + await eventbus.asend(Event("shutdown")) -def test_create_service_instance(node_handler_setup): +async def test_create_service_instance(node_handler_setup): ... @@ -159,7 +156,7 @@ def step(self): assert await node_handler.async_destroy_node(node_id) -# @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"]) +@pytest.mark.skip() @pytest.mark.parametrize("context", ["multiprocessing", "threading"]) async def test_processing_node_pub_table( node_handler_setup, gen_node, con_node, context @@ -217,7 +214,6 @@ async def test_record_and_collect(node_handler_setup, context): cpe.NodeConfig(node, context=context) ) - logger.debug("Starting") assert await node_handler.async_start_nodes() await asyncio.sleep(1) diff --git a/test/worker/test_http_server.py b/test/worker/test_http_server.py index 16e9d1a..f4583f2 100644 --- a/test/worker/test_http_server.py +++ b/test/worker/test_http_server.py @@ -7,7 +7,6 @@ from pytest_lazyfixture import lazy_fixture from chimerapy.engine.worker.http_server_service import HttpServerService -from chimerapy.engine.networking.async_loop_thread import AsyncLoopThread from chimerapy.engine.networking.data_chunk import DataChunk from chimerapy.engine.networking.client import Client from chimerapy.engine.networking.enums import NODE_MESSAGE @@ -28,27 +27,25 @@ def pickled_gen_node_config(gen_node): return pickle.dumps(NodeConfig(gen_node)) -@pytest.fixture(scope="module") -def http_server(): +@pytest.fixture +async def http_server(): # Event Loop - thread = AsyncLoopThread() - thread.start() - eventbus = EventBus(thread=thread) + eventbus = EventBus() # Requirements state = WorkerState() # Create the services http_server = HttpServerService( - name="http_server", state=state, thread=thread, eventbus=eventbus, logger=logger + name="http_server", state=state, eventbus=eventbus, logger=logger ) - thread.exec(http_server.start()).result(timeout=10) + await http_server.start() return http_server -@pytest.fixture(scope="module") -def ws_client(http_server): +@pytest.fixture +async def ws_client(http_server): client = Client( host=http_server.ip, @@ -57,12 +54,12 @@ def ws_client(http_server): ws_handlers={}, parent_logger=logger, ) - client.connect() + await client.async_connect() yield client - client.shutdown() + await client.async_shutdown() -def test_http_server_instanciate(http_server): +async def test_http_server_instanciate(http_server): ...