diff --git a/chimerapy/engine/networking/server.py b/chimerapy/engine/networking/server.py index 541d61b..3710773 100644 --- a/chimerapy/engine/networking/server.py +++ b/chimerapy/engine/networking/server.py @@ -314,7 +314,7 @@ async def _websocket_handler(self, request): # Remove client id target_client_id: Optional[str] = None - for client_id, client_ws in self.ws_clients.values(): + for client_id, client_ws in self.ws_clients.items(): if client_ws == ws: target_client_id = client_id @@ -429,14 +429,18 @@ async def async_shutdown(self) -> bool: self.logger.debug(f"{self}: Tried to shutdown while not running.") return True - for ws in self.ws_clients.values(): - try: - await asyncio.wait_for( - ws.close(), - timeout=2, - ) - except (asyncio.exceptions.TimeoutError, RuntimeError): - pass + for client_id in list( + self.ws_clients + ): # Copying the list to avoid changing the dict + ws = self.ws_clients.get(client_id, None) + if ws is not None: + try: + await asyncio.wait_for( + ws.close(), + timeout=2, + ) + except (asyncio.exceptions.TimeoutError, RuntimeError): + pass # Cleanup and signal complete await asyncio.wait_for(self._runner.shutdown(), timeout=10) diff --git a/chimerapy/engine/node/node.py b/chimerapy/engine/node/node.py index 83524bb..3987cce 100644 --- a/chimerapy/engine/node/node.py +++ b/chimerapy/engine/node/node.py @@ -82,7 +82,6 @@ def __init__( # Generic Node needs self.logger: logging.Logger = logging.getLogger("chimerapy-engine-node") self.logging_level: int = logging.DEBUG - self.start_time = datetime.datetime.now() # Default values self.node_config = NodeConfig() @@ -204,7 +203,6 @@ def save_video(self, name: str, data: np.ndarray, fps: int): "data": data, "dtype": "video", "fps": fps, - "elapsed": (timestamp - self.start_time).total_seconds(), "timestamp": timestamp, } self.recorder.submit(video_entry) @@ -373,6 +371,43 @@ def save_json(self, name: str, data: Dict[Any, Any]): } self.recorder.submit(json_entry) + def save_text(self, name: str, data: str, suffix="txt"): + """Record text data from the node to a text file. + + Parameters + ---------- + name : str + Name of the text file (.suffix extension will be suffixed). + + data : str + The data to be recorded. + + suffix : str + The suffix of the text file. + + Notes + ----- + It should be noted that new lines addition should be taken by the callee. + """ + + if not self.recorder: + self.logger.warning( + f"{self}: cannot perform recording operation without RecorderService " + "initialization" + ) + return False + + if self.recorder.enabled: + text_entry = { + "uuid": uuid.uuid4(), + "name": name, + "data": data, + "suffix": suffix, + "dtype": "text", + "timestamp": datetime.datetime.now(), + } + self.recorder.submit(text_entry) + #################################################################### ## Back-End Lifecycle API #################################################################### diff --git a/chimerapy/engine/node/record_service.py b/chimerapy/engine/node/record_service.py index a5489bf..b5020f8 100644 --- a/chimerapy/engine/node/record_service.py +++ b/chimerapy/engine/node/record_service.py @@ -14,6 +14,7 @@ TabularRecord, ImageRecord, JSONRecord, + TextRecord, ) from ..service import Service @@ -49,6 +50,7 @@ def __init__( "tabular": TabularRecord, "image": ImageRecord, "json": JSONRecord, + "text": TextRecord, } # Making sure the attribute exists @@ -130,6 +132,8 @@ def run(self): if data_entry["name"] not in self.records: entry_cls = self.record_map[data_entry["dtype"]] entry = entry_cls(dir=self.state.logdir, name=data_entry["name"]) + + # FixMe: Potential overwrite of existing entry? self.records[data_entry["name"]] = entry # Case 2 diff --git a/chimerapy/engine/records/__init__.py b/chimerapy/engine/records/__init__.py index 36e7ff6..12871d9 100644 --- a/chimerapy/engine/records/__init__.py +++ b/chimerapy/engine/records/__init__.py @@ -5,6 +5,7 @@ from .tabular_record import TabularRecord from .video_record import VideoRecord from .json_record import JSONRecord +from .text_record import TextRecord __all__ = [ "Record", @@ -13,4 +14,5 @@ "TabularRecord", "VideoRecord", "JSONRecord", + "TextRecord", ] diff --git a/chimerapy/engine/records/text_record.py b/chimerapy/engine/records/text_record.py new file mode 100644 index 0000000..afa8eaf --- /dev/null +++ b/chimerapy/engine/records/text_record.py @@ -0,0 +1,46 @@ +# Built-in Imports +from typing import Dict, Any, Optional, IO +import pathlib + +# Third-party Imports + +# Internal Import +from .record import Record + + +class TextRecord(Record): + def __init__( + self, + dir: pathlib.Path, + name: str, + ): + """Construct a text file Record. + + Args: + dir (pathlib.Path): The directory to store the snap shots of data. + name (str): The name of the ``Record``. + suffix (str): The suffix of the text file. Defaults to "txt". + """ + super().__init__() + + # Saving the Record attributes + self.dir = dir + self.name = name + self.first_frame = False + self.file_handler: Optional[IO[str]] = None + + def write(self, data_chunk: Dict[str, Any]): + if not self.first_frame: + self.file_handler = (self.dir / f"{self.name}.{data_chunk['suffix']}").open( + "w" + ) + self.first_frame = True + + text_data = data_chunk["data"] + assert self.file_handler is not None + self.file_handler.write(text_data) + + def close(self): + if self.file_handler is not None: + self.file_handler.close() + self.file_handler = None diff --git a/chimerapy/engine/records/video_record.py b/chimerapy/engine/records/video_record.py index dd88522..8674cc6 100644 --- a/chimerapy/engine/records/video_record.py +++ b/chimerapy/engine/records/video_record.py @@ -9,6 +9,8 @@ # Internal Imports from .record import Record +from datetime import datetime + class VideoRecord(Record): def __init__( @@ -36,6 +38,7 @@ def __init__( # Handling unstable FPS self.frame_count: int = 0 self.previous_frame: np.ndarray = np.array([]) + self.start_time: datetime = datetime.now() def write(self, data_chunk: Dict[str, Any]): """Commit the unsaved changes to memory.""" @@ -43,7 +46,8 @@ def write(self, data_chunk: Dict[str, Any]): # Determine the size frame = data_chunk["data"] fps = data_chunk["fps"] - elapsed = data_chunk["elapsed"] + timestamp = data_chunk["timestamp"] + elapsed = (timestamp - self.start_time).total_seconds() h, w = frame.shape[:2] # Determine if RGB or grey video @@ -60,7 +64,6 @@ def write(self, data_chunk: Dict[str, Any]): self.video_writer = cv2.VideoWriter( str(self.video_file_path), self.video_fourcc, fps, (w, h), 0 ) - # Write self.first_frame = False self.video_writer.write(np.uint8(frame)) diff --git a/test/streams/data_nodes.py b/test/streams/data_nodes.py index 7a54a0d..bb051dd 100644 --- a/test/streams/data_nodes.py +++ b/test/streams/data_nodes.py @@ -1,5 +1,6 @@ # Build-in Imports import time +import random # Third-party Imports import pyaudio @@ -75,3 +76,18 @@ def step(self): time.sleep(1 / 10) data = {"time": time.time(), "content": "HELLO"} self.save_json(name="test", data=data) + + +class TextNode(cpe.Node): + def setup(self): + self.step_count = 0 + + def step(self): + time.sleep(1 / 10) + num_lines = random.randint(1, 5) + self.step_count += 1 + lines = [] + for j in range(num_lines): + lines.append(f"This is a test - Step Count - {self.step_count + 1}\n") + + self.save_text(name="test", data="".join(lines), suffix="text") diff --git a/test/streams/test_text.py b/test/streams/test_text.py new file mode 100644 index 0000000..c138925 --- /dev/null +++ b/test/streams/test_text.py @@ -0,0 +1,102 @@ +from .data_nodes import TextNode + +# Built-in Imports +import os +import pathlib +import time +import uuid + +# Third-party +import pytest + +# Internal Imports +import chimerapy.engine as cpe +from chimerapy.engine.records.text_record import TextRecord +from chimerapy.engine.networking.async_loop_thread import AsyncLoopThread +from chimerapy.engine.eventbus import EventBus, Event + +logger = cpe._logger.getLogger("chimerapy-engine") + +# Constants +CWD = pathlib.Path(os.path.abspath(__file__)).parent.parent +TEST_DATA_DIR = CWD / "data" + + +@pytest.fixture +def text_node(): + + # Create a node + text_n = TextNode(name="text_n", logdir=TEST_DATA_DIR) + + return text_n + + +def test_text_record(): + + # Check that the image was created + expected_text_path = TEST_DATA_DIR / "test-5.log" + try: + os.rmdir(expected_text_path.parent) + except OSError: + ... + + # Create the record + text_r = TextRecord(dir=TEST_DATA_DIR, name="test-5") + + data = [ + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.\n", + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi " + "ut aliquip ex ea commodo consequat.\n", + ] + + # Write to image file + for i in range(5): + print("\n".join(data)) + text_chunk = { + "uuid": uuid.uuid4(), + "name": "test-5", + "suffix": "log", + "data": "".join(data), + "dtype": "text", + } + text_r.write(text_chunk) + + # Check that the image was created + assert expected_text_path.exists() + + with expected_text_path.open("r") as jlf: + for idx, line in enumerate(jlf): + assert line.strip() == (data[idx % len(data)]).strip() + + +def test_node_save_text_stream(text_node): + + # Event Loop + thread = AsyncLoopThread() + thread.start() + eventbus = EventBus(thread=thread) + + # Check that the image was created + expected_text_path = pathlib.Path(text_node.state.logdir) / "test.text" + try: + os.rmdir(expected_text_path.parent) + except OSError: + ... + + # Stream + text_node.run(blocking=False, eventbus=eventbus) + + # Wait to generate files + eventbus.send(Event("start")).result() + logger.debug("Finish start") + eventbus.send(Event("record")).result() + logger.debug("Finish record") + time.sleep(3) + eventbus.send(Event("stop")).result() + logger.debug("Finish stop") + + text_node.shutdown() + + # Check that the image was created + assert expected_text_path.exists() diff --git a/test/streams/test_video.py b/test/streams/test_video.py index 86f0e1a..a0bdab7 100644 --- a/test/streams/test_video.py +++ b/test/streams/test_video.py @@ -3,6 +3,7 @@ import asyncio import pathlib import uuid +from datetime import timedelta # Third-party import cv2 @@ -47,6 +48,7 @@ def test_video_record(): # Write to video file fps = 30 + start_time = vr.start_time for i in range(fps): data = np.random.rand(200, 300, 3) * 255 video_chunk = { @@ -55,8 +57,7 @@ def test_video_record(): "data": data, "dtype": "video", "fps": fps, - "timestamp": i / fps, - "elapsed": i / fps, + "timestamp": start_time + timedelta(seconds=i / fps), } vr.write(video_chunk) @@ -88,10 +89,11 @@ def test_video_record_with_unstable_frames(): fps = 30 actual_fps = 10 rec_time = 5 + start_time = vr.start_time for i in range(rec_time * actual_fps): # But actually, we are getting frames at 20 fps - timestamp = i / actual_fps + timestamp = start_time + timedelta(seconds=i / actual_fps) data = np.random.rand(200, 300, 3) * 255 video_chunk = { "uuid": uuid.uuid4(),