Skip to content

Commit

Permalink
Added exception and keyboard interrupt support in AsyncLoopThread. (#246
Browse files Browse the repository at this point in the history
)
  • Loading branch information
edavalosanaya authored Sep 8, 2023
1 parent 8cd89df commit aa63f32
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 88 deletions.
89 changes: 38 additions & 51 deletions chimerapy/engine/networking/async_loop_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,69 +20,55 @@ def __init__(self):
super().__init__(daemon=True)
self._loop = asyncio.new_event_loop()

def callback(self, coro: Coroutine) -> Tuple[Future, Coroutine]:

future = Future()

async def _wrapper():
output = None
try:
output = await coro
except KeyboardInterrupt:
logger.debug("KeyboardInterrupt DETECTED")
future.set_result(None)
self.stop()
return None
except Exception:
logger.error(traceback.format_exc())

future.set_result(output)
def callback(
self, func: Union[Callable, Coroutine], args: Optional[List[Any]] = None
) -> Tuple[Future, Coroutine]:

return future, _wrapper()

def waitable_callback(
self, callback: Callable, args: List[Any]
) -> Tuple[Future, Callable]:
future: Future = Future()

future = Future()
if args is None:
args = []

# Create wrapper that signals when the callback finished
def _wrapper(func: Callable, *args) -> Any:
output = None
async def _wrapper(func: Union[Callable, Coroutine], *args) -> Any:
try:
output = func(*args)
except KeyboardInterrupt:
logger.debug("KeyboardInterrupt DETECTED")
future.set_result(None)
if asyncio.iscoroutine(func):
result = await func # type: ignore
else:
result = func(*args) # type: ignore
future.set_result(result)
return result
except (KeyboardInterrupt, Exception) as e:
if isinstance(e, KeyboardInterrupt):
logger.debug("KeyboardInterrupt DETECTED")
else:
logger.error(traceback.format_exc())
future.set_exception(e)
self.stop()
return None
except Exception as e:
logger.error(traceback.format_exc())

future.set_result(output)
return output

wrapper = _wrapper(callback, *args)
wrapper = _wrapper(func, *args)
return future, wrapper

def exec(self, coro: Coroutine) -> threading.Event:
def exec(self, coro: Coroutine) -> Future:
if self._loop.is_closed():
raise RuntimeError(
"AsyncLoopThread: Event loop is closed, but a coroutine was sent to it."
)

finished, wrapper = self.callback(coro)
asyncio.run_coroutine_threadsafe(wrapper, self._loop)
return finished

def exec_noncoro(
self, callback: Callable, args: List[Any], waitable: bool = False
) -> Optional[threading.Event]:

if waitable:
finished, wrapper = self.waitable_callback(callback, args)
self._loop.call_soon_threadsafe(wrapper, *args)
return finished
def exec_noncoro(self, callback: Callable, args: List[Any]) -> Future:
if self._loop.is_closed():
raise RuntimeError(
"AsyncLoopThread: Event loop is closed, but a coroutine was sent to it."
)

else:
self._loop.call_soon_threadsafe(callback, *args)

return None
finished, wrapper = self.callback(callback, args)
asyncio.run_coroutine_threadsafe(wrapper, self._loop)
return finished

def run(self):
asyncio.set_event_loop(self._loop)
Expand All @@ -91,13 +77,14 @@ def run(self):
except KeyboardInterrupt:
...
finally:
self.stop()
self._loop.close()
def flush(self, timeout: Optional[Union[int, float]] = None):

def flush(self, timeout: Optional[Union[int, float]] = None):
tasks = asyncio.all_tasks(self._loop)
if tasks:
coro = asyncio.gather(*tasks)
self.exec(coro).result(timeout=timeout)
self.exec(coro).result(timeout=timeout) # type: ignore

def stop(self):
# Cancel all tasks
Expand Down
24 changes: 12 additions & 12 deletions test/logger/test_zmq_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@
)


# @pytest.mark.skip(reason="Flaky")
def get_log_and_messages(port, node_id):
zmq_push_handler = NodeIdZMQPushHandler("localhost", port)
logger = logging.getLogger("test")
logger.setLevel(logging.DEBUG)
zmq_push_handler.setLevel(logging.DEBUG)
logger.addHandler(zmq_push_handler)
zmq_push_handler.register_node_id(node_id)
assert len(logger.handlers) == 1
for j in range(10):
logger.debug(f"Message {j}")


def test_zmq_push_pull_node_id_logging():
handler = BufferingHandler(capacity=300)
handler.setLevel(logging.DEBUG)
logreceiver = NodeIDZMQPullListener(handlers=[handler])
logreceiver.start()
ids = [str(uuid.uuid4()) for _ in range(2)]

def get_log_and_messages(port, node_id):
zmq_push_handler = NodeIdZMQPushHandler("localhost", port)
logger = logging.getLogger("test")
logger.setLevel(logging.DEBUG)
zmq_push_handler.setLevel(logging.DEBUG)
logger.addHandler(zmq_push_handler)
zmq_push_handler.register_node_id(node_id)
assert len(logger.handlers) == 1
for j in range(10):
logger.debug(f"Message {j}")

p1 = Process(target=get_log_and_messages, args=(logreceiver.port, ids[0]))
p2 = Process(target=get_log_and_messages, args=(logreceiver.port, ids[1]))

Expand Down
7 changes: 4 additions & 3 deletions test/streams/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import datetime
import glob
import uuid
import wave
import tempfile

Expand Down Expand Up @@ -42,10 +41,11 @@ def audio_node():


def get_wav_files():
data_dir = TEST_DATA_DIR / 'audio'
data_dir = TEST_DATA_DIR / "audio"
return glob.glob(str(data_dir / "*.wav"))


@pytest.mark.skip(reason="Test taking way to long?")
@pytest.mark.parametrize("input_file", get_wav_files())
def test_audio_writer(input_file):
save_dir = pathlib.Path(tempfile.mkdtemp())
Expand Down Expand Up @@ -80,6 +80,7 @@ def test_audio_writer(input_file):
assert inp.getparams() == out.getparams()
assert inp.readframes(inp.getnframes()) == out.readframes(out.getnframes())


def test_audio_record():

# Check that the audio was created
Expand All @@ -104,7 +105,7 @@ def test_audio_record():
"format": FORMAT,
"rate": RATE,
"recorder_version": 1,
"timestamp": datetime.datetime.now()
"timestamp": datetime.datetime.now(),
}
ar.write(audio_chunk)

Expand Down
91 changes: 69 additions & 22 deletions test/test_threaded_async.py → test/test_async_loop_thread.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Built-in Imports
import time
import asyncio
import threading

# Third-party Imports
import pytest
from pytest import raises

# ChimeraPy Imports
from chimerapy.engine.networking.async_loop_thread import AsyncLoopThread
Expand All @@ -27,13 +27,13 @@ def test_coroutine_waitable_execution(thread):

async def put(queue):
logger.debug("PUT")
await asyncio.sleep(1)
await asyncio.sleep(0.1)
logger.debug("AFTER SLEEP")
await queue.put(1)
logger.debug("FINISHED PUT")

future = thread.exec(put(queue))
future.result(timeout=5)
future.result(timeout=3)
assert queue.qsize() == 1


Expand All @@ -44,53 +44,100 @@ def put(queue):
logger.debug("put called")
queue.put_nowait(1)

thread.exec_noncoro(put, args=[queue])
time.sleep(5)
future = thread.exec_noncoro(put, args=[queue])
future.result(timeout=1)
assert queue.qsize() == 1


def test_callback_execution_with_wait(thread):
queue = asyncio.Queue()

def put(queue):
time.sleep(1)
time.sleep(0.1)
logger.debug("put called")
queue.put_nowait(1)

future = thread.exec_noncoro(put, args=[queue], waitable=True)
future.result(timeout=5)
future = thread.exec_noncoro(put, args=[queue])
future.result(timeout=1)
assert queue.qsize() == 1


def test_keyboard_interrupt_handling_noncoro(thread):
queue = asyncio.Queue()

# Let's simulate a KeyboardInterrupt using threading after a short delay.
def raise_keyboard_interrupt(queue):
time.sleep(1)
time.sleep(0.1)
queue.put_nowait(1)
raise KeyboardInterrupt

future = thread.exec_noncoro(raise_keyboard_interrupt, args=[queue], waitable=True)
future.result(timeout=5)

future = thread.exec_noncoro(raise_keyboard_interrupt, args=[queue])

with raises(KeyboardInterrupt):
future.result(timeout=1)

assert queue.qsize() == 1
thread.join()
assert thread._loop.is_closed()


def test_keyboard_interrupt_handling_coro(thread):
queue = asyncio.Queue()

# Let's simulate a KeyboardInterrupt using threading after a short delay.
async def raise_keyboard_interrupt(queue):
await asyncio.sleep(1)
await asyncio.sleep(0.1)
await queue.put(1)
raise KeyboardInterrupt

future = thread.exec(raise_keyboard_interrupt(queue))
future.result(timeout=5)
with raises(KeyboardInterrupt):
future.result()
assert queue.qsize() == 1

thread.join()
assert thread._loop.is_closed()



def test_exception_handling_noncoro(thread):
queue = asyncio.Queue()

# Let's simulate a KeyboardInterrupt using threading after a short delay.
def raise_keyboard_interrupt(queue):
time.sleep(0.1)
queue.put_nowait(1)
raise TypeError

future = thread.exec_noncoro(raise_keyboard_interrupt, args=[queue])

with raises(TypeError):
future.result(timeout=1)

assert queue.qsize() == 1
thread.join()
assert thread._loop.is_closed()

with raises(RuntimeError):
future = thread.exec_noncoro(raise_keyboard_interrupt, args=[queue])


def test_exception_handling_coro(thread):
queue = asyncio.Queue()

# Let's simulate a KeyboardInterrupt using threading after a short delay.
async def raise_keyboard_interrupt(queue):
await asyncio.sleep(0.1)
await queue.put(1)
raise TypeError

future = thread.exec(raise_keyboard_interrupt(queue))

with raises(TypeError):
future.result(timeout=1)

assert queue.qsize() == 1
thread.join()
assert thread._loop.is_closed()

with raises(RuntimeError):
future = thread.exec(raise_keyboard_interrupt(queue))

0 comments on commit aa63f32

Please sign in to comment.