Skip to content

Commit

Permalink
Managed to correctly start and stop nodes with the new async Pool app…
Browse files Browse the repository at this point in the history
…roach.
  • Loading branch information
edavalosanaya committed Sep 20, 2023
1 parent 8d264db commit 49d486a
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 65 deletions.
34 changes: 30 additions & 4 deletions chimerapy/engine/eventbus/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
Literal,
TypeVar,
Dict,
Coroutine,
Tuple,
)

from aioreactive import AsyncObservable, AsyncObserver, AsyncSubject
Expand All @@ -34,6 +36,20 @@ class Event:
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())


def future_wrapper(coroutine: Coroutine) -> Tuple[Coroutine, Future]:

future: Future = Future()

async def wrapper():
try:
result = await coroutine
future.set_result(result)
except Exception as e:
future.set_exception(e)

return wrapper(), future


class EventBus(AsyncObservable):
def __init__(self, thread: Optional[AsyncLoopThread] = None):
self.stream = AsyncSubject()
Expand Down Expand Up @@ -87,12 +103,22 @@ async def await_event(self, event_type: str) -> Event:
####################################################################

def send(self, event: Event) -> Future:
assert isinstance(self.thread, AsyncLoopThread)
return self.thread.exec(self.asend(event))
if isinstance(self.thread, AsyncLoopThread):
return self.thread.exec(self.asend(event))
else:
loop = asyncio.get_event_loop()
wrapper, future = future_wrapper(self.asend(event))
loop.create_task(wrapper)
return future

def subscribe(self, observer: AsyncObserver) -> Future:
assert isinstance(self.thread, AsyncLoopThread)
return self.thread.exec(self.asubscribe(observer))
if isinstance(self.thread, AsyncLoopThread):
return self.thread.exec(self.asubscribe(observer))
else:
loop = asyncio.get_event_loop()
wrapper, future = future_wrapper(self.asubscribe(observer))
loop.create_task(wrapper)
return future


class TypedObserver(AsyncObserver, Generic[T]):
Expand Down
32 changes: 15 additions & 17 deletions chimerapy/engine/worker/http_server_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
NodeDiagnostics,
)
from ..networking import Server
from ..networking.async_loop_thread import AsyncLoopThread
from ..networking.enums import NODE_MESSAGE
from ..utils import update_dataclass
from ..eventbus import EventBus, Event, TypedObserver
Expand All @@ -38,15 +37,13 @@ def __init__(
self,
name: str,
state: WorkerState,
thread: AsyncLoopThread,
eventbus: EventBus,
logger: logging.Logger,
):

# Save input parameters
self.name = name
self.state = state
self.thread = thread
self.eventbus = eventbus
self.logger = logger

Expand Down Expand Up @@ -80,9 +77,22 @@ def __init__(
NODE_MESSAGE.DIAGNOSTICS: self._async_node_diagnostics,
},
parent_logger=self.logger,
thread=self.thread,
)

@property
def ip(self) -> str:
return self._ip

@property
def port(self) -> int:
return self._port

@property
def url(self) -> str:
return f"http://{self._ip}:{self._port}"

async def async_init(self):

# Specify observers
self.observers: Dict[str, TypedObserver] = {
"start": TypedObserver("start", on_asend=self.start, handle_event="drop"),
Expand All @@ -103,19 +113,7 @@ def __init__(
),
}
for ob in self.observers.values():
self.eventbus.subscribe(ob).result(timeout=1)

@property
def ip(self) -> str:
return self._ip

@property
def port(self) -> int:
return self._port

@property
def url(self) -> str:
return f"http://{self._ip}:{self._port}"
await self.eventbus.asubscribe(ob)

async def start(self):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
"threading": ThreadNodeController,
}

async def async_init(self):

# Specify observers
self.observers: Dict[str, TypedObserver] = {
"start": TypedObserver("start", on_asend=self.start, handle_event="drop"),
Expand Down Expand Up @@ -126,7 +128,7 @@ def __init__(
),
}
for ob in self.observers.values():
self.eventbus.subscribe(ob).result(timeout=1)
await self.eventbus.asubscribe(ob)

async def start(self) -> bool:
# Containers
Expand Down Expand Up @@ -246,7 +248,7 @@ async def async_create_node(self, node_config: Union[NodeConfig, Dict]) -> bool:
self.node_controllers[node_config.id] = controller

# Mark success
# self.logger.debug(f"{self}: completed node creation: {id}")
self.logger.debug(f"{self}: completed node creation: {id}")
break

if not success:
Expand All @@ -261,7 +263,9 @@ async def async_destroy_node(self, node_id: str) -> bool:
success = False

if node_id in self.node_controllers:
self.logger.debug(f"{self}: destroying Node {node_id}")
await self.node_controllers[node_id].shutdown()
self.logger.debug(f"{self}: destroyed Node {node_id}")

if node_id in self.state.nodes:
del self.state.nodes[node_id]
Expand Down
4 changes: 2 additions & 2 deletions chimerapy/engine/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
self.http_server = HttpServerService(
name="http_server",
state=self.state,
thread=self._thread,
# thread=self._thread,
eventbus=self.eventbus,
logger=self.logger,
)
Expand Down Expand Up @@ -358,7 +358,7 @@ def shutdown(self, blocking: bool = True) -> Union[Future[bool], bool]:
"""
if not self._alive:
return True

self.logger.info(f"{self}: Shutting down")

# Only execute if thread exists
Expand Down
17 changes: 4 additions & 13 deletions test/worker/node_handler/test_node_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,26 @@

OUTPUT = 1

# class TestNode:
# def run(
# self,
# blocking: bool = True,
# running: Optional[mp.Value] = None, # type: ignore
# eventbus=None,
# ):
# while running:
# time.sleep(0.1)
# return OUTPUT


async def test_mp_node_controller():
session = MPSession()
# node = TestNode()
node = GenNode(name="Gen1")

node_controller = MPNodeController(node, logger) # type: ignore
node_controller.run(session)
await asyncio.sleep(0.25)

await node_controller.shutdown()
assert node_controller.future.result() == OUTPUT


async def test_thread_node_controller():
session = ThreadSession()
# node = TestNode()
node = GenNode(name="Gen1")

node_controller = ThreadNodeController(node, logger) # type: ignore
node_controller.run(session)
await asyncio.sleep(0.25)

await node_controller.shutdown()
assert node_controller.future.result() == OUTPUT
24 changes: 10 additions & 14 deletions test/worker/node_handler/test_node_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import chimerapy.engine as cpe
from chimerapy.engine.worker.node_handler_service import NodeHandlerService
from chimerapy.engine.worker.http_server_service import HttpServerService
from chimerapy.engine.networking.async_loop_thread import AsyncLoopThread
from chimerapy.engine.eventbus import EventBus, make_evented, Event
from chimerapy.engine.states import WorkerState

Expand Down Expand Up @@ -71,13 +70,11 @@ def node_with_reg_methods(logreceiver):
return NodeWithRegisteredMethods(name="RegNode1", debug_port=logreceiver.port)


@pytest.fixture(scope="module")
def node_handler_setup():
@pytest.fixture
async def node_handler_setup():

# Event Loop
thread = AsyncLoopThread()
thread.start()
eventbus = EventBus(thread=thread)
eventbus = EventBus()

# Requirements
state = make_evented(WorkerState(), event_bus=eventbus)
Expand All @@ -93,20 +90,20 @@ def node_handler_setup():
logger=logger,
logreceiver=log_receiver,
)
await node_handler.async_init()

# Necessary dependency
http_server = HttpServerService(
name="http_server", state=state, thread=thread, eventbus=eventbus, logger=logger
name="http_server", state=state, eventbus=eventbus, logger=logger
)
# thread.exec(http_server.start()).result(timeout=10)
eventbus.send(Event("start")).result(timeout=10)
await http_server.async_init()

await eventbus.asend(Event("start"))
yield (node_handler, http_server)

eventbus.send(Event("shutdown"))
await eventbus.asend(Event("shutdown"))


def test_create_service_instance(node_handler_setup):
async def test_create_service_instance(node_handler_setup):
...


Expand Down Expand Up @@ -159,7 +156,7 @@ def step(self):
assert await node_handler.async_destroy_node(node_id)


# @pytest.mark.parametrize("context", ["multiprocessing"]) # , "threading"])
@pytest.mark.skip()
@pytest.mark.parametrize("context", ["multiprocessing", "threading"])
async def test_processing_node_pub_table(
node_handler_setup, gen_node, con_node, context
Expand Down Expand Up @@ -217,7 +214,6 @@ async def test_record_and_collect(node_handler_setup, context):
cpe.NodeConfig(node, context=context)
)

logger.debug("Starting")
assert await node_handler.async_start_nodes()
await asyncio.sleep(1)

Expand Down
23 changes: 10 additions & 13 deletions test/worker/test_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytest_lazyfixture import lazy_fixture

from chimerapy.engine.worker.http_server_service import HttpServerService
from chimerapy.engine.networking.async_loop_thread import AsyncLoopThread
from chimerapy.engine.networking.data_chunk import DataChunk
from chimerapy.engine.networking.client import Client
from chimerapy.engine.networking.enums import NODE_MESSAGE
Expand All @@ -28,27 +27,25 @@ def pickled_gen_node_config(gen_node):
return pickle.dumps(NodeConfig(gen_node))


@pytest.fixture(scope="module")
def http_server():
@pytest.fixture
async def http_server():

# Event Loop
thread = AsyncLoopThread()
thread.start()
eventbus = EventBus(thread=thread)
eventbus = EventBus()

# Requirements
state = WorkerState()

# Create the services
http_server = HttpServerService(
name="http_server", state=state, thread=thread, eventbus=eventbus, logger=logger
name="http_server", state=state, eventbus=eventbus, logger=logger
)
thread.exec(http_server.start()).result(timeout=10)
await http_server.start()
return http_server


@pytest.fixture(scope="module")
def ws_client(http_server):
@pytest.fixture
async def ws_client(http_server):

client = Client(
host=http_server.ip,
Expand All @@ -57,12 +54,12 @@ def ws_client(http_server):
ws_handlers={},
parent_logger=logger,
)
client.connect()
await client.async_connect()
yield client
client.shutdown()
await client.async_shutdown()


def test_http_server_instanciate(http_server):
async def test_http_server_instanciate(http_server):
...


Expand Down

0 comments on commit 49d486a

Please sign in to comment.