Skip to content

Commit

Permalink
Rich Progress Handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
umesh-timalsina committed Oct 20, 2023
1 parent 2047249 commit e76e886
Show file tree
Hide file tree
Showing 15 changed files with 561 additions and 462 deletions.
3 changes: 2 additions & 1 deletion chimerapy/engine/chimerapyrc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ config:
worker-shutdown: 10 # seconds
node-creation: 130 # seconds
reset: 30
collect: 1800 # 30 minutes
retry:
data-collection: 30 # seconds
logs-sink:
Expand Down Expand Up @@ -41,4 +42,4 @@ config:
logging-enabled: false
file-transfer:
chunk-size: 250000 # bytes
max-chunks: 2 # Number of chunks to send at once
max-chunks: 1 # Number of chunks to send at once
278 changes: 130 additions & 148 deletions chimerapy/engine/manager/artifacts_collector_service.py
Original file line number Diff line number Diff line change
@@ -1,173 +1,155 @@
import json
import logging
import os
import pathlib
from typing import Optional
from typing import Any, Dict, Optional

import aiofiles
import aiohttp
import aioshutil
import zmq
from aiohttp import ClientSession
from rich.progress import Progress
from zmq.asyncio import Context
import zmq.asyncio

from chimerapy.engine import config
import chimerapy.engine.config as cpe_config
from chimerapy.engine._logger import fork, getLogger
from chimerapy.engine.states import ManagerState
from chimerapy.engine.utils import get_ip_address


async def download_task(
context: zmq.Context,
ip: str,
port: int,
filename: pathlib.Path,
expected_size,
progress=None,
):
"""Download a file from a worker."""
dealer = context.socket(zmq.DEALER)
dealer.sndhwm = dealer.rcvhwm = config.get("file-transfer.max-chunks")
dealer.connect(f"tcp://{ip}:{port}")

f = await aiofiles.open(filename, "wb")
credit = config.get("file-transfer.max-chunks")
chunk_size = config.get("file-transfer.chunk-size")

total = 0
chunks = 0
offset = 0
seq_no = 0

# Create a progress bar
human_size = round(expected_size / 1024 / 1024, 2)
update_task = None
if progress:
update_task = progress.add_task(
f"[cyan]Downloading ({filename.name}-{human_size}MB...)", total=100
)

while True:
while credit:
await dealer.send_multipart(
[b"fetch", b"%i" % offset, b"%i" % chunk_size, b"%i" % seq_no]
)
offset += chunk_size
seq_no += 1
credit -= 1

try:
chunk, seq_no_recv_str = await dealer.recv_multipart()
await f.write(chunk)
except zmq.ZMQError as e:
if e.errno == zmq.ETERM:
return
else:
raise
from chimerapy.engine.networking.zmq_file_transfer_client import ZMQFileClient
from chimerapy.engine.utils import get_progress_bar

chunks += 1
credit += 1
size = len(chunk)
total += size
from ..eventbus import Event, EventBus, TypedObserver
from ..service import Service
from ..states import ManagerState
from .events import UpdateSendArchiveEvent

if update_task:
progress.update(update_task, completed=(total / expected_size) * 100)

if size < chunk_size:
await f.close()
break


class ArtifactsCollector:
"""A utility class to collect artifacts recorded by the nodes."""

class ArtifactsCollectorService(Service):
def __init__(
self,
name: str,
eventbus: EventBus,
state: ManagerState,
worker_id: str,
parent_logger: Optional[logging.Logger] = None,
unzip: bool = False,
progressbar: Optional[Progress] = None,
):
worker_state = state.workers[worker_id]
super().__init__(name=name)
self.eventbus = eventbus
self.observers: Dict[str, TypedObserver] = {}
self.clients: Dict[str, ZMQFileClient] = {}
self.state = state
self.progressbar = get_progress_bar()

if parent_logger is None:
parent_logger = getLogger("chimerapy-engine")

self.logger = fork(
parent_logger,
f"ArtifactsCollector-[Worker({worker_state.name})]",
)
self.logger = fork(parent_logger, self.__class__.__name__)

self.state = state
self.worker_id = worker_id
self.base_url = (
f"http://{self.state.workers[self.worker_id].ip}:"
f"{self.state.workers[self.worker_id].port}"
)
self.unzip = unzip
self.progressbar = progressbar

async def _artifact_info(self, session: aiohttp.ClientSession):
self.logger.info(f"Requesting artifact info from {self.base_url}")
data = {
"initiate_remote_transfer": get_ip_address()
!= self.state.workers[self.worker_id].ip
async def async_init(self):
self.observers = {
"artifacts_transfer_ready": TypedObserver(
"artifacts_transfer_ready", on_asend=self.collect, handle_event="pass"
)
}

async with session.post(
"/nodes/request_collect",
data=json.dumps(data),
) as resp:
if resp.ok:
data = await resp.json()
return data["zip_path"], data["port"], data["size"]
else:
# FixMe: Handle this error properly
raise ConnectionError(
f"Artifacts Collection Failed: {resp.status} {resp.reason}"
)
for name, observer in self.observers.items():
await self.eventbus.asubscribe(observer)

async def collect(self, event: Event) -> None:
method = event.data["method"]
print(event.data)
if method == "zmq":
self.logger.debug("Collecting artifacts over ZMQ")
await self._collect_zmq(
worker_id=event.data["worker_id"],
host=event.data["ip"],
port=event.data["port"],
artifacts=event.data["data"],
)
else:
self.logger.debug("Collecting artifacts locally")
await self._collect_local(
worker_id=event.data["worker_id"], artifacts=event.data["data"]
)

async def collect(self) -> bool:
client_session = ClientSession(base_url=self.base_url)
async def _collect_zmq(
self, worker_id: str, host: str, port: int, artifacts: Dict[str, Any]
):
files = {}
self.logger.debug("Preparing files to download")
for node_id, artifacts in artifacts.items():
out_dir = self._create_node_dir(worker_id, node_id)
for artifact in artifacts:
key = f"{node_id}-{artifact['name']}"
files[key] = {
"name": artifact["filename"],
"size": artifact["size"],
"outdir": out_dir,
}
context = zmq.asyncio.Context.instance()
client = ZMQFileClient(
context=context,
host=host,
port=port,
credit=cpe_config.get("file-transfer.max-chunks"),
chunk_size=cpe_config.get("file-transfer.chunk-size"),
files=files,
parent_logger=self.logger,
progressbar=self.progressbar,
)
self.clients[worker_id] = client
try:
zip_path, port, size = await self._artifact_info(client_session)
await client.async_init()
await client.download_files()
event_data = UpdateSendArchiveEvent(worker_id=worker_id, success=True)
except Exception as e:
self.logger.error(f"Failed to get artifact info: {e}")
return False
save_name = f"{self.state.workers[self.worker_id].name}_{self.worker_id[:8]}"
zip_save_path = self.state.logdir / f"{save_name}.zip"
if port is not None:
try:
await download_task(
Context(),
self.state.workers[self.worker_id].ip,
int(port),
zip_save_path,
size,
self.progressbar,
)
except Exception as e:
self.logger.error(f"Failed to download artifacts: {e}")
return False
else:
self.logger.info(f"Copying {zip_path} to {zip_save_path}")
try:
await aioshutil.copyfile(zip_path, self.state.logdir / zip_save_path)
except Exception as e:
self.logger.error(f"Failed to copy artifacts: {e}")
return False

if self.unzip:
self.logger.info(f"Unzipping {zip_save_path}")
try:
await aioshutil.unpack_archive(
zip_save_path, self.state.logdir / save_name
event_data = UpdateSendArchiveEvent(
worker_id=worker_id,
success=False,
)
self.logger.error(
f"Error while collecting artifacts for worker {worker_id}: {e}"
)
finally:
await self.eventbus.asend(Event("update_send_archive", event_data))
self.logger.info(f"Successfully collected artifacts for worker {worker_id}")

async def _collect_local(self, worker_id: str, artifacts: Dict[str, Any]) -> None:
try:
for node_id, node_artifacts in artifacts.items():
node_dir = self._create_node_dir(worker_id, node_id)

for artifact in node_artifacts:
artifact_path = pathlib.Path(artifact["path"])
self.logger.debug(f"Copying {artifact_path} to {node_dir}")
await aioshutil.copyfile(
artifact_path, node_dir / artifact["filename"]
)

await self.eventbus.asend(
Event(
"update_send_archive",
UpdateSendArchiveEvent(worker_id=worker_id, success=True),
)
self.logger.info(f"Removing {zip_save_path}")
os.remove(zip_save_path)
except Exception as e:
self.logger.error(f"Failed to unzip artifacts: {e}")
return False
)
event_data = UpdateSendArchiveEvent(
worker_id=worker_id,
success=True,
)
self.logger.info(f"Successfully collected artifacts for worker {worker_id}")
except Exception as e:
event_data = UpdateSendArchiveEvent(
worker_id=worker_id,
success=False,
)
self.logger.error(
f"Error while collecting artifacts for worker {worker_id}: {e}"
)
finally:
await self.eventbus.asend(Event("update_send_archive", event_data))

def _create_worker_dir(self, worker_id):
worker_name = self.state.workers[worker_id].name
worker_dir = self.state.logdir / f"{worker_name}-{worker_id[:10]}"
worker_dir.mkdir(parents=True, exist_ok=True)

return worker_dir

def _create_node_dir(self, worker_id, node_id):
worker_dir = self._create_worker_dir(worker_id)
nodes = self.state.workers[worker_id].nodes
node_dir = worker_dir / nodes[node_id].name
node_dir.mkdir(parents=True, exist_ok=True)

return True
return node_dir
10 changes: 10 additions & 0 deletions chimerapy/engine/manager/http_server_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def __init__(
web.post("/workers/deregister", self._deregister_worker_route),
web.post("/workers/node_status", self._update_nodes_status),
web.post("/workers/send_archive", self._update_send_archive),
web.post(
"/workers/artifacts_transfer_ready",
self._file_transfer_server_ready,
),
],
)

Expand Down Expand Up @@ -192,6 +196,12 @@ async def _update_send_archive(self, request: web.Request):
await self.eventbus.asend(Event("update_send_archive", event_data))
return web.HTTPOk()

async def _file_transfer_server_ready(self, request: web.Request):
msg = await request.json()
event_data = msg
await self.eventbus.asend(Event("artifacts_transfer_ready", event_data))
return web.HTTPOk()

#####################################################################################
## Front-End API
#####################################################################################
Expand Down
12 changes: 10 additions & 2 deletions chimerapy/engine/manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Eventbus
from ..eventbus import Event, EventBus, make_evented
from ..networking.async_loop_thread import AsyncLoopThread
from .artifacts_collector_service import ArtifactsCollectorService
from .distributed_logging_service import DistributedLoggingService

# Services
Expand Down Expand Up @@ -110,13 +111,20 @@ async def aserve(self) -> bool:
state=self.state,
# **self.kwargs,
)
self.artifacts_collector = ArtifactsCollectorService(
name="artifacts_collector",
eventbus=self.eventbus,
state=self.state,
parent_logger=logger,
)

# Initialize services
await self.http_server.async_init()
await self.worker_handler.async_init()
await self.zeroconf_service.async_init()
await self.session_record.async_init()
await self.distributed_logging.async_init()
await self.artifacts_collector.async_init()

# Start all services
await self.eventbus.asend(Event("start"))
Expand Down Expand Up @@ -336,8 +344,8 @@ async def async_stop(self) -> bool:
async def async_collect(self) -> bool:
return await self.worker_handler.collect()

async def async_collect_v2(self, unzip=False) -> bool:
return await self.worker_handler.collect_v2(unzip)
async def async_collect_v2(self) -> bool:
await self.worker_handler.collect_v2()

async def async_reset(self, keep_workers: bool = True):
return await self.worker_handler.reset(keep_workers)
Expand Down
Loading

0 comments on commit e76e886

Please sign in to comment.