Skip to content

Commit

Permalink
Refactor collection with HTTP get
Browse files Browse the repository at this point in the history
  • Loading branch information
umesh-timalsina committed Sep 28, 2023
1 parent 0356aad commit a9dd21b
Show file tree
Hide file tree
Showing 16 changed files with 391 additions and 5 deletions.
3 changes: 3 additions & 0 deletions chimerapy/engine/chimerapyrc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ config:
client-shutdown: 10
server-shutdown: 15
pub-delay: 1
artifacts-ready: 50
diagnostics:
deque-length: 10000
interval: 10
logging-enabled: false
streaming-responses:
chunk-size: 1024 # KB
187 changes: 187 additions & 0 deletions chimerapy/engine/manager/artifacts_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import json
import logging
import asyncio
import pathlib

import aioshutil
from tqdm import tqdm

import aiofiles
import aiohttp
from aiohttp import ClientSession
from typing import Dict, Any

from chimerapy.engine._logger import fork, getLogger
from chimerapy.engine.states import ManagerState
from ..config import get
from chimerapy.engine.utils import async_waiting_for


class ArtifactsCollector:
"""A utility class to collect artifacts recorded by the nodes."""

def __init__(
self, state: ManagerState, worker_id: str, parent_logger: logging.Logger = None
):
self._payload = None

if parent_logger:
worker_state = state.workers[worker_id]
self.logger = fork(
parent_logger,
f"ArtifactsCollector[Worker{worker_state.name}-{worker_state.id[:8]}]",
)
else:
logger = getLogger("chimerapy-engine")
self.logger = fork(logger, "collector")

self.state = state
self.worker_id = worker_id
self.base_url = (
f"http://{self.state.workers[self.worker_id].ip}:"
f"{self.state.workers[self.worker_id].port}"
)

async def _request_artifacts_gather(
self, session: ClientSession, timeout: int
) -> None:
"""Request the nodes to gather recorded artifacts."""
self.logger.debug("Requesting nodes to gather recorded artifacts")
async with session.post(
url="/nodes/gather_artifacts", data=json.dumps({})
) as _:
...

self.logger.debug("Waiting for nodes to gather recorded artifacts")
success = await async_waiting_for(self._have_nodes_saved, timeout=timeout)

if not success:
e_msg = "Nodes did not gather recorded artifacts in time"
self.logger.error(e_msg)
raise TimeoutError(e_msg)

self.logger.info("Nodes gathered recorded artifacts")

async def _request_artifacts_info(self, session) -> Dict[str, Any]:
"""Request the nodes to send the artifacts info."""
self.logger.debug("Requesting nodes to send artifacts info")
async with session.get(
url="/nodes/artifacts",
) as resp:
if resp.status != 200:
e_msg = "Could not get artifacts info from nodes"
self.logger.error(e_msg)
artifacts = {}
else:
artifacts = await resp.json()

return artifacts

def _have_nodes_saved(self) -> bool:
"""Check if all nodes have saved the recorded artifacts."""
worker_state = self.state.workers[self.worker_id]
node_fsm = map(lambda node: node.fsm, worker_state.nodes.values())

return all(map(lambda fsm: fsm == "SAVED", node_fsm))

async def _download_artifacts(self, session, artifacts) -> bool:
"""Download the artifacts from the nodes."""
parent_path = self._create_worker_dir()
coros = []
for node_id, node_artifacts in artifacts.items():
node_state = self._find_node_state_by_id(node_id)
node_dir = parent_path / node_state.name
node_dir.mkdir(exist_ok=True, parents=True)
for artifact in node_artifacts:
if self._is_remote_worker_collector():
coros.append(
self._download_remote_artifact(
session, node_id, node_dir, artifact
)
)
else:
coros.append(self._download_local_artifact(node_dir, artifact))

results = await asyncio.gather(*coros)
return all(results)

def _is_remote_worker_collector(self) -> bool:
return self.state.workers[self.worker_id].ip != self.state.ip

async def _download_local_artifact(
self, parent_dir: pathlib.Path, artifact: Dict[str, Any]
) -> bool:
file_path = parent_dir / pathlib.Path(artifact["path"]).name
src_path = pathlib.Path(artifact["path"])

if not src_path.exists():
return False

self.logger.debug(f"Copying {src_path} to {file_path}")
await aioshutil.copyfile(src_path, file_path)
return True

async def _download_remote_artifact(
self,
session: ClientSession,
node_id: str,
parent_dir: pathlib.Path,
artifact: Dict[str, Any],
) -> bool:
"""Download a single artifact from a node."""
file_path = parent_dir / pathlib.Path(artifact["path"]).name
# Stream and Save
async with session.get(
f"/nodes/artifacts/{node_id}/{artifact['name']}"
) as resp:

if resp.status != 200:
print(await resp.text())
e_msg = (
f"Could not download artifact "
f"{artifact['name']} from node {node_id}"
)
self.logger.error(e_msg)
return False

total_size = artifact["size"]
try:
async with aiofiles.open(file_path, mode="wb") as f:
with tqdm(
total=1,
desc=f"Downloading {file_path.name}",
unit="B",
unit_scale=True,
) as pbar:
async for chunk in resp.content.iter_chunked(
get("streaming-responses.chunk-size") * 1024
):
await f.write(chunk)
pbar.update(len(chunk) / total_size)
except Exception as e:
self.logger.error(
f"Could not save artifact {artifact['name']} "
f"from node {node_id}. Error: {e}"
)
return False

return True

def _create_worker_dir(self):
worker_dir = (
self.state.logdir / self.state.workers[self.worker_id].name
) # TODO: Match current format
worker_dir.mkdir(exist_ok=True, parents=True)
return worker_dir

def _find_node_state_by_id(self, node_id):
worker_state = self.state.workers[self.worker_id]
node_state = worker_state.nodes[node_id]
return node_state

async def collect(self, timeout=get("comms.timeout.artifacts-ready")) -> bool:
"""Collect the recorded artifacts from the nodes."""
async with aiohttp.ClientSession(base_url=self.base_url) as session:
await self._request_artifacts_gather(session, timeout=timeout)
artifacts = await self._request_artifacts_info(session)
return await self._download_artifacts(session, artifacts)
8 changes: 7 additions & 1 deletion chimerapy/engine/manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,10 @@ async def async_stop(self) -> bool:
return await self.worker_handler.stop()

async def async_collect(self) -> bool:
return await self.worker_handler.collect()
return await self.worker_handler.collect_v2()

async def async_collect_v2(self) -> bool:
return await self.worker_handler.collect_v2()

async def async_reset(self, keep_workers: bool = True):
return await self.worker_handler.reset(keep_workers)
Expand Down Expand Up @@ -466,6 +469,9 @@ def collect(self) -> Future[bool]:
"""
return self._exec_coro(self.async_collect())

def collect_v2(self) -> Future[bool]:
return self._exec_coro(self.async_collect_v2())

def reset(
self, keep_workers: bool = True, blocking: bool = True
) -> Union[bool, Future[bool]]:
Expand Down
17 changes: 17 additions & 0 deletions chimerapy/engine/manager/worker_handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..networking import Client, DataChunk
from ..service import Service
from ..graph import Graph
from .artifacts_collector import ArtifactsCollector
from ..exceptions import CommitGraphError
from ..states import WorkerState, ManagerState
from ..eventbus import EventBus, TypedObserver, Event, make_evented
Expand Down Expand Up @@ -850,3 +851,19 @@ async def reset(self, keep_workers: bool = True):
self._deregister_graph()

return all(results)

async def collect_v2(self) -> bool:
client_session = aiohttp.ClientSession()
futures = []
for worker_id in self.state.workers:
collector = ArtifactsCollector(
state=self.state,
worker_id=worker_id,
parent_logger=logger,
)
future = asyncio.create_task(collector.collect())
futures.append(future)

results = await asyncio.gather(*futures)
await client_session.close()
return all(results)
1 change: 1 addition & 0 deletions chimerapy/engine/networking/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ class NODE_MESSAGE(Enum):
REPORT_SAVING = 52
REPORT_RESULTS = 53
DIAGNOSTICS = 54
ARTIFACT = 55
14 changes: 12 additions & 2 deletions chimerapy/engine/node/events.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import pathlib
from dataclasses import dataclass
from typing import Dict, Any
from typing import Dict, Any, Optional

from ..networking.client import Client
from ..networking.data_chunk import DataChunk
from ..data_protocols import NodePubTable, NodeDiagnostics


@dataclass
class EnableDiagnosticsEvent: # enable_diagnostics
class EnableDiagnosticsEvent: # enable_diagnostics
enable: bool


Expand Down Expand Up @@ -41,3 +42,12 @@ class GatherEvent:
@dataclass
class DiagnosticsReportEvent: # diagnostics_report
diagnostics: NodeDiagnostics


@dataclass
class ArtifactEvent:
name: str
path: pathlib.Path
mime_type: str
size: Optional[int]
glob: Optional[str] = None
17 changes: 16 additions & 1 deletion chimerapy/engine/node/record_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from chimerapy.engine import _logger
from ..states import NodeState
from ..eventbus import EventBus, TypedObserver
from ..eventbus import EventBus, TypedObserver, Event
from ..records import (
Record,
VideoRecord,
Expand All @@ -17,6 +17,7 @@
TextRecord,
)
from ..service import Service
from .events import ArtifactEvent

logger = _logger.getLogger("chimerapy-engine")

Expand Down Expand Up @@ -151,7 +152,21 @@ def collect(self):

# Signal to stop and save
self.is_running.clear()
artifacts = {}
for name, entry in self.records.items():
artifacts[name] = entry.get_meta()

if self._record_thread:
self._record_thread.join()

for name, artifact in artifacts.items():
event_data = ArtifactEvent(
name=artifact["name"],
mime_type=artifact["mime_type"],
path=artifact["path"],
glob=artifact["glob"],
size=artifact["path"].stat().st_size,
)
assert self.eventbus.send(Event("artifact", event_data)).result(timeout=5)

# self.logger.debug(f"{self}: Finish saving records")
24 changes: 24 additions & 0 deletions chimerapy/engine/node/worker_comms_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
RegisteredMethodEvent,
GatherEvent,
DiagnosticsReportEvent,
ArtifactEvent,
)


Expand Down Expand Up @@ -97,6 +98,12 @@ def add_observers(self):
"teardown": TypedObserver(
"teardown", on_asend=self.teardown, handle_event="drop"
),
"artifact": TypedObserver(
"artifact",
ArtifactEvent,
on_asend=self.send_artifact_info,
handle_event="unpack",
),
}
for ob in observers.values():
self.eventbus.subscribe(ob).result(timeout=1)
Expand Down Expand Up @@ -175,6 +182,23 @@ async def send_diagnostics(self, diagnostics: NodeDiagnostics):
data = {"node_id": self.state.id, "diagnostics": diagnostics.to_dict()}
await self.client.async_send(signal=NODE_MESSAGE.DIAGNOSTICS, data=data)

async def send_artifact_info(
self, name: str, path: pathlib.Path, mime_type: str, size: int, glob: bool
):
assert self.state and self.eventbus and self.logger
if self.client:
data = {
"node_id": self.state.id,
"artifact": {
"name": name,
"path": str(path),
"mime_type": mime_type,
"glob": glob,
"size": size,
},
}
await self.client.async_send(signal=NODE_MESSAGE.ARTIFACT, data=data)

####################################################################
## Message Responds
####################################################################
Expand Down
8 changes: 8 additions & 0 deletions chimerapy/engine/records/audio_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@ def close(self):

# Close the audio writer
self.audio_writer.close()

def get_meta(self):
return {
"name": self.name,
"path": self.audio_file_path,
"mime_type": "audio/wav",
"glob": None,
}
9 changes: 9 additions & 0 deletions chimerapy/engine/records/image_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,12 @@ def write(self, data_chunk: Dict[str, Any]):

def close(self):
...

def get_meta(self):
"""Get metadata."""
return {
"name": self.name,
"path": self.save_loc,
"glob": "*.png",
"mime_type": "image/png",
}
Loading

0 comments on commit a9dd21b

Please sign in to comment.