Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Container Streaming/Retriever #3173

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
68 changes: 68 additions & 0 deletions examples/advanced/streaming/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Object Streaming Examples

## Overview
The examples here demonstrate how to use object streamers to send large file/objects memory efficiently.

The object streamer uses less memory because it sends files by chunks (default chunk size is 1MB) and
it sends containers entry by entry.

For example, if you have a dict with 10 1GB entries, it will take 10GB extra space to send the dict without
streaming. It only requires extra 1GB to serialize the entry using streaming.

## Concepts

### Object Streamer

ObjectStreamer is a base class to stream an object piece by piece. The `StreamableEngine` built in the NVFlare can
stream any implementations of ObjectSteamer

Following implementations are included in NVFlare,

* `FileStreamer`: It can be used to stream a file
* `ContainerStreamer`: This class can stream a container entry by entry. Currently, dict, list and set are supported

The container streamer can only stream the top level entries. All the sub entries of a top entry are sent at once with
the top entry.

### Object Retriever

`ObjectRetriever` is designed to request an object to be streamed from a remote site. It automatically sets up the streaming
on both ends and handles the coordination.

Currently, following implementations are available,

* `FileRetriever`: It's used to retrieve a file from remote site using FileStreamer.
* `ContainerRetriever`: This class can be used to retrieve a container from remote site using ContainerStreamer.

To use ContainerRetriever, the container must be given a name and added on the sending site,

```
ContainerRetriever.add_container("model", model_dict)
```

## Example Jobs

### file_streaming job

This job uses the FileStreamer object to send a large file from server to client.

It demonstrates following mechanisms:
1. It uses components to handle the file transferring. No training workflow is used.
Since executor is required by NVFlare, a dummy executor is created.
2. It shows how to use the streamer directly without an object retriever.

The job creates a temporary file to test. You can run the job in POC or using simulator as follows,

```
nvflare simulator -n 1 -t 1 jobs/file_streaming
```
### dict_streaming job

This job demonstrate how to send a dict from server to client using object retriever.

It creates a task called "retrieve_dict" to tell client to get ready for the streaming.

The example can be run in simulator like this,
```
nvflare simulator -n 1 -t 1 jobs/dict_streaming
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"format_version": 2,
"cell_wait_timeout": 5.0,
"executors": [
{
"tasks": ["*"],
"executor": {
"path": "streaming_executor.StreamingExecutor",
"args": {
"dict_retriever_id": "dict_retriever"
}
}
}
],
"components": [
{
"id": "dict_retriever",
"path": "nvflare.app_common.streamers.container_retriever.ContainerRetriever",
"args": {
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"format_version": 2,
"components": [
{
"id": "dict_retriever",
"path": "nvflare.app_common.streamers.container_retriever.ContainerRetriever",
"args": {
}
}
],
"workflows": [
{
"id": "controller",
"path": "streaming_controller.StreamingController",
"args": {
"dict_retriever_id": "dict_retriever"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from random import randbytes

from nvflare.apis.controller_spec import Client, ClientTask, Task
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.streamers.container_retriever import ContainerRetriever

STREAM_TOPIC = "rtr_file_stream"


class StreamingController(Controller):
def __init__(self, dict_retriever_id=None, task_timeout=60, task_check_period: float = 0.5):
Controller.__init__(self, task_check_period=task_check_period)
self.dict_retriever_id = dict_retriever_id
self.dict_retriever = None
self.task_timeout = task_timeout

def start_controller(self, fl_ctx: FLContext):
model = self._get_test_model()
self.dict_retriever.add_container("model", model)

def stop_controller(self, fl_ctx: FLContext):
pass

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
s = Shareable()
s["name"] = "model"
task = Task(name="retrieve_dict", data=s, timeout=self.task_timeout)
self.broadcast_and_wait(
task=task,
fl_ctx=fl_ctx,
min_responses=1,
abort_signal=abort_signal,
)
client_resps = {}
for ct in task.client_tasks:
assert isinstance(ct, ClientTask)
resp = ct.result
if resp is None:
resp = "no answer"
else:
assert isinstance(resp, Shareable)
self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}")
resp = resp.get_return_code()
client_resps[ct.client.name] = resp
return {"status": "OK", "data": client_resps}

def process_result_of_unknown_task(
self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext
):
pass

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
engine = fl_ctx.get_engine()
if self.dict_retriever_id:
c = engine.get_component(self.dict_retriever_id)
if not isinstance(c, ContainerRetriever):
self.system_panic(
f"invalid dict_retriever {self.dict_retriever_id}, wrong type: {type(c)}",
fl_ctx,
)
return
self.dict_retriever = c

@staticmethod
def _get_test_model() -> dict:
model = {}
for i in range(10):
key = f"layer-{i}"
model[key] = randbytes(1024)

return model
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.streamers.container_retriever import ContainerRetriever


class StreamingExecutor(Executor):
def __init__(self, dict_retriever_id=None):
Executor.__init__(self)
self.dict_retriever_id = dict_retriever_id
self.dict_retriever = None

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
engine = fl_ctx.get_engine()
if self.dict_retriever_id:
c = engine.get_component(self.dict_retriever_id)
if not isinstance(c, ContainerRetriever):
self.system_panic(
f"invalid dict_retriever {self.dict_retriever_id}, wrong type: {type(c)}",
fl_ctx,
)
return
self.dict_retriever = c

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
self.log_info(fl_ctx, f"got task {task_name}: {shareable}")
if task_name == "retrieve_dict":
name = shareable.get("name")
if not name:
self.log_error(fl_ctx, "missing name in request")
return make_reply(ReturnCode.BAD_TASK_DATA)
if not self.dict_retriever:
self.log_error(fl_ctx, "no container retriever")
return make_reply(ReturnCode.SERVICE_UNAVAILABLE)

assert isinstance(self.dict_retriever, ContainerRetriever)
rc, result = self.dict_retriever.retrieve_container(
from_site="server",
fl_ctx=fl_ctx,
timeout=10.0,
name=name,
)
if rc != ReturnCode.OK:
self.log_error(fl_ctx, f"failed to retrieve dict {name}: {rc}")
return make_reply(rc)

self.log_info(fl_ctx, f"received container type: {type(result)} size: {len(result)}")
return make_reply(ReturnCode.OK)
else:
self.log_error(fl_ctx, f"got unknown task {task_name}")
return make_reply(ReturnCode.TASK_UNKNOWN)
10 changes: 10 additions & 0 deletions examples/advanced/streaming/jobs/dict_streaming/meta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"name": "file_streaming",
"resource_spec": {},
"min_clients" : 1,
"deploy_map": {
"app": [
"@ALL"
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"train"
],
"executor": {
"path": "trainer.TestTrainer",
"args": {}
}
}
],
"task_result_filters": [],
"task_data_filters": [],
"components": [
{
"id": "sender",
"path": "file_streaming.FileSender",
"args": {}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"format_version": 2,
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "receiver",
"path": "file_streaming.FileReceiver",
"args": {}
}
],
"workflows": [
{
"id": "controller",
"path": "controller.SimpleController",
"args": {}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading