diff --git a/kedro-airflow/kedro_airflow/grouping.py b/kedro-airflow/kedro_airflow/grouping.py index 26c931f8d..3890804ae 100644 --- a/kedro-airflow/kedro_airflow/grouping.py +++ b/kedro-airflow/kedro_airflow/grouping.py @@ -4,6 +4,11 @@ from kedro.pipeline.node import Node from kedro.pipeline.pipeline import Pipeline +try: + from kedro.io import CatalogProtocol +except ImportError: # pragma: no cover + pass + def _is_memory_dataset(catalog, dataset_name: str) -> bool: if dataset_name not in catalog: @@ -11,7 +16,9 @@ def _is_memory_dataset(catalog, dataset_name: str) -> bool: return False -def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]: +def get_memory_datasets( + catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline +) -> set[str]: """Gather all datasets in the pipeline that are of type MemoryDataset, excluding 'parameters'.""" return { dataset_name @@ -21,7 +28,7 @@ def get_memory_datasets(catalog: DataCatalog, pipeline: Pipeline) -> set[str]: def create_adjacency_list( - catalog: DataCatalog, pipeline: Pipeline + catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline ) -> tuple[dict[str, set], dict[str, set]]: """ Builds adjacency list (adj_list) to search connected components - undirected graph, @@ -48,7 +55,7 @@ def create_adjacency_list( def group_memory_nodes( - catalog: DataCatalog, pipeline: Pipeline + catalog: CatalogProtocol | DataCatalog, pipeline: Pipeline ) -> tuple[dict[str, list[Node]], dict[str, list[str]]]: """ Nodes that are connected through MemoryDatasets cannot be distributed across diff --git a/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py b/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py index 914fdb6b7..15c10a93d 100644 --- a/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/pytorch/pytorch_dataset.py @@ -96,11 +96,11 @@ def _describe(self) -> dict[str, Any]: def _load(self) -> Any: load_path = get_filepath_str(self._get_load_path(), self._protocol) - return torch.load(load_path, **self._fs_open_args_load) + return torch.load(load_path, **self._fs_open_args_load) #nosec: B614 def _save(self, data: torch.nn.Module) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) - torch.save(data.state_dict(), save_path, **self._fs_open_args_save) + torch.save(data.state_dict(), save_path, **self._fs_open_args_save) #nosec: B614 self._invalidate_cache()