diff --git a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst index 72328e47e..def36d4ef 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst @@ -20,4 +20,5 @@ kedro_datasets_experimental prophet.ProphetModelDataset pytorch.PyTorchDataset rioxarray.GeoTIFFDataset + safetensors.SafetensorsDataset video.VideoDataset diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py b/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py new file mode 100644 index 000000000..a9e221689 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py @@ -0,0 +1,11 @@ +"""``AbstractDataset`` implementation to load/save tensors using the SafeTensors library.""" + +from typing import Any + +import lazy_loader as lazy + +SafetensorsDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"safetensors_dataset": ["SafetensorsDataset"]} +) diff --git a/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py new file mode 100644 index 000000000..42fc1069a --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import importlib +from copy import deepcopy +from pathlib import PurePosixPath +from typing import Any + +import fsspec +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + get_filepath_str, + get_protocol_and_path, +) + + +class SafetensorsDataset(AbstractVersionedDataset[Any, Any]): + """``SafetensorsDataset`` loads/saves data from/to a Safetensors file using an underlying + filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by + the specified backend library passed in (defaults to the ``torch`` library), so it + supports all allowed options for loading and Safetensors files. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + test_model: # simple example without compression + type: safetensors.SafetensorsDataset + filepath: data/07_model_output/test_model.safetensors + backend: torch + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.safetensors import SafetensorsDataset + >>> import torch + >>> + >>> data = {"embeddings": torch.zeros((10, 100)} + >>> dataset = SafetensorsDataset( + ... filepath="test.safetensors", + ... backend="torch" + ... ) + >>> dataset.save(data) + >>> reloaded = dataset.load() + >>> assert data.equals(reloaded) + """ + + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "wb"}} + + def __init__( # noqa: PLR0913 + self, + *, + filepath: str, + backend: str = "torch", + version: Version | None = None, + credentials: dict[str, Any] | None = None, + fs_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """Creates a new instance of ``SafetensorsDataset`` pointing to a concrete Safetensors + file on a specific filesystem. ``SafetensorsDataset`` supports custom backends to + serialise/deserialise objects. + + The following backends are supported: + * `torch` + * `tensorflow` + * `paddle` + * `flax` + * `numpy` + + Args: + filepath: Filepath in POSIX format to a Safetensors file prefixed with a protocol like + `s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. + The prefix should be any protocol supported by ``fsspec``. + Note: `http(s)` doesn't support versioning. + backend: The backend library to use for serialising/deserialising objects. + The default backend is 'torch'. + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + attribute is None, save version will be autogenerated. + credentials: Credentials required to get access to the underlying filesystem. + E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. + fs_args: Extra arguments to pass into underlying filesystem class constructor + (e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as + to pass to the filesystem's `open` method through nested keys + `open_args_load` and `open_args_save`. + Here you can find all available arguments for `open`: + https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open + All defaults are preserved, except `mode`, which is set to `wb` when saving. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + + Raises: + ImportError: If the ``backend`` module could not be imported. + """ + try: + importlib.import_module(f"safetensors.{backend}") + except ImportError as exc: + raise ImportError( + f"Selected backend '{backend}' could not be imported. " + "Make sure it is installed and importable." + ) from exc + + _fs_args = deepcopy(fs_args) or {} + _fs_open_args_load = _fs_args.pop("open_args_load", {}) + _fs_open_args_save = _fs_args.pop("open_args_save", {}) + _credentials = deepcopy(credentials) or {} + + protocol, path = get_protocol_and_path(filepath, version) + if protocol == "file": + _fs_args.setdefault("auto_mkdir", True) + + self._protocol = protocol + self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) + + self.metadata = metadata + + super().__init__( + filepath=PurePosixPath(path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + self._backend = backend + + self._fs_open_args_load = { + **self.DEFAULT_FS_ARGS.get("open_args_load", {}), + **(_fs_open_args_load or {}), + } + self._fs_open_args_save = { + **self.DEFAULT_FS_ARGS.get("open_args_save", {}), + **(_fs_open_args_save or {}), + } + + def load(self) -> Any: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + + with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: + imported_backend = importlib.import_module(f"safetensors.{self._backend}") + return imported_backend.load(fs_file.read()) + + def save(self, data: Any) -> None: + save_path = get_filepath_str(self._get_save_path(), self._protocol) + + with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: + try: + imported_backend = importlib.import_module(f"safetensors.{self._backend}") + imported_backend.save_file(data, fs_file.name) + except Exception as exc: + raise DatasetError( + f"{data.__class__} was not serialised due to: {exc}" + ) from exc + + self._invalidate_cache() + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._filepath, + "backend": self._backend, + "protocol": self._protocol, + "version": self._version, + } + + def _exists(self) -> bool: + try: + load_path = get_filepath_str(self._get_load_path(), self._protocol) + except DatasetError: + return False + + return self._fs.exists(load_path) + + def _release(self) -> None: + super()._release() + self._invalidate_cache() + + def _invalidate_cache(self) -> None: + """Invalidate underlying filesystem caches.""" + filepath = get_filepath_str(self._filepath, self._protocol) + self._fs.invalidate_cache(filepath) diff --git a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/__init__.py b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py new file mode 100644 index 000000000..dc461cdd7 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/safetensors/test_safetensors_dataset.py @@ -0,0 +1,233 @@ +import tempfile +from pathlib import Path, PurePosixPath + +import pytest +import torch +from fsspec.implementations.http import HTTPFileSystem +from fsspec.implementations.local import LocalFileSystem +from gcsfs import GCSFileSystem +from kedro.io.core import PROTOCOL_DELIMITER, DatasetError, Version +from s3fs.core import S3FileSystem + +from kedro_datasets_experimental.safetensors import SafetensorsDataset + + +@pytest.fixture +def filepath(tmp_path): + return (tmp_path / "test.safetensors").as_posix() + + +@pytest.fixture(params=["torch"]) +def backend(request): + return request.param + + +@pytest.fixture +def safetensors_dataset(filepath, backend, fs_args): + return SafetensorsDataset( + filepath=filepath, + backend=backend, + fs_args=fs_args, + ) + + +@pytest.fixture +def versioned_safetensors_dataset(filepath, load_version, save_version): + return SafetensorsDataset( + filepath=filepath, version=Version(load_version, save_version) + ) + + +@pytest.fixture +def dummy_data(): + return {"embeddings": torch.zeros((10, 100))} + + +class TestSafetensorsDataset: + @pytest.mark.parametrize( + "backend", + [ + "torch", + ], + indirect=True, + ) + def test_save_and_load(self, safetensors_dataset, dummy_data): + """Test saving and reloading the dataset.""" + safetensors_dataset.save(dummy_data) + reloaded = safetensors_dataset.load() + + if safetensors_dataset._backend == "torch": + assert torch.equal(dummy_data["embeddings"], reloaded["embeddings"]) + + assert safetensors_dataset._fs_open_args_load == {} + assert safetensors_dataset._fs_open_args_save == {"mode": "wb"} + + def test_exists(self, safetensors_dataset, dummy_data): + """Test `exists` method invocation for both existing and + nonexistent dataset.""" + assert not safetensors_dataset.exists() + safetensors_dataset.save(dummy_data) + assert safetensors_dataset.exists() + + @pytest.mark.parametrize( + "fs_args", + [{"open_args_load": {"mode": "rb", "compression": "gzip"}}], + indirect=True, + ) + def test_open_extra_args(self, safetensors_dataset, fs_args): + assert safetensors_dataset._fs_open_args_load == fs_args["open_args_load"] + assert safetensors_dataset._fs_open_args_save == {"mode": "wb"} # default unchanged + + def test_load_missing_file(self, safetensors_dataset): + """Check the error when trying to load missing file.""" + pattern = r"Failed while loading data from dataset SafetensorsDataset\(.*\)" + with pytest.raises(DatasetError, match=pattern): + safetensors_dataset.load() + + @pytest.mark.parametrize( + "filepath,instance_type", + [ + ("s3://bucket/file.safetensors", S3FileSystem), + (tempfile.NamedTemporaryFile(suffix=".safetensors").name, LocalFileSystem), + (tempfile.NamedTemporaryFile(suffix=".safetensors").name, LocalFileSystem), + ("gcs://bucket/file.safetensors", GCSFileSystem), + ("https://example.com/file.safetensors", HTTPFileSystem), + ], + ) + def test_protocol_usage(self, filepath, instance_type): + dataset = SafetensorsDataset(filepath=filepath) + assert isinstance(dataset._fs, instance_type) + + path = filepath.split(PROTOCOL_DELIMITER, 1)[-1] + + assert str(dataset._filepath) == path + assert isinstance(dataset._filepath, PurePosixPath) + + def test_catalog_release(self, mocker): + fs_mock = mocker.patch("fsspec.filesystem").return_value + filepath = "test.safetensors" + dataset = SafetensorsDataset(filepath=filepath) + dataset.release() + fs_mock.invalidate_cache.assert_called_once_with(filepath) + + def test_invalid_backend(self, mocker): + pattern = ( + r"Selected backend 'invalid' could not be imported. " + r"Make sure it is installed and importable." + ) + mocker.patch( + "kedro_datasets_experimental.safetensors.safetensors_dataset.importlib.import_module", + side_effect=ImportError, + ) + with pytest.raises(ImportError, match=pattern): + SafetensorsDataset(filepath="test.safetensors", backend="invalid") + + def test_copy(self, safetensors_dataset): + safetensors_dataset_copy = safetensors_dataset._copy() + assert safetensors_dataset_copy is not safetensors_dataset + assert safetensors_dataset_copy._describe() == safetensors_dataset._describe() + + +class TestSafetensorsDatasetVersioned: + def test_version_str_repr(self, load_version, save_version): + """Test that version is in string representation of the class instance + when applicable.""" + filepath = "test.safetensors" + ds = SafetensorsDataset(filepath=filepath) + ds_versioned = SafetensorsDataset( + filepath=filepath, version=Version(load_version, save_version) + ) + assert filepath in str(ds) + assert "version" not in str(ds) + + assert filepath in str(ds_versioned) + ver_str = f"version=Version(load={load_version}, save='{save_version}')" + assert ver_str in str(ds_versioned) + assert "SafetensorsDataset" in str(ds_versioned) + assert "SafetensorsDataset" in str(ds) + assert "protocol" in str(ds_versioned) + assert "protocol" in str(ds) + assert "backend" in str(ds_versioned) + assert "backend" in str(ds) + + def test_save_and_load(self, versioned_safetensors_dataset, dummy_data): + """Test that saved and reloaded data matches the original one for + the versioned dataset.""" + versioned_safetensors_dataset.save(dummy_data) + reloaded_df = versioned_safetensors_dataset.load() + + assert torch.equal(dummy_data["embeddings"], reloaded_df["embeddings"]) + + def test_no_versions(self, versioned_safetensors_dataset): + """Check the error if no versions are available for load.""" + pattern = r"Did not find any versions for SafetensorsDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_safetensors_dataset.load() + + def test_exists(self, versioned_safetensors_dataset, dummy_data): + """Test `exists` method invocation for versioned dataset.""" + assert not versioned_safetensors_dataset.exists() + versioned_safetensors_dataset.save(dummy_data) + assert versioned_safetensors_dataset.exists() + + def test_prevent_overwrite(self, versioned_safetensors_dataset, dummy_data): + """Check the error when attempting to override the dataset if the + corresponding Safetensors file for a given save version already exists.""" + versioned_safetensors_dataset.save(dummy_data) + pattern = ( + r"Save path \'.+\' for SafetensorsDataset\(.+\) must " + r"not exist if versioning is enabled\." + ) + with pytest.raises(DatasetError, match=pattern): + versioned_safetensors_dataset.save(dummy_data) + + @pytest.mark.parametrize( + "load_version", ["2019-01-01T23.59.59.999Z"], indirect=True + ) + @pytest.mark.parametrize( + "save_version", ["2019-01-02T00.00.00.000Z"], indirect=True + ) + def test_save_version_warning( + self, versioned_safetensors_dataset, load_version, save_version, dummy_data + ): + """Check the warning when saving to the path that differs from + the subsequent load path.""" + pattern = ( + rf"Save version '{save_version}' did not match load version " + rf"'{load_version}' for SafetensorsDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + versioned_safetensors_dataset.save(dummy_data) + + def test_http_filesystem_no_versioning(self): + pattern = "Versioning is not supported for HTTP protocols." + + with pytest.raises(DatasetError, match=pattern): + SafetensorsDataset( + filepath="https://example.com/file.safetensors", version=Version(None, None) + ) + + def test_versioning_existing_dataset( + self, safetensors_dataset, versioned_safetensors_dataset, dummy_data + ): + """Check the error when attempting to save a versioned dataset on top of an + already existing (non-versioned) dataset.""" + safetensors_dataset.save(dummy_data) + assert safetensors_dataset.exists() + assert safetensors_dataset._filepath == versioned_safetensors_dataset._filepath + pattern = ( + f"(?=.*file with the same name already exists in the directory)" + f"(?=.*{versioned_safetensors_dataset._filepath.parent.as_posix()})" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_safetensors_dataset.save(dummy_data) + + # Remove non-versioned dataset and try again + Path(safetensors_dataset._filepath.as_posix()).unlink() + versioned_safetensors_dataset.save(dummy_data) + assert versioned_safetensors_dataset.exists() + + def test_copy(self, versioned_safetensors_dataset): + safetensors_dataset_copy = versioned_safetensors_dataset._copy() + assert safetensors_dataset_copy is not versioned_safetensors_dataset + assert safetensors_dataset_copy._describe() == versioned_safetensors_dataset._describe() diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 96ad0fd41..65d5eb272 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -189,6 +189,9 @@ pytorch = ["kedro-datasets[pytorch-dataset]"] rioxarray-geotiffdataset = ["rioxarray>=0.15.0"] rioxarray = ["kedro-datasets[rioxarray-geotiffdataset]"] +safetensors-safetensorsdataset = ["safetensors"] +safetensors = ["kedro-datasets[safetensors-safetensorsdataset]"] + video-videodataset = ["opencv-python~=4.5.5.64"] video = ["kedro-datasets[video-videodataset]"]