Skip to content

Commit

Permalink
Backport PR #1273 on branch 0.10.x (Fix IO error reporting) (#1281)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp A <[email protected]>
  • Loading branch information
meeseeksmachine and flying-sheep authored Jan 4, 2024
1 parent 2216c8c commit 5cd69a4
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 105 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ repos:
rev: v4.0.0-alpha.8
hooks:
- id: prettier
exclude_types:
- markdown
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class StorageType(Enum):
DaskArray = DaskArray
CupyArray = CupyArray
CupySparseMatrix = CupySparseMatrix
BackedSparseMAtrix = BaseCompressedSparseDataset
BackedSparseMatrix = BaseCompressedSparseDataset

@classmethod
def classes(cls):
Expand Down
19 changes: 13 additions & 6 deletions anndata/_io/h5ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
TypeVar,
Expand Down Expand Up @@ -112,7 +113,13 @@ def write_h5ad(

@report_write_key_on_error
@write_spec(IOSpec("array", "0.2.0"))
def write_sparse_as_dense(f, key, value, dataset_kwargs=MappingProxyType({})):
def write_sparse_as_dense(
f: h5py.Group,
key: str,
value: sparse.spmatrix | BaseCompressedSparseDataset,
*,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
real_key = None # Flag for if temporary key was used
if key in f:
if isinstance(value, BaseCompressedSparseDataset) and (
Expand Down Expand Up @@ -269,7 +276,7 @@ def callback(func, elem_name: str, elem, iospec):
def _read_raw(
f: h5py.File | AnnDataFileManager,
as_sparse: Collection[str] = (),
rdasp: Callable[[h5py.Dataset], sparse.spmatrix] = None,
rdasp: Callable[[h5py.Dataset], sparse.spmatrix] | None = None,
*,
attrs: Collection[str] = ("X", "var", "varm"),
) -> dict:
Expand All @@ -286,7 +293,7 @@ def _read_raw(


@report_read_key_on_error
def read_dataframe_legacy(dataset) -> pd.DataFrame:
def read_dataframe_legacy(dataset: h5py.Dataset) -> pd.DataFrame:
"""Read pre-anndata 0.7 dataframes."""
warn(
f"'{dataset.name}' was written with a very old version of AnnData. "
Expand All @@ -305,7 +312,7 @@ def read_dataframe_legacy(dataset) -> pd.DataFrame:
return df


def read_dataframe(group) -> pd.DataFrame:
def read_dataframe(group: h5py.Group | h5py.Dataset) -> pd.DataFrame:
"""Backwards compat function"""
if not isinstance(group, h5py.Group):
return read_dataframe_legacy(group)
Expand Down Expand Up @@ -352,7 +359,7 @@ def read_dense_as_sparse(
raise ValueError(f"Cannot read dense array as type: {sparse_format}")


def read_dense_as_csr(dataset, axis_chunk=6000):
def read_dense_as_csr(dataset: h5py.Dataset, axis_chunk: int = 6000):
sub_matrices = []
for idx in idx_chunks_along_axis(dataset.shape, 0, axis_chunk):
dense_chunk = dataset[idx]
Expand All @@ -361,7 +368,7 @@ def read_dense_as_csr(dataset, axis_chunk=6000):
return sparse.vstack(sub_matrices, format="csr")


def read_dense_as_csc(dataset, axis_chunk=6000):
def read_dense_as_csc(dataset: h5py.Dataset, axis_chunk: int = 6000):
sub_matrices = []
for idx in idx_chunks_along_axis(dataset.shape, 1, axis_chunk):
sub_matrix = sparse.csc_matrix(dataset[idx])
Expand Down
81 changes: 36 additions & 45 deletions anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Mapping
from collections.abc import Mapping
from dataclasses import dataclass
from functools import singledispatch, wraps
from types import MappingProxyType
Expand All @@ -10,12 +10,13 @@
from anndata.compat import _read_attr

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable

from anndata._types import GroupStorageType, StorageType


# TODO: This probably should be replaced by a hashable Mapping due to conversion b/w "_" and "-"
# TODO: Should filetype be included in the IOSpec if it changes the encoding? Or does the intent that these things be "the same" overrule that?


@dataclass(frozen=True)
class IOSpec:
encoding_type: str
Expand All @@ -25,7 +26,9 @@ class IOSpec:
# TODO: Should this subclass from LookupError?
class IORegistryError(Exception):
@classmethod
def _from_write_parts(cls, dest_type, typ, modifiers) -> IORegistryError:
def _from_write_parts(
cls, dest_type: type, typ: type, modifiers: frozenset[str]
) -> IORegistryError:
msg = f"No method registered for writing {typ} into {dest_type}"
if modifiers:
msg += f" with {modifiers}"
Expand All @@ -36,7 +39,7 @@ def _from_read_parts(
cls,
method: str,
registry: Mapping,
src_typ: StorageType,
src_typ: type[StorageType],
spec: IOSpec,
) -> IORegistryError:
# TODO: Improve error message if type exists, but version does not
Expand All @@ -50,7 +53,7 @@ def _from_read_parts(
def write_spec(spec: IOSpec):
def decorator(func: Callable):
@wraps(func)
def wrapper(g, k, *args, **kwargs):
def wrapper(g: GroupStorageType, k: str, *args, **kwargs):
result = func(g, k, *args, **kwargs)
g[k].attrs.setdefault("encoding-type", spec.encoding_type)
g[k].attrs.setdefault("encoding-version", spec.encoding_version)
Expand Down Expand Up @@ -193,12 +196,12 @@ def proc_spec(spec) -> IOSpec:


@proc_spec.register(IOSpec)
def proc_spec_spec(spec) -> IOSpec:
def proc_spec_spec(spec: IOSpec) -> IOSpec:
return spec


@proc_spec.register(Mapping)
def proc_spec_mapping(spec) -> IOSpec:
def proc_spec_mapping(spec: Mapping[str, str]) -> IOSpec:
return IOSpec(**{k.replace("-", "_"): v for k, v in spec.items()})


Expand All @@ -213,7 +216,9 @@ def get_spec(
)


def _iter_patterns(elem):
def _iter_patterns(
elem,
) -> Generator[tuple[type, type | str] | tuple[type, type, str], None, None]:
"""Iterates over possible patterns for an element in order of precedence."""
from anndata.compat import DaskArray

Expand All @@ -236,40 +241,27 @@ def __init__(self, registry: IORegistry, callback: Callable | None = None) -> No
def read_elem(
self,
elem: StorageType,
modifiers: frozenset(str) = frozenset(),
modifiers: frozenset[str] = frozenset(),
) -> Any:
"""Read an element from a store. See exported function for more details."""
from functools import partial

read_func = self.registry.get_reader(
type(elem), get_spec(elem), frozenset(modifiers)
iospec = get_spec(elem)
read_func = partial(
self.registry.get_reader(type(elem), iospec, modifiers),
_reader=self,
)
read_func = partial(read_func, _reader=self)
if self.callback is not None:
return self.callback(read_func, elem.name, elem, iospec=get_spec(elem))
else:
if self.callback is None:
return read_func(elem)
return self.callback(read_func, elem.name, elem, iospec=iospec)


class Writer:
def __init__(
self,
registry: IORegistry,
callback: Callable[
[
GroupStorageType,
str,
StorageType,
dict,
],
None,
]
| None = None,
):
def __init__(self, registry: IORegistry, callback: Callable | None = None):
self.registry = registry
self.callback = callback

def find_writer(self, dest_type, elem, modifiers):
def find_writer(self, dest_type: type, elem, modifiers: frozenset[str]):
for pattern in _iter_patterns(elem):
if self.registry.has_writer(dest_type, pattern, modifiers):
return self.registry.get_writer(dest_type, pattern, modifiers)
Expand All @@ -281,10 +273,10 @@ def write_elem(
self,
store: GroupStorageType,
k: str,
elem,
elem: Any,
*,
dataset_kwargs=MappingProxyType({}),
modifiers=frozenset(),
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
modifiers: frozenset[str] = frozenset(),
):
from functools import partial
from pathlib import PurePosixPath
Expand Down Expand Up @@ -313,17 +305,16 @@ def write_elem(
_writer=self,
)

if self.callback is not None:
return self.callback(
write_func,
store,
k,
elem,
dataset_kwargs=dataset_kwargs,
iospec=self.registry.get_spec(elem),
)
else:
if self.callback is None:
return write_func(store, k, elem, dataset_kwargs=dataset_kwargs)
return self.callback(
write_func,
store,
k,
elem,
dataset_kwargs=dataset_kwargs,
iospec=self.registry.get_spec(elem),
)


def read_elem(elem: StorageType) -> Any:
Expand All @@ -346,7 +337,7 @@ def write_elem(
k: str,
elem: Any,
*,
dataset_kwargs: Mapping = MappingProxyType({}),
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
) -> None:
"""
Write an element to a storage group using anndata encoding.
Expand Down
63 changes: 34 additions & 29 deletions anndata/_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from __future__ import annotations

from functools import wraps
from typing import Callable, Literal
from typing import TYPE_CHECKING, Callable, Literal, Union, cast
from warnings import warn

import h5py
from packaging.version import Version

from anndata.compat import H5Group, ZarrGroup, add_note

from .._core.sparse_dataset import BaseCompressedSparseDataset
from ..compat import H5Group, ZarrGroup, add_note, pairwise

if TYPE_CHECKING:
from .._types import StorageType

Storage = Union[StorageType, BaseCompressedSparseDataset]

# For allowing h5py v3
# https://github.com/scverse/anndata/issues/442
Expand Down Expand Up @@ -151,31 +155,26 @@ class AnnDataReadError(OSError):
pass


def _get_parent(elem):
try:
import zarr
except ImportError:
zarr = None
if zarr and isinstance(elem, (zarr.Group, zarr.Array)):
parent = elem.store # Not sure how to always get a name out of this
elif isinstance(elem, BaseCompressedSparseDataset):
parent = elem.group.file.name
else:
parent = elem.file.name
return parent
def _get_display_path(store: Storage) -> str:
"""Return an absolute path of an element (always starts with “/”)."""
if isinstance(store, BaseCompressedSparseDataset):
store = store.group
path = store.name or "??" # can be None
return f'/{path.removeprefix("/")}'


def add_key_note(e: BaseException, elem, key, op=Literal["read", "writ"]) -> None:
def add_key_note(
e: BaseException, store: Storage, path: str, key: str, op: Literal["read", "writ"]
) -> None:
if any(
f"Error raised while {op}ing key" in note
for note in getattr(e, "__notes__", [])
):
return
parent = _get_parent(elem)
add_note(
e,
f"Error raised while {op}ing key {key!r} of {type(elem)} to " f"{parent}",
)

dir = "to" if op == "writ" else "from"
msg = f"Error raised while {op}ing key {key!r} of {type(store)} {dir} {path}"
add_note(e, msg)


def report_read_key_on_error(func):
Expand All @@ -198,13 +197,17 @@ def func_wrapper(*args, **kwargs):
from anndata._io.specs import Reader

# Figure out signature (method vs function) by going through args
for elem in args:
if not isinstance(elem, Reader):
for arg in args:
if not isinstance(arg, Reader):
store = cast("Storage", arg)
break
else:
raise ValueError("No element found in args.")
try:
return func(*args, **kwargs)
except Exception as e:
add_key_note(e, elem, elem.name, "read")
path, key = _get_display_path(store).rsplit("/", 1)
add_key_note(e, store, path or "/", key, "read")
raise

return func_wrapper
Expand All @@ -230,15 +233,17 @@ def func_wrapper(*args, **kwargs):
from anndata._io.specs import Writer

# Figure out signature (method vs function) by going through args
for i in range(len(args)):
elem = args[i]
key = args[i + 1]
if not isinstance(elem, Writer):
for arg, key in pairwise(args):
if not isinstance(arg, Writer):
store = cast("Storage", arg)
break
else:
raise ValueError("No element found in args.")
try:
return func(*args, **kwargs)
except Exception as e:
add_key_note(e, elem, key, "writ")
path = _get_display_path(store)
add_key_note(e, store, path, key, "writ")
raise

return func_wrapper
Expand Down
2 changes: 1 addition & 1 deletion anndata/_io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def read_dataframe_legacy(dataset: zarr.Array) -> pd.DataFrame:


@report_read_key_on_error
def read_dataframe(group) -> pd.DataFrame:
def read_dataframe(group: zarr.Group | zarr.Array) -> pd.DataFrame:
# Fast paths
if isinstance(group, zarr.Array):
return read_dataframe_legacy(group)
Expand Down
Loading

0 comments on commit 5cd69a4

Please sign in to comment.