Skip to content

Commit

Permalink
feat: support using S3Config.credentials_provider for writes (#3648)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `S3Credentials.expiry` must be a timezoned datetime now

Resolves  #3367
  • Loading branch information
kevinzwang authored Jan 15, 2025
1 parent feab49a commit 0afc55f
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 112 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ features = ['async']
path = "src/parquet2"

[workspace.dependencies.pyo3]
features = ["extension-module", "multiple-pymethods", "abi3-py39", "indexmap"]
features = ["extension-module", "multiple-pymethods", "abi3-py39", "indexmap", "chrono"]
version = "0.23.3"

[workspace.dependencies.pyo3-log]
Expand Down
4 changes: 4 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ class S3Config:
"""Creates an S3Config, retrieving credentials and configurations from the current environment."""
...

def provide_cached_credentials(self) -> S3Credentials | None:
"""Wrapper around call to `S3Config.credentials_provider` to cache credentials until expiry."""
...

class S3Credentials:
key_id: str
access_key: str
Expand Down
62 changes: 38 additions & 24 deletions daft/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pathlib
import sys
import urllib.parse
from typing import TYPE_CHECKING, Any, Literal
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any

from daft.convert import from_pydict
from daft.daft import FileFormat, FileInfos, IOConfig, io_glob
Expand All @@ -19,7 +20,14 @@

logger = logging.getLogger(__name__)

_CACHED_FSES: dict[tuple[str, IOConfig | None], pafs.FileSystem] = {}

@dataclasses.dataclass(frozen=True)
class PyArrowFSWithExpiry:
fs: pafs.FileSystem
expiry: datetime | None


_CACHED_FSES: dict[tuple[str, IOConfig | None], PyArrowFSWithExpiry] = {}


def _get_fs_from_cache(protocol: str, io_config: IOConfig | None) -> pafs.FileSystem | None:
Expand All @@ -29,22 +37,20 @@ def _get_fs_from_cache(protocol: str, io_config: IOConfig | None) -> pafs.FileSy
"""
global _CACHED_FSES

return _CACHED_FSES.get((protocol, io_config))
if (protocol, io_config) in _CACHED_FSES:
fs = _CACHED_FSES[(protocol, io_config)]

if fs.expiry is None or fs.expiry > datetime.now(timezone.utc):
return fs.fs

def _put_fs_in_cache(protocol: str, fs: pafs.FileSystem, io_config: IOConfig | None) -> None:
"""Put pyarrow filesystem in cache under provided protocol."""
global _CACHED_FSES
return None

_CACHED_FSES[(protocol, io_config)] = fs

def _put_fs_in_cache(protocol: str, fs: pafs.FileSystem, io_config: IOConfig | None, expiry: datetime | None) -> None:
"""Put pyarrow filesystem in cache under provided protocol."""
global _CACHED_FSES

@dataclasses.dataclass(frozen=True)
class ListingInfo:
path: str
size: int
type: Literal["file", "directory"]
rows: int | None = None
_CACHED_FSES[(protocol, io_config)] = PyArrowFSWithExpiry(fs, expiry)


def get_filesystem(protocol: str, **kwargs) -> fsspec.AbstractFileSystem:
Expand Down Expand Up @@ -154,10 +160,10 @@ def _resolve_paths_and_filesystem(
if resolved_filesystem is None:
# Resolve path and filesystem for the first path.
# We use this first resolved filesystem for validation on all other paths.
resolved_path, resolved_filesystem = _infer_filesystem(paths[0], io_config)
resolved_path, resolved_filesystem, expiry = _infer_filesystem(paths[0], io_config)

# Put resolved filesystem in cache under these paths' canonical protocol.
_put_fs_in_cache(protocol, resolved_filesystem, io_config)
_put_fs_in_cache(protocol, resolved_filesystem, io_config, expiry)
else:
resolved_path = _validate_filesystem(paths[0], resolved_filesystem, io_config)

Expand All @@ -175,7 +181,7 @@ def _resolve_paths_and_filesystem(


def _validate_filesystem(path: str, fs: pafs.FileSystem, io_config: IOConfig | None) -> str:
resolved_path, inferred_fs = _infer_filesystem(path, io_config)
resolved_path, inferred_fs, _ = _infer_filesystem(path, io_config)
if not isinstance(fs, type(inferred_fs)):
raise RuntimeError(
f"Cannot read multiple paths with different inferred PyArrow filesystems. Expected: {fs} but received: {inferred_fs}"
Expand All @@ -186,8 +192,8 @@ def _validate_filesystem(path: str, fs: pafs.FileSystem, io_config: IOConfig | N
def _infer_filesystem(
path: str,
io_config: IOConfig | None,
) -> tuple[str, pafs.FileSystem]:
"""Resolves and normalizes the provided path and infers it's filesystem.
) -> tuple[str, pafs.FileSystem, datetime | None]:
"""Resolves and normalizes the provided path and infers its filesystem and expiry.
Also ensures that the inferred filesystem is compatible with the passedfilesystem, if provided.
Expand Down Expand Up @@ -225,17 +231,25 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None):
except ImportError:
pass # Config does not exist in pyarrow 7.0.0

expiry = None
if (s3_creds := s3_config.provide_cached_credentials()) is not None:
_set_if_not_none(translated_kwargs, "access_key", s3_creds.key_id)
_set_if_not_none(translated_kwargs, "secret_key", s3_creds.access_key)
_set_if_not_none(translated_kwargs, "session_token", s3_creds.session_token)

expiry = s3_creds.expiry

resolved_filesystem = pafs.S3FileSystem(**translated_kwargs)
resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path))
return resolved_path, resolved_filesystem
return resolved_path, resolved_filesystem, expiry

###
# Local
###
elif protocol == "file":
resolved_filesystem = pafs.LocalFileSystem()
resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path))
return resolved_path, resolved_filesystem
return resolved_path, resolved_filesystem, None

###
# GCS
Expand All @@ -257,7 +271,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None):

resolved_filesystem = GcsFileSystem(**translated_kwargs)
resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path))
return resolved_path, resolved_filesystem
return resolved_path, resolved_filesystem, None

###
# HTTP: Use FSSpec as a fallback
Expand All @@ -267,7 +281,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None):
fsspec_fs = fsspec_fs_cls()
resolved_filesystem, resolved_path = pafs._resolve_filesystem_and_path(path, fsspec_fs)
resolved_path = resolved_filesystem.normalize_path(resolved_path)
return resolved_path, resolved_filesystem
return resolved_path, resolved_filesystem, None

###
# Azure: Use FSSpec as a fallback
Expand All @@ -290,7 +304,7 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None):
fsspec_fs = fsspec_fs_cls()
resolved_filesystem, resolved_path = pafs._resolve_filesystem_and_path(path, fsspec_fs)
resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(resolved_path))
return resolved_path, resolved_filesystem
return resolved_path, resolved_filesystem, None

else:
raise NotImplementedError(f"Cannot infer PyArrow filesystem for protocol {protocol}: please file an issue!")
Expand All @@ -313,7 +327,7 @@ def glob_path_with_stats(
file_format: FileFormat | None,
io_config: IOConfig | None,
) -> FileInfos:
"""Glob a path, returning a list ListingInfo."""
"""Glob a path, returning a FileInfos."""
files = io_glob(path, io_config=io_config)
filepaths_to_infos = {f["path"]: {"size": f["size"], "type": f["type"]} for f in files}

Expand Down
4 changes: 3 additions & 1 deletion src/common/io-config/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
[dependencies]
aws-credential-types = {version = "0.55.3"}
chrono = {workspace = true}
common-error = {path = "../error", default-features = false}
common-py-serde = {path = "../py-serde", default-features = false}
derivative = {workspace = true}
derive_more = {workspace = true}
pyo3 = {workspace = true, optional = true}
secrecy = {version = "0.8.0", features = ["alloc"], default-features = false}
serde = {workspace = true}
typetag = {workspace = true}

[features]
python = ["dep:pyo3", "common-py-serde/python"]
python = ["dep:pyo3", "common-error/python", "common-py-serde/python"]

[lints]
workspace = true
Expand Down
2 changes: 2 additions & 0 deletions src/common/io-config/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![feature(let_chains)]

#[cfg(feature = "python")]
pub mod python;

Expand Down
Loading

0 comments on commit 0afc55f

Please sign in to comment.