Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: rm __array__, add __buffer__ #115

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/array-api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
numpy-version: ['1.26', 'dev']
python-version: ['3.12', '3.13']
numpy-version: ['1.26', '2.2', 'dev']
exclude:
- python-version: '3.8'
numpy-version: 'dev'
- python-version: '3.13'
numpy-version: '1.26'

steps:
- name: Checkout array-api-strict
Expand All @@ -38,7 +38,7 @@ jobs:
if [[ "${{ matrix.numpy-version }}" == "dev" ]]; then
python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy;
else
python -m pip install 'numpy>=1.26,<2.0';
python -m pip install 'numpy=='${{ matrix.numpy-version }};
fi
python -m pip install ${GITHUB_WORKSPACE}/array-api-strict
python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt
Expand Down
39 changes: 16 additions & 23 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ def __hash__(self):

_default = object()

# See https://github.com/data-apis/array-api-strict/issues/67 and the comment
# on __array__ below.
_allow_array = True

class Array:
"""
Expand Down Expand Up @@ -157,26 +154,22 @@ def __repr__(self: Array, /) -> str:
# This was implemented historically for compatibility, and removing it has
# caused issues for some libraries (see
# https://github.com/data-apis/array-api-strict/issues/67).
def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None) -> npt.NDArray[Any]:
# We have to allow this to be internally enabled as there's no other
# easy way to parse a list of Array objects in asarray().
if _allow_array:
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
# copy keyword is new in 2.0.0; for older versions don't use it
# retry without that keyword.
if np.__version__[0] < '2':
return np.asarray(self._array, dtype=dtype)
elif np.__version__.startswith('2.0.0-dev0'):
# Handle dev version for which we can't know based on version
# number whether or not the copy keyword is supported.
try:
return np.asarray(self._array, dtype=dtype, copy=copy)
except TypeError:
return np.asarray(self._array, dtype=dtype)
else:
return np.asarray(self._array, dtype=dtype, copy=copy)
raise ValueError("Conversion from an array_api_strict array to a NumPy ndarray is not supported")

# Instead of `__array__` we now implement the buffer protocol.
# Note that it makes array-apis-strict requiring python>=3.12
def __buffer__(self, flags):
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
return memoryview(self._array)
def __release_buffer(self, buffer):
# XXX anything to do here?
pass

def __array__(self, *args, **kwds):
# a stub for python < 3.12; otherwise numpy silently produces object arrays
raise TypeError(
"Interoperation with NumPy requires python >= 3.12. Please upgrade."
)

# These are various helper functions to make the array behavior match the
# spec in places where it either deviates from or is more strict than
Expand Down
18 changes: 2 additions & 16 deletions array_api_strict/_creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

if TYPE_CHECKING:
Expand All @@ -16,19 +15,6 @@

import numpy as np

@contextmanager
def allow_array():
"""
Temporarily enable Array.__array__. This is needed for np.array to parse
list of lists of Array objects.
"""
from . import _array_object
original_value = _array_object._allow_array
try:
_array_object._allow_array = True
yield
finally:
_array_object._allow_array = original_value

def _check_valid_dtype(dtype):
# Note: Only spelling dtypes as the dtype objects is supported.
Expand Down Expand Up @@ -112,8 +98,8 @@ def asarray(
# Give a better error message in this case. NumPy would convert this
# to an object array. TODO: This won't handle large integers in lists.
raise OverflowError("Integer out of bounds for array dtypes")
with allow_array():
res = np.array(obj, dtype=_np_dtype, copy=copy)

res = np.array(obj, dtype=_np_dtype, copy=copy)
return Array._new(res, device=device)


Expand Down
35 changes: 19 additions & 16 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import operator
from builtins import all as all_

Expand Down Expand Up @@ -351,6 +352,10 @@ def test_array_properties():
assert b.mT.shape == (3, 2)


@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312,
reason="array conversion relies on buffer protocol, and "
"requires python >= 3.12"
)
def test_array_conversion():
# Check that arrays on the CPU device can be converted to NumPy
# but arrays on other devices can't. Note this is testing the logic in
Expand All @@ -361,25 +366,23 @@ def test_array_conversion():

for device in ("device1", "device2"):
a = ones((2, 3), device=array_api_strict.Device(device))
with pytest.raises(RuntimeError, match="Can not convert array"):
with pytest.raises((RuntimeError, TypeError)):
asarray([a])

def test__array__():
# __array__ should work for now
# __buffer__ should work for now for conversion to numpy
a = ones((2, 3))
np.array(a)

# Test the _allow_array private global flag for disabling it in the
# future.
from .. import _array_object
original_value = _array_object._allow_array
try:
_array_object._allow_array = False
a = ones((2, 3))
with pytest.raises(ValueError, match="Conversion from an array_api_strict array to a NumPy ndarray is not supported"):
np.array(a)
finally:
_array_object._allow_array = original_value
na = np.array(a)
assert na.shape == (2, 3)
assert na.dtype == np.float64

@pytest.mark.skipif(not sys.version_info.major*100 + sys.version_info.minor < 312,
reason="conversion to numpy errors out unless python >= 3.12"
)
def test_array_conversion_2():
a = ones((2, 3))
with pytest.raises(TypeError):
np.array(a)


def test_allow_newaxis():
a = ones(5)
Expand Down
8 changes: 7 additions & 1 deletion array_api_strict/tests/test_creation_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import warnings

from numpy.testing import assert_raises
Expand Down Expand Up @@ -97,7 +98,12 @@ def test_asarray_copy():
a[0] = 0
assert all(b[0] == 0)

def test_asarray_list_of_lists():

@pytest.mark.xfail(sys.version_info.major*100 + sys.version_info.minor < 312,
reason="array conversion relies on buffer protocol, and "
"requires python >= 3.12"
)
def test_asarray_list_of_arrays():
a = asarray(1, dtype=int16)
b = asarray([1], dtype=int16)
res = asarray([a, a])
Expand Down
Loading