From 7cf254eec8f7fc6ddf8ca0f40719e5f0ad6efa2f Mon Sep 17 00:00:00 2001 From: Ankita Katiyar Date: Mon, 23 Sep 2024 12:34:48 +0100 Subject: [PATCH 1/4] Replace type checking with CatalogProtocol Signed-off-by: Ankita Katiyar --- kedro-airflow/kedro_airflow/grouping.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kedro-airflow/kedro_airflow/grouping.py b/kedro-airflow/kedro_airflow/grouping.py index 26c931f8d..7ac1a8339 100644 --- a/kedro-airflow/kedro_airflow/grouping.py +++ b/kedro-airflow/kedro_airflow/grouping.py @@ -1,6 +1,6 @@ from __future__ import annotations -from kedro.io import DataCatalog +from kedro.io import CatalogProtocol from kedro.pipeline.node import Node from kedro.pipeline.pipeline import Pipeline @@ -11,7 +11,7 @@ def _is_memory_dataset(catalog, dataset_name: str) -> bool: return False -def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]: +def get_memory_datasets(catalog: CatalogProtocol, pipeline: Pipeline) -> set[str]: """Gather all datasets in the pipeline that are of type MemoryDataset, excluding 'parameters'.""" return { dataset_name @@ -21,7 +21,7 @@ def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]: def create_adjacency_list( - catalog: DataCatalog, pipeline: Pipeline + catalog: CatalogProtocol, pipeline: Pipeline ) -> tuple[dict[str, set], dict[str, set]]: """ Builds adjacency list (adj_list) to search connected components - undirected graph, @@ -48,7 +48,7 @@ def create_adjacency_list( def group_memory_nodes( - catalog: DataCatalog, pipeline: Pipeline + catalog: CatalogProtocol, pipeline: Pipeline ) -> tuple[dict[str, list[Node]], dict[str, list[str]]]: """ Nodes that are connected through MemoryDatasets cannot be distributed across From e66ce2546357ff8830171f7e9903b3efaf6e1de4 Mon Sep 17 00:00:00 2001 From: Ankita Katiyar Date: Tue, 24 Sep 2024 12:20:28 +0100 Subject: [PATCH 2/4] Add try except for import Signed-off-by: Ankita Katiyar --- kedro-airflow/kedro_airflow/grouping.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/kedro-airflow/kedro_airflow/grouping.py b/kedro-airflow/kedro_airflow/grouping.py index 7ac1a8339..913d6d817 100644 --- a/kedro-airflow/kedro_airflow/grouping.py +++ b/kedro-airflow/kedro_airflow/grouping.py @@ -1,9 +1,16 @@ from __future__ import annotations -from kedro.io import CatalogProtocol +from typing import Any + +from kedro.io import DataCatalog from kedro.pipeline.node import Node from kedro.pipeline.pipeline import Pipeline +try: + from kedro.io import CatalogProtocol +except ImportError: # pragma: no cover + pass + def _is_memory_dataset(catalog, dataset_name: str) -> bool: if dataset_name not in catalog: @@ -11,7 +18,9 @@ def _is_memory_dataset(catalog, dataset_name: str) -> bool: return False -def get_memory_datasets(catalog: CatalogProtocol, pipeline: Pipeline) -> set[str]: +def get_memory_datasets( + catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline +) -> set[str]: """Gather all datasets in the pipeline that are of type MemoryDataset, excluding 'parameters'.""" return { dataset_name @@ -21,7 +30,7 @@ def get_memory_datasets(catalog: CatalogProtocol, pipeline: Pipeline) -> set[str def create_adjacency_list( - catalog: CatalogProtocol, pipeline: Pipeline + catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline ) -> tuple[dict[str, set], dict[str, set]]: """ Builds adjacency list (adj_list) to search connected components - undirected graph, @@ -48,7 +57,7 @@ def create_adjacency_list( def group_memory_nodes( - catalog: CatalogProtocol, pipeline: Pipeline + catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline ) -> tuple[dict[str, list[Node]], dict[str, list[str]]]: """ Nodes that are connected through MemoryDatasets cannot be distributed across From 9e594ed75d980accb47b3e7dd1c54904785cf0a8 Mon Sep 17 00:00:00 2001 From: Ankita Katiyar Date: Wed, 25 Sep 2024 11:32:06 +0100 Subject: [PATCH 3/4] Ignore bandit warnings Signed-off-by: Ankita Katiyar --- .../kedro_datasets_experimental/pytorch/pytorch_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py b/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py index 914fdb6b7..15c10a93d 100644 --- a/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py @@ -96,11 +96,11 @@ def _describe(self) -> dict[str, Any]: def _load(self) -> Any: load_path = get_filepath_str(self._get_load_path(), self._protocol) - return torch.load(load_path, **self._fs_open_args_load) + return torch.load(load_path, **self._fs_open_args_load) #nosec: B614 def _save(self, data: torch.nn.Module) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) - torch.save(data.state_dict(), save_path, **self._fs_open_args_save) + torch.save(data.state_dict(), save_path, **self._fs_open_args_save) #nosec: B614 self._invalidate_cache() From 2651901338a5d548b7945585d8c0bf4a56d2bad5 Mon Sep 17 00:00:00 2001 From: Ankita Katiyar Date: Wed, 25 Sep 2024 12:52:13 +0100 Subject: [PATCH 4/4] Remove any Signed-off-by: Ankita Katiyar --- kedro-airflow/kedro_airflow/grouping.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/kedro-airflow/kedro_airflow/grouping.py b/kedro-airflow/kedro_airflow/grouping.py index 913d6d817..3890804ae 100644 --- a/kedro-airflow/kedro_airflow/grouping.py +++ b/kedro-airflow/kedro_airflow/grouping.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Any - from kedro.io import DataCatalog from kedro.pipeline.node import Node from kedro.pipeline.pipeline import Pipeline @@ -19,7 +17,7 @@ def _is_memory_dataset(catalog, dataset_name: str) -> bool: def get_memory_datasets( - catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline + catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline ) -> set[str]: """Gather all datasets in the pipeline that are of type MemoryDataset, excluding 'parameters'.""" return { @@ -30,7 +28,7 @@ def get_memory_datasets( def create_adjacency_list( - catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline + catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline ) -> tuple[dict[str, set], dict[str, set]]: """ Builds adjacency list (adj_list) to search connected components - undirected graph, @@ -57,7 +55,7 @@ def create_adjacency_list( def group_memory_nodes( - catalog: CatalogProtocol[Any] | DataCatalog, pipeline: Pipeline + catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline ) -> tuple[dict[str, list[Node]], dict[str, list[str]]]: """ Nodes that are connected through MemoryDatasets cannot be distributed across