-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2047249
commit e76e886
Showing
15 changed files
with
561 additions
and
462 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
278 changes: 130 additions & 148 deletions
278
chimerapy/engine/manager/artifacts_collector_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,173 +1,155 @@ | ||
import json | ||
import logging | ||
import os | ||
import pathlib | ||
from typing import Optional | ||
from typing import Any, Dict, Optional | ||
|
||
import aiofiles | ||
import aiohttp | ||
import aioshutil | ||
import zmq | ||
from aiohttp import ClientSession | ||
from rich.progress import Progress | ||
from zmq.asyncio import Context | ||
import zmq.asyncio | ||
|
||
from chimerapy.engine import config | ||
import chimerapy.engine.config as cpe_config | ||
from chimerapy.engine._logger import fork, getLogger | ||
from chimerapy.engine.states import ManagerState | ||
from chimerapy.engine.utils import get_ip_address | ||
|
||
|
||
async def download_task( | ||
context: zmq.Context, | ||
ip: str, | ||
port: int, | ||
filename: pathlib.Path, | ||
expected_size, | ||
progress=None, | ||
): | ||
"""Download a file from a worker.""" | ||
dealer = context.socket(zmq.DEALER) | ||
dealer.sndhwm = dealer.rcvhwm = config.get("file-transfer.max-chunks") | ||
dealer.connect(f"tcp://{ip}:{port}") | ||
|
||
f = await aiofiles.open(filename, "wb") | ||
credit = config.get("file-transfer.max-chunks") | ||
chunk_size = config.get("file-transfer.chunk-size") | ||
|
||
total = 0 | ||
chunks = 0 | ||
offset = 0 | ||
seq_no = 0 | ||
|
||
# Create a progress bar | ||
human_size = round(expected_size / 1024 / 1024, 2) | ||
update_task = None | ||
if progress: | ||
update_task = progress.add_task( | ||
f"[cyan]Downloading ({filename.name}-{human_size}MB...)", total=100 | ||
) | ||
|
||
while True: | ||
while credit: | ||
await dealer.send_multipart( | ||
[b"fetch", b"%i" % offset, b"%i" % chunk_size, b"%i" % seq_no] | ||
) | ||
offset += chunk_size | ||
seq_no += 1 | ||
credit -= 1 | ||
|
||
try: | ||
chunk, seq_no_recv_str = await dealer.recv_multipart() | ||
await f.write(chunk) | ||
except zmq.ZMQError as e: | ||
if e.errno == zmq.ETERM: | ||
return | ||
else: | ||
raise | ||
from chimerapy.engine.networking.zmq_file_transfer_client import ZMQFileClient | ||
from chimerapy.engine.utils import get_progress_bar | ||
|
||
chunks += 1 | ||
credit += 1 | ||
size = len(chunk) | ||
total += size | ||
from ..eventbus import Event, EventBus, TypedObserver | ||
from ..service import Service | ||
from ..states import ManagerState | ||
from .events import UpdateSendArchiveEvent | ||
|
||
if update_task: | ||
progress.update(update_task, completed=(total / expected_size) * 100) | ||
|
||
if size < chunk_size: | ||
await f.close() | ||
break | ||
|
||
|
||
class ArtifactsCollector: | ||
"""A utility class to collect artifacts recorded by the nodes.""" | ||
|
||
class ArtifactsCollectorService(Service): | ||
def __init__( | ||
self, | ||
name: str, | ||
eventbus: EventBus, | ||
state: ManagerState, | ||
worker_id: str, | ||
parent_logger: Optional[logging.Logger] = None, | ||
unzip: bool = False, | ||
progressbar: Optional[Progress] = None, | ||
): | ||
worker_state = state.workers[worker_id] | ||
super().__init__(name=name) | ||
self.eventbus = eventbus | ||
self.observers: Dict[str, TypedObserver] = {} | ||
self.clients: Dict[str, ZMQFileClient] = {} | ||
self.state = state | ||
self.progressbar = get_progress_bar() | ||
|
||
if parent_logger is None: | ||
parent_logger = getLogger("chimerapy-engine") | ||
|
||
self.logger = fork( | ||
parent_logger, | ||
f"ArtifactsCollector-[Worker({worker_state.name})]", | ||
) | ||
self.logger = fork(parent_logger, self.__class__.__name__) | ||
|
||
self.state = state | ||
self.worker_id = worker_id | ||
self.base_url = ( | ||
f"http://{self.state.workers[self.worker_id].ip}:" | ||
f"{self.state.workers[self.worker_id].port}" | ||
) | ||
self.unzip = unzip | ||
self.progressbar = progressbar | ||
|
||
async def _artifact_info(self, session: aiohttp.ClientSession): | ||
self.logger.info(f"Requesting artifact info from {self.base_url}") | ||
data = { | ||
"initiate_remote_transfer": get_ip_address() | ||
!= self.state.workers[self.worker_id].ip | ||
async def async_init(self): | ||
self.observers = { | ||
"artifacts_transfer_ready": TypedObserver( | ||
"artifacts_transfer_ready", on_asend=self.collect, handle_event="pass" | ||
) | ||
} | ||
|
||
async with session.post( | ||
"/nodes/request_collect", | ||
data=json.dumps(data), | ||
) as resp: | ||
if resp.ok: | ||
data = await resp.json() | ||
return data["zip_path"], data["port"], data["size"] | ||
else: | ||
# FixMe: Handle this error properly | ||
raise ConnectionError( | ||
f"Artifacts Collection Failed: {resp.status} {resp.reason}" | ||
) | ||
for name, observer in self.observers.items(): | ||
await self.eventbus.asubscribe(observer) | ||
|
||
async def collect(self, event: Event) -> None: | ||
method = event.data["method"] | ||
print(event.data) | ||
if method == "zmq": | ||
self.logger.debug("Collecting artifacts over ZMQ") | ||
await self._collect_zmq( | ||
worker_id=event.data["worker_id"], | ||
host=event.data["ip"], | ||
port=event.data["port"], | ||
artifacts=event.data["data"], | ||
) | ||
else: | ||
self.logger.debug("Collecting artifacts locally") | ||
await self._collect_local( | ||
worker_id=event.data["worker_id"], artifacts=event.data["data"] | ||
) | ||
|
||
async def collect(self) -> bool: | ||
client_session = ClientSession(base_url=self.base_url) | ||
async def _collect_zmq( | ||
self, worker_id: str, host: str, port: int, artifacts: Dict[str, Any] | ||
): | ||
files = {} | ||
self.logger.debug("Preparing files to download") | ||
for node_id, artifacts in artifacts.items(): | ||
out_dir = self._create_node_dir(worker_id, node_id) | ||
for artifact in artifacts: | ||
key = f"{node_id}-{artifact['name']}" | ||
files[key] = { | ||
"name": artifact["filename"], | ||
"size": artifact["size"], | ||
"outdir": out_dir, | ||
} | ||
context = zmq.asyncio.Context.instance() | ||
client = ZMQFileClient( | ||
context=context, | ||
host=host, | ||
port=port, | ||
credit=cpe_config.get("file-transfer.max-chunks"), | ||
chunk_size=cpe_config.get("file-transfer.chunk-size"), | ||
files=files, | ||
parent_logger=self.logger, | ||
progressbar=self.progressbar, | ||
) | ||
self.clients[worker_id] = client | ||
try: | ||
zip_path, port, size = await self._artifact_info(client_session) | ||
await client.async_init() | ||
await client.download_files() | ||
event_data = UpdateSendArchiveEvent(worker_id=worker_id, success=True) | ||
except Exception as e: | ||
self.logger.error(f"Failed to get artifact info: {e}") | ||
return False | ||
save_name = f"{self.state.workers[self.worker_id].name}_{self.worker_id[:8]}" | ||
zip_save_path = self.state.logdir / f"{save_name}.zip" | ||
if port is not None: | ||
try: | ||
await download_task( | ||
Context(), | ||
self.state.workers[self.worker_id].ip, | ||
int(port), | ||
zip_save_path, | ||
size, | ||
self.progressbar, | ||
) | ||
except Exception as e: | ||
self.logger.error(f"Failed to download artifacts: {e}") | ||
return False | ||
else: | ||
self.logger.info(f"Copying {zip_path} to {zip_save_path}") | ||
try: | ||
await aioshutil.copyfile(zip_path, self.state.logdir / zip_save_path) | ||
except Exception as e: | ||
self.logger.error(f"Failed to copy artifacts: {e}") | ||
return False | ||
|
||
if self.unzip: | ||
self.logger.info(f"Unzipping {zip_save_path}") | ||
try: | ||
await aioshutil.unpack_archive( | ||
zip_save_path, self.state.logdir / save_name | ||
event_data = UpdateSendArchiveEvent( | ||
worker_id=worker_id, | ||
success=False, | ||
) | ||
self.logger.error( | ||
f"Error while collecting artifacts for worker {worker_id}: {e}" | ||
) | ||
finally: | ||
await self.eventbus.asend(Event("update_send_archive", event_data)) | ||
self.logger.info(f"Successfully collected artifacts for worker {worker_id}") | ||
|
||
async def _collect_local(self, worker_id: str, artifacts: Dict[str, Any]) -> None: | ||
try: | ||
for node_id, node_artifacts in artifacts.items(): | ||
node_dir = self._create_node_dir(worker_id, node_id) | ||
|
||
for artifact in node_artifacts: | ||
artifact_path = pathlib.Path(artifact["path"]) | ||
self.logger.debug(f"Copying {artifact_path} to {node_dir}") | ||
await aioshutil.copyfile( | ||
artifact_path, node_dir / artifact["filename"] | ||
) | ||
|
||
await self.eventbus.asend( | ||
Event( | ||
"update_send_archive", | ||
UpdateSendArchiveEvent(worker_id=worker_id, success=True), | ||
) | ||
self.logger.info(f"Removing {zip_save_path}") | ||
os.remove(zip_save_path) | ||
except Exception as e: | ||
self.logger.error(f"Failed to unzip artifacts: {e}") | ||
return False | ||
) | ||
event_data = UpdateSendArchiveEvent( | ||
worker_id=worker_id, | ||
success=True, | ||
) | ||
self.logger.info(f"Successfully collected artifacts for worker {worker_id}") | ||
except Exception as e: | ||
event_data = UpdateSendArchiveEvent( | ||
worker_id=worker_id, | ||
success=False, | ||
) | ||
self.logger.error( | ||
f"Error while collecting artifacts for worker {worker_id}: {e}" | ||
) | ||
finally: | ||
await self.eventbus.asend(Event("update_send_archive", event_data)) | ||
|
||
def _create_worker_dir(self, worker_id): | ||
worker_name = self.state.workers[worker_id].name | ||
worker_dir = self.state.logdir / f"{worker_name}-{worker_id[:10]}" | ||
worker_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
return worker_dir | ||
|
||
def _create_node_dir(self, worker_id, node_id): | ||
worker_dir = self._create_worker_dir(worker_id) | ||
nodes = self.state.workers[worker_id].nodes | ||
node_dir = worker_dir / nodes[node_id].name | ||
node_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
return True | ||
return node_dir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.