From 06429d86e4e4a09054d709ad3a0358b4b11716bd Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 7 Jan 2025 10:47:41 -0500 Subject: [PATCH 1/8] Added yaml include support. --- nvflare/lighter/utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index ae71e409fa..86ddd77edc 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -180,11 +180,26 @@ def sign_all(content_folder, signing_pri_key): return signatures +class YamlLoader(yaml.SafeLoader): + + def __init__(self, stream): + + self._root = os.path.split(stream.name)[0] + super(YamlLoader, self).__init__(stream) + + def include(self, node): + + filename = os.path.join(self._root, self.construct_scalar(node)) + with open(filename, "r") as f: + return yaml.load(f, YamlLoader) + + def load_yaml(file): + YamlLoader.add_constructor("!include", YamlLoader.include) if isinstance(file, str): - return yaml.safe_load(open(file, "r")) + return yaml.load(open(file, "r"), YamlLoader) elif isinstance(file, bytes): - return yaml.safe_load(file) + return yaml.load(file, YamlLoader) else: return None From f7c46e9b04b744aea0359406c51b7f07e48aec7a Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 7 Jan 2025 12:31:49 -0500 Subject: [PATCH 2/8] Added check to avoid duplicate yaml constructor adding. --- nvflare/lighter/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index 86ddd77edc..1715e5fe7e 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -195,7 +195,8 @@ def include(self, node): def load_yaml(file): - YamlLoader.add_constructor("!include", YamlLoader.include) + if "!include" not in YamlLoader.yaml_constructors: + YamlLoader.add_constructor("!include", YamlLoader.include) if isinstance(file, str): return yaml.load(open(file, "r"), YamlLoader) elif isinstance(file, bytes): From a2d5d3afce6be2ac8c099b67827d49235b677bed Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 8 Jan 2025 10:56:59 -0500 Subject: [PATCH 3/8] Added unit tests for load_yaml with include another yaml file. --- tests/unit_test/lighter/0.yml | 4 ++++ tests/unit_test/lighter/1.yml | 1 + tests/unit_test/lighter/utils_test.py | 6 +++++- 3 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 tests/unit_test/lighter/0.yml create mode 100644 tests/unit_test/lighter/1.yml diff --git a/tests/unit_test/lighter/0.yml b/tests/unit_test/lighter/0.yml new file mode 100644 index 0000000000..609438e9a0 --- /dev/null +++ b/tests/unit_test/lighter/0.yml @@ -0,0 +1,4 @@ +api_version: 3 +name: example_project + +server: !include 1.yml \ No newline at end of file diff --git a/tests/unit_test/lighter/1.yml b/tests/unit_test/lighter/1.yml new file mode 100644 index 0000000000..4dece82e9e --- /dev/null +++ b/tests/unit_test/lighter/1.yml @@ -0,0 +1 @@ +server_name: server \ No newline at end of file diff --git a/tests/unit_test/lighter/utils_test.py b/tests/unit_test/lighter/utils_test.py index f1bdb7266a..c81e7e8699 100644 --- a/tests/unit_test/lighter/utils_test.py +++ b/tests/unit_test/lighter/utils_test.py @@ -25,7 +25,7 @@ from cryptography.x509.oid import NameOID from nvflare.lighter.impl.cert import serialize_cert -from nvflare.lighter.utils import sign_folders, verify_folder_signature +from nvflare.lighter.utils import sign_folders, verify_folder_signature, load_yaml folders = ["folder1", "folder2"] files = ["file1", "file2"] @@ -144,3 +144,7 @@ def test_verify_updated_folder(self): os.unlink("client.crt") os.unlink("root.crt") shutil.rmtree(folder) + + def test_load_yaml(self): + data = load_yaml("0.yml") + assert data.get("server").get("server_name") == "server" \ No newline at end of file From 00ae6c9bc37396e682b8ccd7d75b9c36a4d4bfd9 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 8 Jan 2025 11:05:30 -0500 Subject: [PATCH 4/8] codestyle fix. --- tests/unit_test/lighter/utils_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit_test/lighter/utils_test.py b/tests/unit_test/lighter/utils_test.py index c81e7e8699..9d553c818f 100644 --- a/tests/unit_test/lighter/utils_test.py +++ b/tests/unit_test/lighter/utils_test.py @@ -25,7 +25,7 @@ from cryptography.x509.oid import NameOID from nvflare.lighter.impl.cert import serialize_cert -from nvflare.lighter.utils import sign_folders, verify_folder_signature, load_yaml +from nvflare.lighter.utils import load_yaml, sign_folders, verify_folder_signature folders = ["folder1", "folder2"] files = ["file1", "file2"] @@ -147,4 +147,4 @@ def test_verify_updated_folder(self): def test_load_yaml(self): data = load_yaml("0.yml") - assert data.get("server").get("server_name") == "server" \ No newline at end of file + assert data.get("server").get("server_name") == "server" From 8a9d404ca3fa0394cc611fda3c867f917bace613 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Wed, 8 Jan 2025 11:25:58 -0500 Subject: [PATCH 5/8] Added the folder path to the testing yaml file. --- tests/unit_test/lighter/utils_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit_test/lighter/utils_test.py b/tests/unit_test/lighter/utils_test.py index 9d553c818f..835f935784 100644 --- a/tests/unit_test/lighter/utils_test.py +++ b/tests/unit_test/lighter/utils_test.py @@ -146,5 +146,6 @@ def test_verify_updated_folder(self): shutil.rmtree(folder) def test_load_yaml(self): - data = load_yaml("0.yml") + dir_path = os.path.dirname(os.path.realpath(__file__)) + data = load_yaml(os.path.join(dir_path, "0.yml")) assert data.get("server").get("server_name") == "server" From 1e443df8c176766c21d4068dd3a2e54512111dc1 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 17 Jan 2025 14:03:25 -0500 Subject: [PATCH 6/8] Changed the way how to do yaml include. --- nvflare/lighter/utils.py | 52 ++++++++++++++++----------- tests/unit_test/lighter/0.yml | 10 +++++- tests/unit_test/lighter/2.yml | 1 + tests/unit_test/lighter/utils_test.py | 14 +++++++- 4 files changed, 55 insertions(+), 22 deletions(-) create mode 100644 tests/unit_test/lighter/2.yml diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index 1715e5fe7e..a571759b9e 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -180,29 +180,41 @@ def sign_all(content_folder, signing_pri_key): return signatures -class YamlLoader(yaml.SafeLoader): - - def __init__(self, stream): - - self._root = os.path.split(stream.name)[0] - super(YamlLoader, self).__init__(stream) - - def include(self, node): - - filename = os.path.join(self._root, self.construct_scalar(node)) - with open(filename, "r") as f: - return yaml.load(f, YamlLoader) - - def load_yaml(file): - if "!include" not in YamlLoader.yaml_constructors: - YamlLoader.add_constructor("!include", YamlLoader.include) + + root = os.path.split(file)[0] + yaml_data = None if isinstance(file, str): - return yaml.load(open(file, "r"), YamlLoader) + yaml_data = yaml.safe_load(open(file, "r")) elif isinstance(file, bytes): - return yaml.load(file, YamlLoader) - else: - return None + yaml_data = yaml.safe_load(file) + + yaml_data = load_yaml_include(root, yaml_data) + + return yaml_data + + +def load_yaml_include(root, yaml_data): + new_data = {} + for k, v in yaml_data.items(): + if k == "include": + if isinstance(v, str): + includes = [v] + elif isinstance(v, list): + includes = v + for item in includes: + new_data.update(load_yaml(os.path.join(root, item))) + elif isinstance(v, list): + new_list = [] + for item in v: + if isinstance(item, dict): + item = load_yaml_include(root, item) + new_list.append(item) + new_data[k] = new_list + else: + new_data[k] = v + + return new_data def sh_replace(src, mapping_dict): diff --git a/tests/unit_test/lighter/0.yml b/tests/unit_test/lighter/0.yml index 609438e9a0..3290ded4e6 100644 --- a/tests/unit_test/lighter/0.yml +++ b/tests/unit_test/lighter/0.yml @@ -1,4 +1,12 @@ api_version: 3 name: example_project -server: !include 1.yml \ No newline at end of file +include: 1.yml + +participants: + - name: server + port: 123 + include: [1.yml] + - name: client + port: 234 + include: 2.yml diff --git a/tests/unit_test/lighter/2.yml b/tests/unit_test/lighter/2.yml new file mode 100644 index 0000000000..18d2519c61 --- /dev/null +++ b/tests/unit_test/lighter/2.yml @@ -0,0 +1 @@ +client_name: client-1 \ No newline at end of file diff --git a/tests/unit_test/lighter/utils_test.py b/tests/unit_test/lighter/utils_test.py index 835f935784..01c4b4df2e 100644 --- a/tests/unit_test/lighter/utils_test.py +++ b/tests/unit_test/lighter/utils_test.py @@ -145,7 +145,19 @@ def test_verify_updated_folder(self): os.unlink("root.crt") shutil.rmtree(folder) + def _get_participant(self, name, participants): + for p in participants: + if p.get("name") == name: + return p + def test_load_yaml(self): dir_path = os.path.dirname(os.path.realpath(__file__)) data = load_yaml(os.path.join(dir_path, "0.yml")) - assert data.get("server").get("server_name") == "server" + + assert data.get("server_name") == "server" + + participant = self._get_participant("server", data.get("participants")) + assert participant.get("server_name") == "server" + + participant = self._get_participant("client", data.get("participants")) + assert participant.get("client_name") == "client-1" From d97d8e928a28ee30f1f685b92f1b2aaf0360710a Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Tue, 21 Jan 2025 09:58:28 -0500 Subject: [PATCH 7/8] Added dict value for yaml include solution. --- nvflare/lighter/utils.py | 2 ++ tests/unit_test/lighter/0.yml | 3 +++ tests/unit_test/lighter/3.yml | 2 ++ tests/unit_test/lighter/utils_test.py | 1 + 4 files changed, 8 insertions(+) create mode 100644 tests/unit_test/lighter/3.yml diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index a571759b9e..e2a88bb565 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -211,6 +211,8 @@ def load_yaml_include(root, yaml_data): item = load_yaml_include(root, item) new_list.append(item) new_data[k] = new_list + elif isinstance(v, dict): + new_data[k] = load_yaml_include(root, v) else: new_data[k] = v diff --git a/tests/unit_test/lighter/0.yml b/tests/unit_test/lighter/0.yml index 3290ded4e6..bc9ee7e07c 100644 --- a/tests/unit_test/lighter/0.yml +++ b/tests/unit_test/lighter/0.yml @@ -7,6 +7,9 @@ participants: - name: server port: 123 include: [1.yml] + extra: + location: "east" + include: 3.yml - name: client port: 234 include: 2.yml diff --git a/tests/unit_test/lighter/3.yml b/tests/unit_test/lighter/3.yml new file mode 100644 index 0000000000..08117a92cc --- /dev/null +++ b/tests/unit_test/lighter/3.yml @@ -0,0 +1,2 @@ +size: 4 +gpus: large \ No newline at end of file diff --git a/tests/unit_test/lighter/utils_test.py b/tests/unit_test/lighter/utils_test.py index 01c4b4df2e..60c3eaa9c8 100644 --- a/tests/unit_test/lighter/utils_test.py +++ b/tests/unit_test/lighter/utils_test.py @@ -158,6 +158,7 @@ def test_load_yaml(self): participant = self._get_participant("server", data.get("participants")) assert participant.get("server_name") == "server" + assert participant.get("extra").get("gpus") == "large" participant = self._get_participant("client", data.get("participants")) assert participant.get("client_name") == "client-1" From b755013f592c927873f369e290384a3154d748d6 Mon Sep 17 00:00:00 2001 From: Yuhong Wen Date: Fri, 24 Jan 2025 12:42:27 -0500 Subject: [PATCH 8/8] Changed for safe open file. --- nvflare/lighter/utils.py | 41 +++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index e2a88bb565..8cee919cca 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -40,7 +40,8 @@ def serialize_cert(cert): def load_crt(path): - return load_crt_bytes(open(path, "rb").read()) + with open(path, "rb") as f: + return load_crt_bytes(f.read()) def load_crt_bytes(data: bytes): @@ -116,17 +117,19 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999): for file in files: if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE: continue - signatures[file] = sign_content( - content=open(os.path.join(root, file), "rb").read(), - signing_pri_key=signing_pri_key, - ) + with open(os.path.join(root, file), "rb") as f: + signatures[file] = sign_content( + content=f.read(), + signing_pri_key=signing_pri_key, + ) for folder in folders: signatures[folder] = sign_content( content=folder, signing_pri_key=signing_pri_key, ) - json.dump(signatures, open(os.path.join(root, NVFLARE_SIG_FILE), "wt")) + with open(os.path.join(root, NVFLARE_SIG_FILE), "wt") as f: + json.dump(signatures, f) shutil.copyfile(crt_path, os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE)) if depth >= max_depth: break @@ -138,7 +141,8 @@ def verify_folder_signature(src_folder, root_ca_path): root_ca_public_key = root_ca_cert.public_key() for root, folders, files in os.walk(src_folder): try: - signatures = json.load(open(os.path.join(root, NVFLARE_SIG_FILE), "rt")) + with open(os.path.join(root, NVFLARE_SIG_FILE), "rt") as f: + signatures = json.load(f) cert = load_crt(os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE)) public_key = cert.public_key() except: @@ -150,11 +154,12 @@ def verify_folder_signature(src_folder, root_ca_path): continue signature = signatures.get(file) if signature: - verify_content( - content=open(os.path.join(root, file), "rb").read(), - signature=signature, - public_key=public_key, - ) + with open(os.path.join(root, file), "rb") as f: + verify_content( + content=f.read(), + signature=signature, + public_key=public_key, + ) for folder in folders: signature = signatures.get(folder) if signature: @@ -173,10 +178,11 @@ def sign_all(content_folder, signing_pri_key): for f in os.listdir(content_folder): path = os.path.join(content_folder, f) if os.path.isfile(path): - signatures[f] = sign_content( - content=open(path, "rb").read(), - signing_pri_key=signing_pri_key, - ) + with open(path, "rb") as file: + signatures[f] = sign_content( + content=file.read(), + signing_pri_key=signing_pri_key, + ) return signatures @@ -185,7 +191,8 @@ def load_yaml(file): root = os.path.split(file)[0] yaml_data = None if isinstance(file, str): - yaml_data = yaml.safe_load(open(file, "r")) + with open(file, "r") as f: + yaml_data = yaml.safe_load(f) elif isinstance(file, bytes): yaml_data = yaml.safe_load(file)