Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: modify namespace based on backend #61

Merged
merged 5 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ Install with `pip`.
pip install marray
```

The only public function is `get_namespace`:
Use the `from...import...as` syntax to get a masked array namespace.

```python3
import numpy as xp # use any Array API compatible library, installed separately
import marray
mxp = marray.get_namespace(xp)
# use with any Array API compatible library, installed separately
from marray import numpy as mxp
import numpy as xp # optional (if the non-masked namespace is desired)
```

The resulting `mxp` namespace has all the features of `xp` that are specified
in the Array API standard, but they are modified to be mask-aware. Typically, the
signatures of functions in the `mxp` namespace match those in the `xp` namespace;
signatures of functions in the `mxp` namespace match those in the standard;
the one notable exception is the addition of a `mask` keyword argument of `asarray`.

```python3
Expand All @@ -39,4 +39,4 @@ Documentation provided by attributes of `xp` are exposed in the `mxp`
namespace and are accessible via `help`. For more information, please see
[the tutorial](https://mdhaber.github.io/marray/tutorial.html).

[^1]: The MArray logo is a nod to NumPy, but MArray is not affiliated with the NumPy project.
[^1]: The MArray logo is a nod to NumPy's logo, but MArray is not affiliated with the NumPy project.
22 changes: 17 additions & 5 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,30 @@

import collections
import dataclasses
import importlib
import inspect
import sys
import types


def get_namespace(xp):
def __getattr__(name):
try:
xp = importlib.import_module(name)
mod = _get_namespace(xp)
sys.modules[f"marray.{name}"] = mod
return mod
except ModuleNotFoundError as e:
raise AttributeError(str(e))


def _get_namespace(xp):
"""Returns a masked array namespace for an Array API Standard compatible backend

Examples
--------
>>> import numpy as xp
>>> from marray import get_namespace
>>> mxp = get_namespace(xp)
>>> from marray import _get_namespace
>>> mxp = _get_namespace(xp)
>>> A = mxp.eye(3)
>>> A.mask[0, ...] = True
>>> x = mxp.asarray([1, 2, 3], mask=[False, False, True])
Expand Down Expand Up @@ -207,8 +218,9 @@ def fun(self, other, name=name, **kwargs):
return self
setattr(MArray, name, fun)

mod = types.ModuleType('mxp')
sys.modules['mxp'] = mod
mod_name = f'marray.{xp.__name__}'
mod = types.ModuleType(mod_name)
sys.modules[mod_name] = mod

mod.MArray = MArray

Expand Down
77 changes: 41 additions & 36 deletions marray/tests/test_marray.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


def get_arrays(n_arrays, *, dtype, xp, ndim=(1, 4), seed=None):
xpm = marray.get_namespace(xp)
xpm = marray._get_namespace(xp)

entropy = np.random.SeedSequence(seed).entropy
rng = np.random.default_rng(entropy)
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_array_binary(f, dtype, xp, seed=None):
@pytest.mark.parametrize('xp', xps)
def test_bitwise_unary(f_name_fun, dtype, xp, seed=None):
f_name, f = f_name_fun
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)

res = f(~marrays[0])
Expand All @@ -260,7 +260,7 @@ def test_bitwise_unary(f_name_fun, dtype, xp, seed=None):
"Only integer dtypes are allowed in "])
def test_bitwise_binary(f_name_fun, dtype, xp, seed=None):
f_name, f = f_name_fun
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed)

res = f(marrays[0], marrays[1])
Expand All @@ -276,7 +276,7 @@ def test_bitwise_binary(f_name_fun, dtype, xp, seed=None):
@pytest.mark.parametrize('mask', [False, True])
@pytest.mark.parametrize('xp', xps)
def test_scalar_conversion(type_val, mask, xp):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
type, val = type_val
x = mxp.asarray(val)
assert type(x) == val
Expand All @@ -293,7 +293,7 @@ def test_indexing(xp):
# This does not make them easy to test exhaustively, but it does make
# them easy to fix if a shortcoming is identified. Include a very basic
# test for now, and improve as needed.
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
x = mxp.asarray(xp.arange(3), mask=[False, True, False])

# Test `__setitem__`/`__getitem__` roundtrip
Expand Down Expand Up @@ -327,7 +327,7 @@ def test_indexing(xp):
@pytest.mark.parametrize('xp', xps)
def test_dlpack(dtype, xp, seed=None):
# This is a placeholder for a real test when there is a real implementation
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, _, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
assert isinstance(marrays[0].__dlpack__(), type(marrays[0].data.__dlpack__()))
assert marrays[0].__dlpack_device__() == marrays[0].data.__dlpack_device__()
Expand Down Expand Up @@ -382,7 +382,7 @@ def test_inplace(f, arg2_masked, dtype, xp, seed=None):
@pass_exceptions(allowed=["Only numeric dtypes are allowed in matmul"])
def test_inplace_array_binary(f, dtype, xp, seed=None):
# very restrictive operator -> limited test
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
rng = np.random.default_rng(seed)
data = (rng.random((3, 10, 10))*10).astype(dtype)
mask = rng.random((3, 10, 10)) > 0.5
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_rarithmetic_binary(f, dtype, xp, type_, seed=None):
@pass_exceptions(allowed=["Only numeric dtypes are allowed in __matmul__"])
def test_rarray_binary(dtype, xp, seed=None):
# very restrictive operator -> limited test
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
rng = np.random.default_rng(seed)
data = (rng.random((3, 10, 10))*10).astype(dtype)
mask = rng.random((3, 10, 10)) > 0.5
Expand Down Expand Up @@ -467,7 +467,7 @@ def test_attributes(dtype, xp, seed=None):

@pytest.mark.parametrize('xp', xps)
def test_constants(xp):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
assert mxp.e == xp.e
assert mxp.inf == xp.inf
assert np.isnan(mxp.nan) == np.isnan(xp.nan)
Expand All @@ -478,7 +478,7 @@ def test_constants(xp):
@pytest.mark.parametrize("f", data_type + inspection + version)
@pytest.mark.parametrize('xp', xps)
def test_dtype_funcs_inspection(f, xp):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
getattr(mxp, f) is getattr(xp, f)


Expand All @@ -487,7 +487,7 @@ def test_dtype_funcs_inspection(f, xp):
def test_dtypes(dtype, xp):
if xp == np:
pytest.xfail("NumPy fails... unclear whether NumPy follows standard here.")
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
getattr(mxp, dtype).__eq__(getattr(xp, dtype))


Expand All @@ -504,7 +504,7 @@ def test_dtypes(dtype, xp):
"Only boolean dtypes are allowed",
"Only complex floating-point dtypes are allowed"])
def test_elementwise_unary(f_name, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
f = getattr(mxp, f_name)
f2 = getattr(xp, f_name)
Expand All @@ -527,7 +527,7 @@ def test_elementwise_unary(f_name, dtype, xp, seed=None):
"Only numeric dtypes are allowed",
"Only boolean dtypes are allowed",])
def test_elementwise_binary(f_name, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed)
f = getattr(mxp, f_name)
f2 = getattr(np, f_name)
Expand All @@ -550,7 +550,7 @@ def test_statistical_array(f_name, keepdims, xp, dtype, seed=None):
# should fix this and ensure strict check at the end
pytest.skip("`np.ma` can't provide reference due to numpy/numpy#27885")

mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
rng = np.random.default_rng(seed)
axes = list(range(marrays[0].ndim))
Expand Down Expand Up @@ -595,7 +595,7 @@ def test_statistical_array(f_name, keepdims, xp, dtype, seed=None):
@pass_exceptions(allowed=[r"arange() is only supported for booleans when"])
def test_creation(f_name, args, kwargs, dtype, xp, seed=None):
dtype = getattr(xp, dtype)
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
f_xp = getattr(xp, f_name)
f_mxp = getattr(mxp, f_name)
if f_name.endswith('like'):
Expand All @@ -614,7 +614,7 @@ def test_creation(f_name, args, kwargs, dtype, xp, seed=None):
@pytest.mark.parametrize("dtype", dtypes_all + [None])
@pytest.mark.parametrize('xp', xps)
def test_creation_like(f_name, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
f_mxp = getattr(mxp, f_name)
f_np = getattr(np, f_name) # np.ma doesn't have full_like
args = (2,) if f_name == "full_like" else ()
Expand All @@ -633,7 +633,7 @@ def test_creation_like(f_name, dtype, xp, seed=None):
@pytest.mark.parametrize('dtype', dtypes_all)
@pytest.mark.parametrize('xp', xps)
def test_tri(f_name, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
f_xp = getattr(xp, f_name)
f_mxp = getattr(mxp, f_name)
marrays, _, seed = get_arrays(1, ndim=(2, 4), dtype=dtype, xp=xp, seed=seed)
Expand All @@ -649,7 +649,7 @@ def test_tri(f_name, dtype, xp, seed=None):
@pytest.mark.parametrize('dtype', dtypes_all)
@pytest.mark.parametrize('xp', xps)
def test_meshgrid(indexing, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
rng = np.random.default_rng(seed)
n = rng.integers(1, 4)
marrays, masked_arrays, seed = get_arrays(n, ndim=1, dtype=dtype, xp=xp, seed=seed)
Expand All @@ -667,7 +667,7 @@ def test_meshgrid(indexing, dtype, xp, seed=None):
@pytest.mark.parametrize('dtype', dtypes_integral + dtypes_real)
@pytest.mark.parametrize('xp', xps)
def test_searchsorted(side, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)

rng = np.random.default_rng(seed)
n = 20
Expand Down Expand Up @@ -711,7 +711,7 @@ def test_searchsorted(side, dtype, xp, seed=None):
@pytest.mark.parametrize('dtype', dtypes_all)
@pytest.mark.parametrize('xp', xps)
def test_where(dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(2, dtype=dtype, xp=xp, seed=seed)
rng = np.random.default_rng(seed)
cond = rng.random(marrays[0].shape) > 0.5
Expand All @@ -723,7 +723,7 @@ def test_where(dtype, xp, seed=None):
@pytest.mark.parametrize('dtype', dtypes_all)
@pytest.mark.parametrize('xp', xps)
def test_nonzero(dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
x, y = marrays[0], masked_arrays[0]
rng = np.random.default_rng(seed)
Expand Down Expand Up @@ -757,7 +757,7 @@ def test_nonzero(dtype, xp, seed=None):
@pytest.mark.parametrize('dtype', dtypes_all)
@pytest.mark.parametrize('xp', xps)
def test_manipulation(f_name, n_arrays, n_dims, args, kwargs, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, _, seed = get_arrays(n_arrays, ndim=n_dims, dtype=dtype, xp=xp, seed=seed)
if f_name in {'broadcast_to', 'squeeze'}:
original_shape = marrays[0].shape
Expand Down Expand Up @@ -792,7 +792,7 @@ def test_manipulation(f_name, n_arrays, n_dims, args, kwargs, dtype, xp, seed=No
@pytest.mark.parametrize('copy', [False, True])
@pytest.mark.parametrize('xp', xps)
def test_astype(dtype_in, dtype_out, copy, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype_in, xp=xp, seed=seed)

res = mxp.astype(marrays[0], getattr(xp, dtype_out), copy=copy)
Expand All @@ -809,7 +809,7 @@ def test_astype(dtype_in, dtype_out, copy, xp, seed=None):

@pytest.mark.parametrize('xp', xps)
def test_asarray_device(xp):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
message = "`device` argument is not implemented"
with pytest.raises(NotImplementedError, match=message):
mxp.asarray(xp.asarray([1, 2, 3]), device='coconut')
Expand All @@ -819,7 +819,7 @@ def test_asarray_device(xp):
@pytest.mark.parametrize('xp', xps)
@pass_exceptions(allowed=["Only real numeric dtypes are allowed"])
def test_clip(dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(3, dtype=dtype, xp=xp, seed=seed)
min = mxp.minimum(marrays[1], marrays[2])
max = mxp.maximum(marrays[1], marrays[2])
Expand All @@ -835,7 +835,7 @@ def test_clip(dtype, xp, seed=None):
@pytest.mark.parametrize('dtype', dtypes_all)
@pytest.mark.parametrize('xp', xps)
def test_set(f_name, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, _, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
f_mxp = getattr(mxp, f_name)

Expand Down Expand Up @@ -882,7 +882,7 @@ def test_set(f_name, dtype, xp, seed=None):
@pytest.mark.parametrize('dtype', dtypes_real + dtypes_integral)
@pytest.mark.parametrize('xp', xps)
def test_sorting(f_name, descending, stable, dtype, xp, seed=None):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
marrays, masked_arrays, seed = get_arrays(1, dtype=dtype, xp=xp, seed=seed)
f_mxp = getattr(mxp, f_name)
f_xp = getattr(np.ma, f_name)
Expand Down Expand Up @@ -931,7 +931,7 @@ def test_sorting(f_name, descending, stable, dtype, xp, seed=None):

@pytest.mark.parametrize('xp', xps)
def test_array_namespace(xp):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
x = mxp.asarray([1, 2, 3])
assert x.__array_namespace__() is mxp
assert x.__array_namespace__("2023.12") is mxp
Expand All @@ -940,34 +940,39 @@ def test_array_namespace(xp):
x.__array_namespace__("shrubbery")


@pytest.mark.parametrize('xp', xps)
def test_import(xp):
mxp = marray.get_namespace(xp) # noqa: F841
from mxp import asarray
def test_import():
from marray import numpy as mnp
assert mnp.__name__ == 'marray.numpy'
from marray.numpy import asarray
asarray(10, mask=True)

from marray import array_api_strict as mxp
assert mxp.__name__ == 'marray.array_api_strict'
from marray.array_api_strict import asarray
asarray(10, mask=True)

@pytest.mark.parametrize('xp', xps)
def test_str(xp):
mxp = marray.get_namespace(xp)
mxp = marray._get_namespace(xp)
x = mxp.asarray(1, mask=True)
ref = "MArray(1, True)"
assert str(x) == ref

def test_repr():
mxp = marray.get_namespace(strict)
mxp = marray._get_namespace(strict)
x = mxp.asarray(1, mask=True)
ref = ("MArray(\n Array(1, dtype=array_api_strict.int64),"
"\n Array(True, dtype=array_api_strict.bool)\n)")
assert repr(x) == ref

mxp = marray.get_namespace(np)
mxp = marray._get_namespace(np)
x = mxp.asarray(1, mask=True)
ref = "MArray(array(1), array(True))"
assert repr(x) == ref

def test_signature_docs():
# Rough test that signatures were replaced where possible
mxp = marray.get_namespace(np)
mxp = marray._get_namespace(np)
assert mxp.sum.__signature__ == inspect.signature(np.sum)
assert np.sum.__doc__ in mxp.sum.__doc__

Expand Down
Loading
Loading