diff --git a/examples/advanced/streaming/README.md b/examples/advanced/streaming/README.md new file mode 100644 index 0000000000..6ab160c7d5 --- /dev/null +++ b/examples/advanced/streaming/README.md @@ -0,0 +1,68 @@ +# Object Streaming Examples + +## Overview +The examples here demonstrate how to use object streamers to send large file/objects memory efficiently. + +The object streamer uses less memory because it sends files by chunks (default chunk size is 1MB) and +it sends containers entry by entry. + +For example, if you have a dict with 10 1GB entries, it will take 10GB extra space to send the dict without +streaming. It only requires extra 1GB to serialize the entry using streaming. + +## Concepts + +### Object Streamer + +ObjectStreamer is a base class to stream an object piece by piece. The `StreamableEngine` built in the NVFlare can +stream any implementations of ObjectSteamer + +Following implementations are included in NVFlare, + +* `FileStreamer`: It can be used to stream a file +* `ContainerStreamer`: This class can stream a container entry by entry. Currently, dict, list and set are supported + +The container streamer can only stream the top level entries. All the sub entries of a top entry are sent at once with +the top entry. + +### Object Retriever + +`ObjectRetriever` is designed to request an object to be streamed from a remote site. It automatically sets up the streaming +on both ends and handles the coordination. + +Currently, following implementations are available, + +* `FileRetriever`: It's used to retrieve a file from remote site using FileStreamer. +* `ContainerRetriever`: This class can be used to retrieve a container from remote site using ContainerStreamer. + +To use ContainerRetriever, the container must be given a name and added on the sending site, + +``` +ContainerRetriever.add_container("model", model_dict) +``` + +## Example Jobs + +### file_streaming job + +This job uses the FileStreamer object to send a large file from server to client. + +It demonstrates following mechanisms: +1. It uses components to handle the file transferring. No training workflow is used. + Since executor is required by NVFlare, a dummy executor is created. +2. It shows how to use the streamer directly without an object retriever. + +The job creates a temporary file to test. You can run the job in POC or using simulator as follows, + +``` +nvflare simulator -n 1 -t 1 jobs/file_streaming +``` +### dict_streaming job + +This job demonstrate how to send a dict from server to client using object retriever. + +It creates a task called "retrieve_dict" to tell client to get ready for the streaming. + +The example can be run in simulator like this, +``` +nvflare simulator -n 1 -t 1 jobs/dict_streaming +``` diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json new file mode 100755 index 0000000000..c2d85b8b48 --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json @@ -0,0 +1,23 @@ +{ + "format_version": 2, + "cell_wait_timeout": 5.0, + "executors": [ + { + "tasks": ["*"], + "executor": { + "path": "streaming_executor.StreamingExecutor", + "args": { + "dict_retriever_id": "dict_retriever" + } + } + } + ], + "components": [ + { + "id": "dict_retriever", + "path": "nvflare.app_common.streamers.container_retriever.ContainerRetriever", + "args": { + } + } + ] +} \ No newline at end of file diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json new file mode 100755 index 0000000000..fd847e0175 --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json @@ -0,0 +1,20 @@ +{ + "format_version": 2, + "components": [ + { + "id": "dict_retriever", + "path": "nvflare.app_common.streamers.container_retriever.ContainerRetriever", + "args": { + } + } + ], + "workflows": [ + { + "id": "controller", + "path": "streaming_controller.StreamingController", + "args": { + "dict_retriever_id": "dict_retriever" + } + } + ] +} \ No newline at end of file diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py new file mode 100644 index 0000000000..1f1700d1a8 --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from random import randbytes + +from nvflare.apis.controller_spec import Client, ClientTask, Task +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.streamers.container_retriever import ContainerRetriever + +STREAM_TOPIC = "rtr_file_stream" + + +class StreamingController(Controller): + def __init__(self, dict_retriever_id=None, task_timeout=60, task_check_period: float = 0.5): + Controller.__init__(self, task_check_period=task_check_period) + self.dict_retriever_id = dict_retriever_id + self.dict_retriever = None + self.task_timeout = task_timeout + + def start_controller(self, fl_ctx: FLContext): + model = self._get_test_model() + self.dict_retriever.add_container("model", model) + + def stop_controller(self, fl_ctx: FLContext): + pass + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + s = Shareable() + s["name"] = "model" + task = Task(name="retrieve_dict", data=s, timeout=self.task_timeout) + self.broadcast_and_wait( + task=task, + fl_ctx=fl_ctx, + min_responses=1, + abort_signal=abort_signal, + ) + client_resps = {} + for ct in task.client_tasks: + assert isinstance(ct, ClientTask) + resp = ct.result + if resp is None: + resp = "no answer" + else: + assert isinstance(resp, Shareable) + self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") + resp = resp.get_return_code() + client_resps[ct.client.name] = resp + return {"status": "OK", "data": client_resps} + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + pass + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + engine = fl_ctx.get_engine() + if self.dict_retriever_id: + c = engine.get_component(self.dict_retriever_id) + if not isinstance(c, ContainerRetriever): + self.system_panic( + f"invalid dict_retriever {self.dict_retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.dict_retriever = c + + @staticmethod + def _get_test_model() -> dict: + model = {} + for i in range(10): + key = f"layer-{i}" + model[key] = randbytes(1024) + + return model diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py new file mode 100644 index 0000000000..a238f82aff --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.streamers.container_retriever import ContainerRetriever + + +class StreamingExecutor(Executor): + def __init__(self, dict_retriever_id=None): + Executor.__init__(self) + self.dict_retriever_id = dict_retriever_id + self.dict_retriever = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + engine = fl_ctx.get_engine() + if self.dict_retriever_id: + c = engine.get_component(self.dict_retriever_id) + if not isinstance(c, ContainerRetriever): + self.system_panic( + f"invalid dict_retriever {self.dict_retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.dict_retriever = c + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.log_info(fl_ctx, f"got task {task_name}: {shareable}") + if task_name == "retrieve_dict": + name = shareable.get("name") + if not name: + self.log_error(fl_ctx, "missing name in request") + return make_reply(ReturnCode.BAD_TASK_DATA) + if not self.dict_retriever: + self.log_error(fl_ctx, "no container retriever") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + + assert isinstance(self.dict_retriever, ContainerRetriever) + rc, result = self.dict_retriever.retrieve_container( + from_site="server", + fl_ctx=fl_ctx, + timeout=10.0, + name=name, + ) + if rc != ReturnCode.OK: + self.log_error(fl_ctx, f"failed to retrieve dict {name}: {rc}") + return make_reply(rc) + + self.log_info(fl_ctx, f"received container type: {type(result)} size: {len(result)}") + return make_reply(ReturnCode.OK) + else: + self.log_error(fl_ctx, f"got unknown task {task_name}") + return make_reply(ReturnCode.TASK_UNKNOWN) diff --git a/examples/advanced/streaming/jobs/dict_streaming/meta.json b/examples/advanced/streaming/jobs/dict_streaming/meta.json new file mode 100644 index 0000000000..0fcb99272c --- /dev/null +++ b/examples/advanced/streaming/jobs/dict_streaming/meta.json @@ -0,0 +1,10 @@ +{ + "name": "file_streaming", + "resource_spec": {}, + "min_clients" : 1, + "deploy_map": { + "app": [ + "@ALL" + ] + } +} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json new file mode 100755 index 0000000000..5ac09cbb4f --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json @@ -0,0 +1,23 @@ +{ + "format_version": 2, + "executors": [ + { + "tasks": [ + "train" + ], + "executor": { + "path": "trainer.TestTrainer", + "args": {} + } + } + ], + "task_result_filters": [], + "task_data_filters": [], + "components": [ + { + "id": "sender", + "path": "file_streaming.FileSender", + "args": {} + } + ] +} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json new file mode 100755 index 0000000000..1c0be95c54 --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json @@ -0,0 +1,19 @@ +{ + "format_version": 2, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "receiver", + "path": "file_streaming.FileReceiver", + "args": {} + } + ], + "workflows": [ + { + "id": "controller", + "path": "controller.SimpleController", + "args": {} + } + ] +} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py new file mode 100644 index 0000000000..341a77c5bc --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py new file mode 100644 index 0000000000..206346e8f2 --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from nvflare.apis.client import Client +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal + +logger = logging.getLogger(__name__) + + +class SimpleController(Controller): + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + logger.info(f"Entering control loop of {self.__class__.__name__}") + engine = fl_ctx.get_engine() + + # Wait till receiver is done. Otherwise, the job ends. + receiver = engine.get_component("receiver") + while not receiver.is_done(): + time.sleep(0.2) + + logger.info("Control flow ends") + + def start_controller(self, fl_ctx: FLContext): + logger.info("Start controller") + + def stop_controller(self, fl_ctx: FLContext): + logger.info("Stop controller") + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + raise RuntimeError(f"Unknown task: {task_name} from client {client.name}.") diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py new file mode 100644 index 0000000000..9b49b230e5 --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +from threading import Thread + +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.streamers.file_streamer import FileStreamer + +CHANNEL = "_test_channel" +TOPIC = "_test_topic" +SIZE = 100 * 1024 * 1024 # 100 MB + + +class FileSender(FLComponent): + + def __init__(self): + super().__init__() + self.seq = 0 + self.aborted = False + self.file_name = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self.log_info(fl_ctx, "FileSender is started") + Thread(target=self._sending_file, args=(fl_ctx,), daemon=True).start() + elif event_type == EventType.ABORT_TASK: + self.log_info(fl_ctx, "Sender is aborted") + self.aborted = True + + def _sending_file(self, fl_ctx): + + # Create a temp file to send + tmp = tempfile.NamedTemporaryFile(delete=False) + try: + buf = bytearray(SIZE) + for i in range(len(buf)): + buf[i] = i % 256 + + tmp.write(buf) + finally: + tmp.close() + + self.file_name = tmp.name + + rc, result = FileStreamer.stream_file( + targets=["server"], + stream_ctx={}, + channel=CHANNEL, + topic=TOPIC, + file_name=self.file_name, + fl_ctx=fl_ctx, + optional=False, + secure=False, + ) + + self.log_info(fl_ctx, f"Sending finished with RC: {rc}") + os.remove(self.file_name) + + +class FileReceiver(FLComponent): + + def __init__(self): + super().__init__() + self.done = False + + def is_done(self): + return self.done + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + self._receive_file(fl_ctx) + self.log_info(fl_ctx, "FileReceiver is started") + + def _receive_file(self, fl_ctx): + FileStreamer.register_stream_processing( + fl_ctx=fl_ctx, + channel=CHANNEL, + topic=TOPIC, + stream_done_cb=self._done_cb, + ) + + def _done_cb(self, stream_ctx: dict, fl_ctx: FLContext): + self.log_info(fl_ctx, "File streaming is done") + self.done = True + + file_name = FileStreamer.get_file_location(stream_ctx) + file_size = FileStreamer.get_file_size(stream_ctx) + size = os.path.getsize(file_name) + + if size == file_size: + self.log_info(fl_ctx, f"File {file_name} has correct size {size} bytes") + else: + self.log_error(fl_ctx, f"File {file_name} sizes mismatch {size} <> {file_size} bytes") + + os.remove(file_name) diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py new file mode 100644 index 0000000000..216d69f793 --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.dxo import DXO, DataKind +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal + + +class TestTrainer(Executor): + def __init__(self): + super().__init__() + self.aborted = False + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.ABORT_TASK: + self.log_info(fl_ctx, "Trainer is aborted") + self.aborted = True + + def execute( + self, + task_name: str, + shareable: Shareable, + fl_ctx: FLContext, + abort_signal: Signal, + ) -> Shareable: + # This is a dummy executor which does nothing + self.log_info(fl_ctx, f"Executor is called with task {task_name}") + dxo = DXO(data_kind=DataKind.WEIGHTS, data={}) + return dxo.to_shareable() diff --git a/examples/advanced/streaming/jobs/file_streaming/meta.json b/examples/advanced/streaming/jobs/file_streaming/meta.json new file mode 100644 index 0000000000..0fcb99272c --- /dev/null +++ b/examples/advanced/streaming/jobs/file_streaming/meta.json @@ -0,0 +1,10 @@ +{ + "name": "file_streaming", + "resource_spec": {}, + "min_clients" : 1, + "deploy_map": { + "app": [ + "@ALL" + ] + } +} diff --git a/nvflare/app_common/statistics/json_stats_file_persistor.py b/nvflare/app_common/statistics/json_stats_file_persistor.py index bf56fe3a6a..96c96f7b32 100644 --- a/nvflare/app_common/statistics/json_stats_file_persistor.py +++ b/nvflare/app_common/statistics/json_stats_file_persistor.py @@ -19,7 +19,7 @@ from nvflare.apis.storage import StorageException from nvflare.app_common.abstract.statistics_writer import StatisticsWriter from nvflare.app_common.utils.json_utils import ObjectEncoder -from nvflare.fuel.utils.class_utils import get_class +from nvflare.fuel.utils.class_loader import load_class class JsonStatsFileWriter(StatisticsWriter): @@ -34,7 +34,7 @@ def __init__(self, output_path: str, json_encoder_path: str = ""): self.json_encoder_class = ObjectEncoder else: self.json_encoder_path = json_encoder_path - self.json_encoder_class = get_class(json_encoder_path) + self.json_encoder_class = load_class(json_encoder_path) def save( self, diff --git a/nvflare/app_common/streamers/container_retriever.py b/nvflare/app_common/streamers/container_retriever.py new file mode 100644 index 0000000000..a2781b536e --- /dev/null +++ b/nvflare/app_common/streamers/container_retriever.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable +from nvflare.apis.streaming import StreamContext + +from .container_streamer import ContainerStreamer +from .object_retriever import ObjectRetriever + + +class ContainerRetriever(ObjectRetriever): + def __init__( + self, + topic: str = None, + stream_msg_optional=False, + stream_msg_secure=False, + entry_timeout=None, + ): + ObjectRetriever.__init__(self, topic) + self.stream_msg_optional = stream_msg_optional + self.stream_msg_secure = stream_msg_secure + self.entry_timeout = entry_timeout + self.containers = {} + + def add_container(self, name: str, container: Any): + """Add a container to the retriever. This must be called on the sending side + + Args: + name: name for the container. + container: The container to be streamed + """ + self.containers[name] = container + + def register_stream_processing( + self, + channel: str, + topic: str, + fl_ctx: FLContext, + stream_done_cb, + **cb_kwargs, + ): + """Called on the stream sending side. + + Args: + channel: + topic: + fl_ctx: + stream_done_cb: + **cb_kwargs: + + Returns: + + """ + ContainerStreamer.register_stream_processing( + channel=channel, + topic=topic, + fl_ctx=fl_ctx, + stream_done_cb=stream_done_cb, + **cb_kwargs, + ) + + def validate_request(self, request: Shareable, fl_ctx: FLContext) -> (str, Any): + name = request.get("name") + if not name: + self.log_error(fl_ctx, "bad request: missing container name") + return ReturnCode.BAD_REQUEST_DATA, None + + container = self.containers.get(name, None) + if not container: + self.log_error(fl_ctx, f"bad request: requested container {name} doesn't exist") + return ReturnCode.BAD_REQUEST_DATA, None + + return ReturnCode.OK, container + + def retrieve_container(self, from_site: str, fl_ctx: FLContext, timeout: float, name: str) -> (str, Any): + """Retrieve a container from the specified site. + This method is to be called by the app. + + Args: + from_site: the site that has the container to be retrieved + fl_ctx: FLContext object + timeout: how long to wait for the file + name: name of the container + + Returns: a tuple of (ReturnCode, container) + + """ + return self.retrieve(from_site=from_site, fl_ctx=fl_ctx, timeout=timeout, name=name) + + def do_stream( + self, target: str, request: Shareable, fl_ctx: FLContext, stream_ctx: StreamContext, validated_data: Any + ): + """Stream the container to the peer. + Called on the stream sending side. + + Args: + target: the receiving site + request: data to be sent + fl_ctx: FLContext object + stream_ctx: the stream context + validated_data: the file full path returned from the validate_request method + + Returns: + + """ + ContainerStreamer.stream_container( + targets=[target], + stream_ctx=stream_ctx, + channel=self.stream_channel, + topic=self.topic, + container=validated_data, + fl_ctx=fl_ctx, + optional=self.stream_msg_optional, + secure=self.stream_msg_secure, + ) + + def get_result(self, stream_ctx: StreamContext) -> (str, Any): + """Called on the stream receiving side. + Get the final result of the streaming. + The result is the location of the received file. + + Args: + stream_ctx: the StreamContext + + Returns: + + """ + return ContainerStreamer.get_rc(stream_ctx), ContainerStreamer.get_result(stream_ctx) diff --git a/nvflare/app_common/streamers/container_streamer.py b/nvflare/app_common/streamers/container_streamer.py new file mode 100644 index 0000000000..569ea0b257 --- /dev/null +++ b/nvflare/app_common/streamers/container_streamer.py @@ -0,0 +1,255 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Tuple + +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.streaming import ConsumerFactory, ObjectConsumer, ObjectProducer, StreamableEngine, StreamContext +from nvflare.app_common.streamers.streamer_base import StreamerBase +from nvflare.fuel.utils.class_loader import get_class_name, load_class +from nvflare.fuel.utils.log_utils import get_obj_logger +from nvflare.fuel.utils.validation_utils import check_positive_number + +_PREFIX = "ContainerStreamer." + +# Keys for StreamCtx +_CTX_TYPE = _PREFIX + "type" +_CTX_SIZE = _PREFIX + "size" +_CTX_RESULT = _PREFIX + "result" + +# Keys for Shareable +_KEY_ENTRY = _PREFIX + "entry" +_KEY_LAST = _PREFIX + "last" + + +class _EntryConsumer(ObjectConsumer): + def __init__(self, stream_ctx: StreamContext): + self.logger = get_obj_logger(self) + container_type = stream_ctx.get(_CTX_TYPE) + container_class = load_class(container_type) + self.container = container_class() + self.size = stream_ctx.get(_CTX_SIZE) + + def consume( + self, + shareable: Shareable, + stream_ctx: StreamContext, + fl_ctx: FLContext, + ) -> Tuple[bool, Shareable]: + + entry = shareable.get(_KEY_ENTRY) + try: + if isinstance(self.container, dict): + key, value = entry + self.container[key] = value + elif isinstance(self.container, set): + self.container.add(entry) + else: + self.container.append(entry) + except Exception as ex: + error = f"Unable to add entry ({type(entry)} to container ({type(self.container)}" + self.logger.error(error) + raise ValueError(error) + + last = shareable.get(_KEY_LAST) + if last: + # Check if all entries are added + if self.size != len(self.container): + err = f"Container size {len(self.container)} does not match expected size {self.size}" + self.logger.error(err) + raise ValueError(err) + else: + stream_ctx[_CTX_RESULT] = self.container + return False, make_reply(ReturnCode.OK) + else: + # continue streaming + return True, make_reply(ReturnCode.OK) + + def finalize(self, stream_ctx: StreamContext, fl_ctx: FLContext): + self.logger.debug(f"Container streaming is done for container type {type(self.container)}") + + +class _EntryConsumerFactory(ConsumerFactory): + + def get_consumer(self, stream_ctx: StreamContext, fl_ctx: FLContext) -> ObjectConsumer: + return _EntryConsumer(stream_ctx) + + +class _EntryProducer(ObjectProducer): + def __init__(self, container, entry_timeout): + self.logger = get_obj_logger(self) + if not container: + error = "Can't stream empty container" + self.logger.error(error) + raise ValueError(error) + + self.container = container + if isinstance(container, dict): + self.iterator = iter(container.items()) + else: + self.iterator = iter(container) + self.size = len(container) + self.count = 0 + self.last = False + self.entry_timeout = entry_timeout + + def produce( + self, + stream_ctx: StreamContext, + fl_ctx: FLContext, + ) -> Tuple[Shareable, float]: + + try: + entry = next(self.iterator) + self.count += 1 + self.last = self.count >= self.size + except StopIteration: + self.logger.error(f"Producer called too many times {self.count}/{self.size}") + self.last = True + return None, 0.0 + + result = Shareable() + result[_KEY_ENTRY] = entry + result[_KEY_LAST] = self.last + return result, self.entry_timeout + + def process_replies( + self, + replies: Dict[str, Shareable], + stream_ctx: StreamContext, + fl_ctx: FLContext, + ) -> Any: + has_error = False + for target, reply in replies.items(): + rc = reply.get_return_code(ReturnCode.OK) + if rc != ReturnCode.OK: + self.logger.error(f"error from target {target}: {rc}") + has_error = True + + if has_error: + # done - failed + return False + elif self.last: + # done - succeeded + return True + else: + # not done yet - continue streaming + return None + + +class ContainerStreamer(StreamerBase): + @staticmethod + def register_stream_processing( + fl_ctx: FLContext, + channel: str, + topic: str, + stream_done_cb=None, + **cb_kwargs, + ): + """Register for stream processing on the receiving side. + + Args: + fl_ctx: the FLContext object + channel: the app channel + topic: the app topic + stream_done_cb: if specified, the callback to be called when the file is completely received + **cb_kwargs: the kwargs for the stream_done_cb + + Returns: None + + Notes: the stream_done_cb must follow stream_done_cb_signature as defined in apis.streaming. + + """ + + engine = fl_ctx.get_engine() + if not isinstance(engine, StreamableEngine): + raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}") + + engine.register_stream_processing( + channel=channel, + topic=topic, + factory=_EntryConsumerFactory(), + stream_done_cb=stream_done_cb, + **cb_kwargs, + ) + + @staticmethod + def stream_container( + channel: str, + topic: str, + stream_ctx: StreamContext, + targets: List[str], + container: Any, + fl_ctx: FLContext, + entry_timeout=None, + optional=False, + secure=False, + ) -> bool: + """Stream a file to one or more targets. + + Args: + channel: the app channel + topic: the app topic + stream_ctx: context data of the stream + targets: targets that the file will be sent to + container: container to be streamed + fl_ctx: a FLContext object + entry_timeout: timeout for each entry sent to targets. + optional: whether the file is optional + secure: whether P2P security is required + + Returns: whether the streaming completed successfully + + Notes: this is a blocking call - only returns after the streaming is done. + """ + if not entry_timeout: + entry_timeout = 60.0 + check_positive_number("entry_timeout", entry_timeout) + + producer = _EntryProducer(container, entry_timeout) + engine = fl_ctx.get_engine() + + if not isinstance(engine, StreamableEngine): + raise RuntimeError(f"engine must be StreamableEngine but got {type(engine)}") + + if not stream_ctx: + stream_ctx = {} + + stream_ctx[_CTX_TYPE] = get_class_name(type(container)) + stream_ctx[_CTX_SIZE] = len(container) + + return engine.stream_objects( + channel=channel, + topic=topic, + stream_ctx=stream_ctx, + targets=targets, + producer=producer, + fl_ctx=fl_ctx, + optional=optional, + secure=secure, + ) + + @staticmethod + def get_result(stream_ctx: StreamContext) -> Any: + """Get the received container + This method is intended to be used by the stream_done_cb() function of the receiving side. + + Args: + stream_ctx: the stream context + + Returns: The received container + + """ + return stream_ctx.get(_CTX_RESULT) diff --git a/nvflare/app_common/streamers/object_retriever.py b/nvflare/app_common/streamers/object_retriever.py index 01c8f989fb..9360206747 100644 --- a/nvflare/app_common/streamers/object_retriever.py +++ b/nvflare/app_common/streamers/object_retriever.py @@ -237,7 +237,7 @@ def _handle_stream_done(self, stream_ctx: StreamContext, fl_ctx: FLContext): waiter.result = result waiter.set() - self.log_info(fl_ctx, f"got result for RTR {tx_id}: {waiter.result}") + self.log_info(fl_ctx, f"got result for RTR {tx_id}: {type(waiter.result)}") def _handle_request(self, topic, request: Shareable, fl_ctx: FLContext) -> Shareable: # On request receiving side, which is also stream sending side. diff --git a/nvflare/fuel/f3/streaming/byte_streamer.py b/nvflare/fuel/f3/streaming/byte_streamer.py index faf160becd..3e678bd5a7 100644 --- a/nvflare/fuel/f3/streaming/byte_streamer.py +++ b/nvflare/fuel/f3/streaming/byte_streamer.py @@ -182,6 +182,9 @@ def send_pending_buffer(self, final=False): def stop(self, error: Optional[StreamError] = None, notify=True): + if self.stopped: + return + self.stopped = True if self.task_future: diff --git a/nvflare/fuel/utils/class_loader.py b/nvflare/fuel/utils/class_loader.py new file mode 100644 index 0000000000..2479f5c662 --- /dev/null +++ b/nvflare/fuel/utils/class_loader.py @@ -0,0 +1,53 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import builtins +import importlib +from typing import Type + + +# Those functions are extracted from class_utils module to share the code +# with FOBS and to avoid circular imports +def get_class_name(cls: Type) -> str: + """Get canonical class path or fully qualified name. The builtins module is removed + so common builtin class can be referenced with its normal name + + Args: + cls: The class type + Returns: + The canonical name + """ + module = cls.__module__ + if module == "builtins": + return cls.__qualname__ + return module + "." + cls.__qualname__ + + +def load_class(class_path): + """Load class from fully qualified class name + + Args: + class_path: fully qualified class name + Returns: + The class type + """ + + try: + if "." in class_path: + module_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, class_name) + else: + return getattr(builtins, class_path) + except Exception as ex: + raise TypeError(f"Can't load class {class_path}: {ex}") diff --git a/nvflare/fuel/utils/class_utils.py b/nvflare/fuel/utils/class_utils.py index e82162ada2..be69d7339f 100644 --- a/nvflare/fuel/utils/class_utils.py +++ b/nvflare/fuel/utils/class_utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import importlib import inspect import pkgutil @@ -19,6 +18,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.fuel.common.excepts import ConfigError +from nvflare.fuel.utils.class_loader import load_class from nvflare.fuel.utils.components_utils import create_classes_table_static from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.security.logging import secure_format_exception @@ -26,22 +26,6 @@ DEPRECATED_PACKAGES = ["nvflare.app_common.pt", "nvflare.app_common.homomorphic_encryption"] -def get_class(class_path): - module_name, class_name = class_path.rsplit(".", 1) - - try: - module_ = importlib.import_module(module_name) - - try: - class_ = getattr(module_, class_name) - except AttributeError: - raise ValueError("Class {} does not exist".format(class_path)) - except AttributeError: - raise ValueError("Module {} does not exist".format(class_path)) - - return class_ - - def instantiate_class(class_path, init_params): """Method for creating an instance for the class. @@ -51,7 +35,7 @@ def instantiate_class(class_path, init_params): arguments. The transform name will be appended to `medical.common.transforms` to make a full name of the transform to be built. """ - c = get_class(class_path) + c = load_class(class_path) try: if init_params: instance = c(**init_params) @@ -80,7 +64,7 @@ def __init__(self, base_pkgs: List[str], module_names: List[str], exclude_libs=T self._class_table = create_classes_table_static() def create_classes_table(self): - class_table: Dict[str, str] = {} + class_table: Dict[str, list[str]] = {} for base in self.base_pkgs: package = importlib.import_module(base) @@ -123,7 +107,8 @@ def get_module_name(self, class_name) -> Optional[str]: """ if class_name not in self._class_table: raise ConfigError( - f"Cannot find class '{class_name}'. Please check its spelling. If the spelling is correct, specify the class using its full path." + f"Cannot find class '{class_name}'. Please check its spelling. If the spelling is correct, " + "specify the class using its full path." ) modules = self._class_table.get(class_name, None) diff --git a/nvflare/fuel/utils/fobs/fobs.py b/nvflare/fuel/utils/fobs/fobs.py index 7a69ea17e7..aaca8c8a7c 100644 --- a/nvflare/fuel/utils/fobs/fobs.py +++ b/nvflare/fuel/utils/fobs/fobs.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import builtins import importlib import inspect import logging @@ -23,6 +22,7 @@ import msgpack +from nvflare.fuel.utils.class_loader import get_class_name, load_class from nvflare.fuel.utils.fobs.datum import DatumManager from nvflare.fuel.utils.fobs.decomposer import DataClassDecomposer, Decomposer, EnumTypeDecomposer @@ -58,25 +58,6 @@ _data_auto_registration = True -def _get_type_name(cls: Type) -> str: - module = cls.__module__ - if module == "builtins": - return cls.__qualname__ - return module + "." + cls.__qualname__ - - -def _load_class(type_name: str): - try: - if "." in type_name: - module_name, class_name = type_name.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, class_name) - else: - return getattr(builtins, type_name) - except Exception as ex: - raise TypeError(f"Can't load class {type_name}: {ex}") - - def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: """Register a decomposer. It does nothing if decomposer is already registered for the type @@ -91,7 +72,7 @@ def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: else: instance = decomposer - name = _get_type_name(instance.supported_type()) + name = get_class_name(instance.supported_type()) if name in _decomposers: return @@ -105,15 +86,15 @@ def register(decomposer: Union[Decomposer, Type[Decomposer]]) -> None: class Packer: def __init__(self, manager: DatumManager): self.manager = manager - self.enum_decomposer_name = _get_type_name(EnumTypeDecomposer) - self.data_decomposer_name = _get_type_name(DataClassDecomposer) + self.enum_decomposer_name = get_class_name(EnumTypeDecomposer) + self.data_decomposer_name = get_class_name(DataClassDecomposer) def pack(self, obj: Any) -> dict: if type(obj) in MSGPACK_TYPES: return obj - type_name = _get_type_name(obj.__class__) + type_name = get_class_name(obj.__class__) if type_name not in _decomposers: registered = False if isinstance(obj, Enum): @@ -136,7 +117,7 @@ def pack(self, obj: Any) -> dict: if self.manager: decomposed = self.manager.externalize(decomposed) - return {FOBS_TYPE: type_name, FOBS_DATA: decomposed, FOBS_DECOMPOSER: _get_type_name(type(decomposer))} + return {FOBS_TYPE: type_name, FOBS_DATA: decomposed, FOBS_DECOMPOSER: get_class_name(type(decomposer))} def unpack(self, obj: Any) -> Any: @@ -147,7 +128,7 @@ def unpack(self, obj: Any) -> Any: if type_name not in _decomposers: registered = False decomposer_name = obj.get(FOBS_DECOMPOSER) - cls = _load_class(type_name) + cls = load_class(type_name) if not decomposer_name: # Maintaining backward compatibility with auto enum registration if _enum_auto_registration: @@ -155,7 +136,7 @@ def unpack(self, obj: Any) -> Any: register_enum_types(cls) registered = True else: - decomposer_class = _load_class(decomposer_name) + decomposer_class = load_class(decomposer_name) if decomposer_name == self.enum_decomposer_name or decomposer_name == self.data_decomposer_name: # Generic decomposer's __init__ takes the target class as argument decomposer = decomposer_class(cls) diff --git a/nvflare/fuel/utils/wfconf.py b/nvflare/fuel/utils/wfconf.py index f2356c5469..c283d18dd6 100644 --- a/nvflare/fuel/utils/wfconf.py +++ b/nvflare/fuel/utils/wfconf.py @@ -22,7 +22,8 @@ from nvflare.security.logging import secure_format_exception from .argument_utils import parse_vars -from .class_utils import ModuleScanner, get_class, instantiate_class +from .class_loader import load_class +from .class_utils import ModuleScanner, instantiate_class from .dict_utils import extract_first_level_primitive, merge_dict from .json_scanner import JsonObjectProcessor, JsonScanner, Node @@ -362,7 +363,7 @@ def get_class_path(self, config_dict): return class_path def is_configured_subclass(self, config_dict, base_class): - return issubclass(get_class(self.get_class_path(config_dict)), base_class) + return issubclass(load_class(self.get_class_path(config_dict)), base_class) def start_config(self, config_ctx: ConfigContext): pass diff --git a/nvflare/private/json_configer.py b/nvflare/private/json_configer.py index 3a987fde69..aca538658b 100644 --- a/nvflare/private/json_configer.py +++ b/nvflare/private/json_configer.py @@ -15,7 +15,8 @@ from typing import List, Union from nvflare.fuel.common.excepts import ComponentNotAuthorized, ConfigError -from nvflare.fuel.utils.class_utils import ModuleScanner, get_class +from nvflare.fuel.utils.class_loader import load_class +from nvflare.fuel.utils.class_utils import ModuleScanner from nvflare.fuel.utils.component_builder import ComponentBuilder from nvflare.fuel.utils.config_factory import ConfigFactory from nvflare.fuel.utils.config_service import ConfigService @@ -150,7 +151,7 @@ def process_element(self, node: Node): self.process_config_element(self.config_ctx, node) def is_configured_subclass(self, config_dict, base_class): - return issubclass(get_class(self.get_class_path(config_dict)), base_class) + return issubclass(load_class(self.get_class_path(config_dict)), base_class) def start_config(self, config_ctx: ConfigContext): pass