From 26f714d1b673ca7365b2fc430ca05ce19399cf94 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 18 Oct 2024 00:49:16 +0530 Subject: [PATCH 01/18] added the skeleton for the Safetensors experimental dataset Signed-off-by: Minura Punchihewa --- .../safetensors/__init__.py | 0 .../safetensors/safetensors_dataset.py | 35 +++++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py create mode 100644 kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py b/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py new file mode 100644 index 000000000..9923e24f4 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import importlib +from copy import deepcopy +from pathlib import PurePosixPath +from typing import Any + +import fsspec +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + get_filepath_str, + get_protocol_and_path, +) + + +class SafetensorsDataset(AbstractVersionedDataset[Any, Any]): + + def __init__( # noqa: PLR0913 + self, + *, + filepath: str, + backend: str = "torch", + ) -> None: + pass + + def load(self) -> Any: + pass + + def save(self, data: Any) -> None: + pass + + def _exists(self) -> bool: + pass \ No newline at end of file From c2a980e69943ea0f1896179acfff1d28a694ea5d Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 18 Oct 2024 01:10:58 +0530 Subject: [PATCH 02/18] implemented the save() and load() funcs Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 76 +++++++++++++++++-- 1 file changed, 71 insertions(+), 5 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index 9923e24f4..0e8a671d5 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -17,19 +17,85 @@ class SafetensorsDataset(AbstractVersionedDataset[Any, Any]): + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "wb"}} + def __init__( # noqa: PLR0913 self, *, filepath: str, - backend: str = "torch", + backend: str = "pickle", + version: Version | None = None, + credentials: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, ) -> None: - pass + try: + importlib.import_module(f"safetensors.{backend}") + except ImportError as exc: + raise ImportError( + f"Selected backend '{backend}' could not be imported. " + "Make sure it is installed and importable." + ) from exc + + _fs_args = deepcopy(fs_args) or {} + _fs_open_args_load = _fs_args.pop("open_args_load", {}) + _fs_open_args_save = _fs_args.pop("open_args_save", {}) + _credentials = deepcopy(credentials) or {} + + protocol, path = get_protocol_and_path(filepath, version) + if protocol == "file": + _fs_args.setdefault("auto_mkdir", True) + + self._protocol = protocol + self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) + + self.metadata = metadata + + super().__init__( + filepath=PurePosixPath(path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + self._backend = backend + + self._fs_open_args_load = { + **self.DEFAULT_FS_ARGS.get("open_args_load", {}), + **(_fs_open_args_load or {}), + } + self._fs_open_args_save = { + **self.DEFAULT_FS_ARGS.get("open_args_save", {}), + **(_fs_open_args_save or {}), + } def load(self) -> Any: - pass + load_path = get_filepath_str(self._get_load_path(), self._protocol) + + with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: + imported_backend = importlib.import_module(f"safetensors.{self._backend}") + return imported_backend.load(fs_file) def save(self, data: Any) -> None: - pass + save_path = get_filepath_str(self._get_save_path(), self._protocol) + + with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: + try: + imported_backend = importlib.import_module(f"safetensors.{self._backend}") + imported_backend.save_file(data, fs_file) + except Exception as exc: + raise DatasetError( + f"{data.__class__} was not serialised due to: {exc}" + ) from exc + + self._invalidate_cache() def _exists(self) -> bool: - pass \ No newline at end of file + pass + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + filepath = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(filepath) \ No newline at end of file From 58a258df0f16121c72ceefba6f839b6be9ee62ab Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 18 Oct 2024 09:36:04 +0530 Subject: [PATCH 03/18] updated the default backend Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index 0e8a671d5..295843bf9 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -25,7 +25,7 @@ def __init__( # noqa: PLR0913 self, *, filepath: str, - backend: str = "pickle", + backend: str = "torch", version: Version | None = None, credentials: dict[str, Any] | None = None, fs_args: dict[str, Any] | None = None, From 3723b02a2bad1d9f69afad51c314d25cd6116c7b Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 19 Oct 2024 00:11:03 +0530 Subject: [PATCH 04/18] implemented the describe() and exists() funcs Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index 295843bf9..a1d7bf293 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -92,8 +92,21 @@ def save(self, data: Any) -> None: self._invalidate_cache() + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._filepath, + "backend": self._backend, + "protocol": self._protocol, + "version": self._version, + } + def _exists(self) -> bool: - pass + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + + return self._fs.exists(load_path) def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" From 843e815539caf5caabc535cac756be4f1c680af7 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 19 Oct 2024 00:11:14 +0530 Subject: [PATCH 05/18] imported the dataset to main pkg Signed-off-by: Minura Punchihewa --- .../safetensors/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py b/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py index e69de29bb..a9e221689 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py @@ -0,0 +1,11 @@ +"""``AbstractDataset`` implementation to load/save tensors using the SafeTensors library.""" + +from typing import Any + +import lazy_loader as lazy + +SafetensorsDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"safetensors_dataset": ["SafetensorsDataset"]} +) From 9a5493099779a6a38d483784712cca432f3d075e Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 19 Oct 2024 01:04:01 +0530 Subject: [PATCH 06/18] fixed how data is passed to load() Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index a1d7bf293..a54ce3d50 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -76,7 +76,7 @@ def load(self) -> Any: with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: imported_backend = importlib.import_module(f"safetensors.{self._backend}") - return imported_backend.load(fs_file) + return imported_backend.load(fs_file.read()) def save(self, data: Any) -> None: save_path = get_filepath_str(self._get_save_path(), self._protocol) From 8a5d522903b859cd331bc1da1dea5b8d095544a9 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 19 Oct 2024 01:06:05 +0530 Subject: [PATCH 07/18] fixed save() to access the file path Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index a54ce3d50..67effbd27 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -84,7 +84,7 @@ def save(self, data: Any) -> None: with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: try: imported_backend = importlib.import_module(f"safetensors.{self._backend}") - imported_backend.save_file(data, fs_file) + imported_backend.save_file(data, fs_file.name) except Exception as exc: raise DatasetError( f"{data.__class__} was not serialised due to: {exc}" From 1bd4117bb2c54f66de464f0c00931da01ffdbf63 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sat, 19 Oct 2024 01:07:37 +0530 Subject: [PATCH 08/18] added a release() func Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index 67effbd27..21ccf83c9 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -107,6 +107,10 @@ def _exists(self) -> bool: return False return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" From eba18df8f658b120726b31a42498041052f607f5 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 20 Oct 2024 00:41:48 +0530 Subject: [PATCH 09/18] added the docstrings for the dataset Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index 21ccf83c9..48786fc70 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -16,6 +16,40 @@ class SafetensorsDataset(AbstractVersionedDataset[Any, Any]): + """``SafetensorsDataset`` loads/saves data from/to a Safetensors file using an underlying + filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by + the specified backend library passed in (defaults to the ``torch`` library), so it + supports all allowed options for loading and Safetensors pickle files. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + test_model: # simple example without compression + type: safetensors.SafetensorsDataset + filepath: data/07_model_output/test_model.safetensors + backend: torch + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.safetensors import SafetensorsDataset + >>> import torch + >>> + >>> data = {"embeddings": torch.zeros((10, 100)} + >>> dataset = SafetensorsDataset( + ... filepath="test.safetensors", + ... backend="torch" + ... ) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> assert data.equals(reloaded) + """ DEFAULT_LOAD_ARGS: dict[str, Any] = {} DEFAULT_SAVE_ARGS: dict[str, Any] = {} @@ -31,6 +65,43 @@ def __init__( # noqa: PLR0913 fs_args: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ) -> None: + """Creates a new instance of ``SafetensorsDataset`` pointing to a concrete Safetensors + file on a specific filesystem. ``SafetensorsDataset`` supports custom backends to + serialise/deserialise objects. + + The following backends are supported: + * `torch` + * `tensorflow` + * `paddle` + * `flax` + * `numpy` + + Args: + filepath: Filepath in POSIX format to a Pickle file prefixed with a protocol like + `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. + The prefix should be any protocol supported by ``fsspec``. + Note: `http(s)` doesn't support versioning. + backend: The backend library to use for serialising/deserialising objects. + The default backend is 'torch'. + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + attribute is None, save version will be autogenerated. + credentials: Credentials required to get access to the underlying filesystem. + E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as + to pass to the filesystem's `open` method through nested keys + `open_args_load` and `open_args_save`. + Here you can find all available arguments for `open`: + https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open + All defaults are preserved, except `mode`, which is set to `wb` when saving. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + + Raises: + ImportError: If the ``backend`` module could not be imported. + """ try: importlib.import_module(f"safetensors.{backend}") except ImportError as exc: From 98bb71949209adf7b7e0ab527b3ae210ddb35177 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 20 Oct 2024 00:45:22 +0530 Subject: [PATCH 10/18] fixed lint issues Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index 48786fc70..a4385bb95 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -109,7 +109,7 @@ def __init__( # noqa: PLR0913 f"Selected backend '{backend}' could not be imported. " "Make sure it is installed and importable." ) from exc - + _fs_args = deepcopy(fs_args) or {} _fs_open_args_load = _fs_args.pop("open_args_load", {}) _fs_open_args_save = _fs_args.pop("open_args_save", {}) @@ -160,7 +160,7 @@ def save(self, data: Any) -> None: raise DatasetError( f"{data.__class__} was not serialised due to: {exc}" ) from exc - + self._invalidate_cache() def _describe(self) -> dict[str, Any]: @@ -178,7 +178,7 @@ def _exists(self) -> bool: return False return self._fs.exists(load_path) - + def _release(self) -> None: super()._release() self._invalidate_cache() @@ -186,4 +186,4 @@ def _release(self) -> None: def _invalidate_cache(self) -> None: """Invalidate underlying filesystem caches.""" filepath = get_filepath_str(self._filepath, self._protocol) - self._fs.invalidate_cache(filepath) \ No newline at end of file + self._fs.invalidate_cache(filepath) From 5d0e347bf338057f1a36d2bd622b9adc48436b7f Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 20 Oct 2024 09:59:13 +0530 Subject: [PATCH 11/18] added unit tests Signed-off-by: Minura Punchihewa --- .../tests/safetensors/__init__.py | 0 .../safetensors/test_safetensors_dataset.py | 47 +++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 kedro-datasets/kedro_datasets_experimental/tests/safetensors/__init__.py create mode 100644 kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py diff --git a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/__init__.py b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py new file mode 100644 index 000000000..bcc0cdb8d --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py @@ -0,0 +1,47 @@ +import pytest +import torch + +from kedro_datasets_experimental.safetensors import SafetensorsDataset + + +@pytest.fixture +def filepath(tmp_path): + return (tmp_path / "test.safetensors").as_posix() + + +@pytest.fixture(params=["torch"]) +def backend(request): + return request.param + + +@pytest.fixture +def safetensors_dataset(filepath, backend, fs_args): + return SafetensorsDataset( + filepath=filepath, + backend=backend, + fs_args=fs_args, + ) + +@pytest.fixture +def dummy_data(): + return {"embeddings": torch.zeros((10, 100))} + + +class TestSafetensorsDataset: + @pytest.mark.parametrize( + "backend", + [ + "torch", + ], + indirect=True, + ) + def test_save_and_load(self, safetensors_dataset, dummy_data): + """Test saving and reloading the dataset.""" + safetensors_dataset.save(dummy_data) + reloaded = safetensors_dataset.load() + + if safetensors_dataset._backend == "torch": + assert torch.equal(dummy_data["embeddings"], reloaded["embeddings"]) + + assert safetensors_dataset._fs_open_args_load == {} + assert safetensors_dataset._fs_open_args_save == {"mode": "wb"} \ No newline at end of file From 58c1ac7a7b8f67a667b4a44ce93ad35284f718cb Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 20 Oct 2024 11:58:55 +0530 Subject: [PATCH 12/18] added a few more unit tests Signed-off-by: Minura Punchihewa --- .../safetensors/test_safetensors_dataset.py | 187 +++++++++++++++++- 1 file changed, 186 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py index bcc0cdb8d..dfb1e3b81 100644 --- a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py @@ -1,4 +1,11 @@ +from pathlib import Path, PurePosixPath + +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import DatasetError, PROTOCOL_DELIMITER, Version import pytest +from s3fs.core import S3FileSystem import torch from kedro_datasets_experimental.safetensors import SafetensorsDataset @@ -22,6 +29,14 @@ def safetensors_dataset(filepath, backend, fs_args): fs_args=fs_args, ) + +@pytest.fixture +def versioned_safetensors_dataset(filepath, load_version, save_version): + return SafetensorsDataset( + filepath=filepath, version=Version(load_version, save_version) + ) + + @pytest.fixture def dummy_data(): return {"embeddings": torch.zeros((10, 100))} @@ -44,4 +59,174 @@ def test_save_and_load(self, safetensors_dataset, dummy_data): assert torch.equal(dummy_data["embeddings"], reloaded["embeddings"]) assert safetensors_dataset._fs_open_args_load == {} - assert safetensors_dataset._fs_open_args_save == {"mode": "wb"} \ No newline at end of file + assert safetensors_dataset._fs_open_args_save == {"mode": "wb"} + + def test_exists(self, safetensors_dataset, dummy_data): + """Test `exists` method invocation for both existing and + nonexistent dataset.""" + assert not safetensors_dataset.exists() + safetensors_dataset.save(dummy_data) + assert safetensors_dataset.exists() + + @pytest.mark.parametrize( + "fs_args", + [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], + indirect=True, + ) + def test_open_extra_args(self, safetensors_dataset, fs_args): + assert safetensors_dataset._fs_open_args_load == fs_args["open_args_load"] + assert safetensors_dataset._fs_open_args_save == {"mode": "wb"} # default unchanged + + def test_load_missing_file(self, safetensors_dataset): + """Check the error when trying to load missing file.""" + pattern = r"Failed while loading data from dataset SafetensorsDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + safetensors_dataset.load() + + @pytest.mark.parametrize( + "filepath,instance_type", + [ + ("s3://bucket/file.safetensors", S3FileSystem), + ("file:///tmp/test.safetensors", LocalFileSystem), + ("/tmp/test.safetensors", LocalFileSystem), + ("gcs://bucket/file.safetensors", GCSFileSystem), + ("https://example.com/file.safetensors", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, filepath, instance_type): + dataset = SafetensorsDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) + + path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] + + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + filepath = "test.safetensors" + dataset = SafetensorsDataset(filepath=filepath) + dataset.release() + fs_mock.invalidate_cache.assert_called_once_with(filepath) + + def test_invalid_backend(self, mocker): + pattern = ( + r"Selected backend 'invalid' could not be imported. " + r"Make sure it is installed and importable." + ) + mocker.patch( + "kedro_datasets_experimental.safetensors.safetensors_dataset.importlib.import_module", + return_value=object, + ) + with pytest.raises(ImportError, match=pattern): + SafetensorsDataset(filepath="test.safetensors", backend="invalid") + + def test_copy(self, safetensors_dataset): + safetensors_dataset_copy = safetensors_dataset._copy() + assert safetensors_dataset_copy is not safetensors_dataset + assert safetensors_dataset_copy._describe() == safetensors_dataset._describe() + + +class TestSafetensorsDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + """Test that version is in string representation of the class instance + when applicable.""" + filepath = "test.safetensors" + ds = SafetensorsDataset(filepath=filepath) + ds_versioned = SafetensorsDataset( + filepath=filepath, version=Version(load_version, save_version) + ) + assert filepath in str(ds) + assert "version" not in str(ds) + + assert filepath in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "SafetensorsDataset" in str(ds_versioned) + assert "SafetensorsDataset" in str(ds) + assert "protocol" in str(ds_versioned) + assert "protocol" in str(ds) + assert "backend" in str(ds_versioned) + assert "backend" in str(ds) + + def test_save_and_load(self, versioned_safetensors_dataset, dummy_data): + """Test that saved and reloaded data matches the original one for + the versioned dataset.""" + versioned_safetensors_dataset.save(dummy_data) + reloaded_df = versioned_safetensors_dataset.load() + + assert torch.equal(dummy_data["embeddings"], reloaded_df["embeddings"]) + + def test_no_versions(self, versioned_safetensors_dataset): + """Check the error if no versions are available for load.""" + pattern = r"Did not find any versions for SafetensorsDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_safetensors_dataset.load() + + def test_exists(self, versioned_safetensors_dataset, dummy_data): + """Test `exists` method invocation for versioned dataset.""" + assert not versioned_safetensors_dataset.exists() + versioned_safetensors_dataset.save(dummy_data) + assert versioned_safetensors_dataset.exists() + + def test_prevent_overwrite(self, versioned_safetensors_dataset, dummy_data): + """Check the error when attempting to override the dataset if the + corresponding Safetensors file for a given save version already exists.""" + versioned_safetensors_dataset.save(dummy_data) + pattern = ( + r"Save path \'.+\' for SafetensorsDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_safetensors_dataset.save(dummy_data) + + @pytest.mark.parametrize( + "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True + ) + @pytest.mark.parametrize( + "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True + ) + def test_save_version_warning( + self, versioned_safetensors_dataset, load_version, save_version, dummy_data + ): + """Check the warning when saving to the path that differs from + the subsequent load path.""" + pattern = ( + rf"Save version '{save_version}' did not match load version " + rf"'{load_version}' for SafetensorsDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + versioned_safetensors_dataset.save(dummy_data) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + + with pytest.raises(DatasetError, match=pattern): + SafetensorsDataset( + filepath="https://example.com/file.safetensors", version=Version(None, None) + ) + + def test_versioning_existing_dataset( + self, safetensors_dataset, versioned_safetensors_dataset, dummy_data + ): + """Check the error when attempting to save a versioned dataset on top of an + already existing (non-versioned) dataset.""" + safetensors_dataset.save(dummy_data) + assert safetensors_dataset.exists() + assert safetensors_dataset._filepath == versioned_safetensors_dataset._filepath + pattern = ( + f"(?=.*file with the same name already exists in the directory)" + f"(?=.*{versioned_safetensors_dataset._filepath.parent.as_posix()})" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_safetensors_dataset.save(dummy_data) + + # Remove non-versioned dataset and try again + Path(safetensors_dataset._filepath.as_posix()).unlink() + versioned_safetensors_dataset.save(dummy_data) + assert versioned_safetensors_dataset.exists() + + def test_copy(self, versioned_safetensors_dataset): + safetensors_dataset_copy = versioned_safetensors_dataset._copy() + assert safetensors_dataset_copy is not versioned_safetensors_dataset + assert safetensors_dataset_copy._describe() == versioned_safetensors_dataset._describe() \ No newline at end of file From 33501615bde4822984fa16ef9118e4340570f498 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 20 Oct 2024 12:08:28 +0530 Subject: [PATCH 13/18] fixed broken unit test Signed-off-by: Minura Punchihewa --- .../tests/safetensors/test_safetensors_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py index dfb1e3b81..7105aeeb9 100644 --- a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py @@ -116,7 +116,7 @@ def test_invalid_backend(self, mocker): ) mocker.patch( "kedro_datasets_experimental.safetensors.safetensors_dataset.importlib.import_module", - return_value=object, + side_effect=ImportError, ) with pytest.raises(ImportError, match=pattern): SafetensorsDataset(filepath="test.safetensors", backend="invalid") @@ -229,4 +229,4 @@ def test_versioning_existing_dataset( def test_copy(self, versioned_safetensors_dataset): safetensors_dataset_copy = versioned_safetensors_dataset._copy() assert safetensors_dataset_copy is not versioned_safetensors_dataset - assert safetensors_dataset_copy._describe() == versioned_safetensors_dataset._describe() \ No newline at end of file + assert safetensors_dataset_copy._describe() == versioned_safetensors_dataset._describe() From aa6dd9c06bb2d8e911fded644c3317d33ca6d91d Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 20 Oct 2024 12:14:10 +0530 Subject: [PATCH 14/18] fixed lint issues Signed-off-by: Minura Punchihewa --- .../tests/safetensors/test_safetensors_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py index 7105aeeb9..ec7288226 100644 --- a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py @@ -1,12 +1,12 @@ from pathlib import Path, PurePosixPath +import pytest +import torch from fsspec.implementations.http import HTTPFileSystem from fsspec.implementations.local import LocalFileSystem from gcsfs import GCSFileSystem -from kedro.io.core import DatasetError, PROTOCOL_DELIMITER, Version -import pytest +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version from s3fs.core import S3FileSystem -import torch from kedro_datasets_experimental.safetensors import SafetensorsDataset @@ -154,7 +154,7 @@ def test_save_and_load(self, versioned_safetensors_dataset, dummy_data): the versioned dataset.""" versioned_safetensors_dataset.save(dummy_data) reloaded_df = versioned_safetensors_dataset.load() - + assert torch.equal(dummy_data["embeddings"], reloaded_df["embeddings"]) def test_no_versions(self, versioned_safetensors_dataset): From 697f34fee008b608a53f392b767862222a3a9d66 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Sun, 20 Oct 2024 19:25:24 +0530 Subject: [PATCH 15/18] fixed use of insecure temp files Signed-off-by: Minura Punchihewa --- .../tests/safetensors/test_safetensors_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py index ec7288226..dc461cdd7 100644 --- a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py @@ -1,3 +1,4 @@ +import tempfile from pathlib import Path, PurePosixPath import pytest @@ -87,8 +88,8 @@ def test_load_missing_file(self, safetensors_dataset): "filepath,instance_type", [ ("s3://bucket/file.safetensors", S3FileSystem), - ("file:///tmp/test.safetensors", LocalFileSystem), - ("/tmp/test.safetensors", LocalFileSystem), + (tempfile.NamedTemporaryFile(suffix=".safetensors").name, LocalFileSystem), + (tempfile.NamedTemporaryFile(suffix=".safetensors").name, LocalFileSystem), ("gcs://bucket/file.safetensors", GCSFileSystem), ("https://example.com/file.safetensors", HTTPFileSystem), ], From 4c27e7358ca75a75f72347f9fed3bb34f82e4eee Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 21 Oct 2024 17:34:27 +0530 Subject: [PATCH 16/18] added the dataset to the documentation Signed-off-by: Minura Punchihewa --- kedro-datasets/docs/source/api/kedro_datasets_experimental.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst index 219510954..6b2fd56f9 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst @@ -19,3 +19,4 @@ kedro_datasets_experimental prophet.ProphetModelDataset pytorch.PyTorchDataset rioxarray.GeoTIFFDataset + safetensors.SafeTensorsDataset From 7922ed8f9849e2ed19d7895f261ef19045effd58 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 21 Oct 2024 17:56:25 +0530 Subject: [PATCH 17/18] listed the dependencies for the dataset Signed-off-by: Minura Punchihewa --- .../safetensors/safetensors_dataset.py | 4 ++-- kedro-datasets/pyproject.toml | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py index a4385bb95..42fc1069a 100644 --- a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -19,7 +19,7 @@ class SafetensorsDataset(AbstractVersionedDataset[Any, Any]): """``SafetensorsDataset`` loads/saves data from/to a Safetensors file using an underlying filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by the specified backend library passed in (defaults to the ``torch`` library), so it - supports all allowed options for loading and Safetensors pickle files. + supports all allowed options for loading and Safetensors files. Example usage for the `YAML API =0.15.0"] rioxarray = ["kedro-datasets[rioxarray-geotiffdataset]"] +safetensors-safetensorsdataset = ["safetensors"] +safetensors = ["kedro-datasets[safetensors-safetensorsdataset]"] + # Docs requirements docs = [ "kedro-sphinx-theme==2024.10.2", From 3c5981359875622071ef79577c7cc4099b6e609a Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Mon, 21 Oct 2024 18:04:02 +0530 Subject: [PATCH 18/18] fixed typo in dataset reference Signed-off-by: Minura Punchihewa --- kedro-datasets/docs/source/api/kedro_datasets_experimental.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst index 6b2fd56f9..62772baa8 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst @@ -19,4 +19,4 @@ kedro_datasets_experimental prophet.ProphetModelDataset pytorch.PyTorchDataset rioxarray.GeoTIFFDataset - safetensors.SafeTensorsDataset + safetensors.SafetensorsDataset