From c1be430f1f5d34d6521ef0510cef2df1a795b953 Mon Sep 17 00:00:00 2001 From: "hpkfft.com" Date: Thu, 19 Sep 2024 17:55:25 -0700 Subject: [PATCH] Add ndarray tests return_jax and return_tensorflow. (#728) --- tests/test_ndarray.cpp | 26 ++++++++ tests/test_ndarray.py | 117 ++++++++++++++++++++------------- tests/test_ndarray_ext.pyi.ref | 6 ++ 3 files changed, 105 insertions(+), 44 deletions(-) diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index 392de8d4..42b9e0d1 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -298,6 +298,32 @@ NB_MODULE(test_ndarray_ext, m) { deleter); }); + m.def("ret_jax", []() { + float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + size_t shape[2] = { 2, 4 }; + + nb::capsule deleter(f, [](void *data) noexcept { + destruct_count++; + delete[] (float *) data; + }); + + return nb::ndarray>(f, 2, shape, + deleter); + }); + + m.def("ret_tensorflow", []() { + float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; + size_t shape[2] = { 2, 4 }; + + nb::capsule deleter(f, [](void *data) noexcept { + destruct_count++; + delete[] (float *) data; + }); + + return nb::ndarray>(f, 2, shape, + deleter); + }); + m.def("ret_array_scalar", []() { float* f = new float[1] { 1 }; size_t shape[1] = {}; diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index 7bd639c3..aa433d74 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -148,7 +148,7 @@ def test04_constrain_shape(): @needs_numpy -def test04_constrain_order(): +def test05_constrain_order(): assert t.check_order(np.zeros((3, 5, 4, 6), order='C')) == 'C' assert t.check_order(np.zeros((3, 5, 4, 6), order='F')) == 'F' assert t.check_order(np.zeros((3, 5, 4, 6), order='C')[:, 2, :, :]) == '?' @@ -156,7 +156,7 @@ def test04_constrain_order(): @needs_jax -def test05_constrain_order_jax(): +def test06_constrain_order_jax(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -170,7 +170,7 @@ def test05_constrain_order_jax(): @needs_torch @pytest.mark.filterwarnings -def test06_constrain_order_pytorch(): +def test07_constrain_order_pytorch(): try: c = torch.zeros(3, 5) c.__dlpack__() @@ -188,7 +188,7 @@ def test06_constrain_order_pytorch(): @needs_tensorflow -def test07_constrain_order_tensorflow(): +def test08_constrain_order_tensorflow(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -200,7 +200,7 @@ def test07_constrain_order_tensorflow(): @needs_numpy -def test08_write_from_cpp(): +def test09_write_from_cpp(): x = np.zeros(10, dtype=np.float32) t.initialize(x) assert np.all(x == np.arange(10, dtype=np.float32)) @@ -211,7 +211,7 @@ def test08_write_from_cpp(): @needs_numpy -def test09_implicit_conversion(): +def test10_implicit_conversion(): t.implicit(np.zeros((2, 2), dtype=np.uint32)) t.implicit(np.zeros((2, 2, 10), dtype=np.float32)[:, :, 4]) t.implicit(np.zeros((2, 2, 10), dtype=np.uint32)[:, :, 4]) @@ -228,7 +228,7 @@ def test09_implicit_conversion(): @needs_torch -def test10_implicit_conversion_pytorch(): +def test11_implicit_conversion_pytorch(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -249,7 +249,7 @@ def test10_implicit_conversion_pytorch(): @needs_tensorflow -def test11_implicit_conversion_tensorflow(): +def test12_implicit_conversion_tensorflow(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -270,7 +270,7 @@ def test11_implicit_conversion_tensorflow(): @needs_jax -def test12_implicit_conversion_jax(): +def test13_implicit_conversion_jax(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -290,7 +290,7 @@ def test12_implicit_conversion_jax(): t.noimplicit(jnp.zeros((2, 2), dtype=jnp.uint8)) -def test13_destroy_capsule(): +def test14_destroy_capsule(): collect() dc = t.destruct_count() a = t.return_dlpack() @@ -301,7 +301,7 @@ def test13_destroy_capsule(): @needs_numpy -def test14_consume_numpy(): +def test15_consume_numpy(): collect() class wrapper: def __init__(self, value): @@ -328,7 +328,7 @@ def __dlpack__(self): @needs_numpy -def test15_passthrough(): +def test16_passthrough(): a = t.ret_numpy() b = t.passthrough(a) assert a is b @@ -346,7 +346,7 @@ def test15_passthrough(): @needs_numpy -def test16_return_numpy(): +def test17_return_numpy(): collect() dc = t.destruct_count() x = t.ret_numpy() @@ -358,7 +358,7 @@ def test16_return_numpy(): @needs_torch -def test17_return_pytorch(): +def test18_return_pytorch(): try: c = torch.zeros(3, 5) except: @@ -373,8 +373,33 @@ def test17_return_pytorch(): assert t.destruct_count() - dc == 1 +@needs_jax +def test19_return_jax(): + collect() + dc = t.destruct_count() + x = t.ret_jax() + assert x.shape == (2, 4) + assert jnp.all(x == jnp.array([[1,2,3,4], [5,6,7,8]], dtype=jnp.float32)) + del x + collect() + assert t.destruct_count() - dc == 1 + + +@needs_tensorflow +def test20_return_tensorflow(): + collect() + dc = t.destruct_count() + x = t.ret_tensorflow() + assert x.get_shape().as_list() == [2, 4] + assert tf.math.reduce_all( + x == tf.constant([[1,2,3,4], [5,6,7,8]], dtype=tf.float32)) + del x + collect() + assert t.destruct_count() - dc == 1 + + @needs_numpy -def test18_return_array_scalar(): +def test21_return_array_scalar(): collect() dc = t.destruct_count() x = t.ret_array_scalar() @@ -386,7 +411,7 @@ def test18_return_array_scalar(): # See PR #162 @needs_torch -def test19_single_and_empty_dimension_pytorch(): +def test22_single_and_empty_dimension_pytorch(): a = torch.ones((1,100,1025), dtype=torch.float32) t.noop_3d_c_contig(a) a = torch.ones((100,1,1025), dtype=torch.float32) @@ -405,7 +430,7 @@ def test19_single_and_empty_dimension_pytorch(): # See PR #162 @needs_numpy -def test20_single_and_empty_dimension_numpy(): +def test23_single_and_empty_dimension_numpy(): a = np.ones((1,100,1025), dtype=np.float32) t.noop_3d_c_contig(a) a = np.ones((100,1,1025), dtype=np.float32) @@ -424,7 +449,7 @@ def test20_single_and_empty_dimension_numpy(): # See PR #162 @needs_torch -def test21_single_and_empty_dimension_fortran_order_pytorch(): +def test24_single_and_empty_dimension_fortran_order_pytorch(): # This idiom creates a pytorch 2D tensor in column major (aka, 'F') ordering a = torch.ones((0,100), dtype=torch.float32).t().contiguous().t() t.noop_2d_f_contig(a) @@ -437,7 +462,7 @@ def test21_single_and_empty_dimension_fortran_order_pytorch(): @needs_numpy -def test22_ro_array(): +def test25_ro_array(): a = np.array([1, 2], dtype=np.float32) assert t.accept_ro(a) == 1 assert t.accept_rw(a) == 1 @@ -449,7 +474,7 @@ def test22_ro_array(): @needs_numpy -def test22_return_ro(): +def test26_return_ro(): x = t.ret_numpy_const_ref() y = t.ret_numpy_const_ref_f() assert t.ret_numpy_const_ref.__doc__ == 'ret_numpy_const_ref() -> numpy.ndarray[dtype=float32, shape=(2, 4), order=\'C\', writable=False]' @@ -467,27 +492,27 @@ def test22_return_ro(): @needs_numpy -def test23_check_numpy(): +def test27_check_numpy(): assert t.check(np.zeros(1)) @needs_torch -def test24_check_torch(): +def test28_check_torch(): assert t.check(torch.zeros((1))) @needs_tensorflow -def test25_check_tensorflow(): +def test29_check_tensorflow(): assert t.check(tf.zeros((1))) @needs_jax -def test26_check_jax(): +def test30_check_jax(): assert t.check(jnp.zeros((1))) @needs_numpy -def test27_rv_policy(): +def test31_rv_policy(): def p(a): return a.__array_interface__['data'] @@ -513,7 +538,7 @@ def p(a): @needs_numpy -def test28_reference_internal(): +def test32_reference_internal(): collect() dc = t.destruct_count() c = t.Cls() @@ -592,7 +617,7 @@ def test28_reference_internal(): @needs_numpy -def test29_force_contig_numpy(): +def test33_force_contig_numpy(): a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = t.make_contig(a) assert b is a @@ -603,7 +628,7 @@ def test29_force_contig_numpy(): @needs_torch @pytest.mark.filterwarnings -def test30_force_contig_pytorch(): +def test34_force_contig_pytorch(): a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = t.make_contig(a) assert b is a @@ -614,7 +639,7 @@ def test30_force_contig_pytorch(): @needs_numpy -def test31_view(): +def test35_view(): # 1 x1 = np.array([[1,2],[3,4]], dtype=np.float32) x2 = np.array([[1,2],[3,4]], dtype=np.float64) @@ -650,7 +675,7 @@ def test31_view(): @needs_numpy -def test32_half(): +def test36_half(): if not hasattr(t, 'ret_numpy_half'): pytest.skip('half precision test is missing') x = t.ret_numpy_half() @@ -659,7 +684,7 @@ def test32_half(): assert np.all(x == [[1, 2, 3, 4], [5, 6, 7, 8]]) @needs_numpy -def test33_cast(): +def test37_cast(): a = t.cast(False) b = t.cast(True) assert a.ndim == 0 and b.ndim == 0 @@ -668,7 +693,7 @@ def test33_cast(): @needs_numpy -def test34_complex_decompose(): +def test38_complex_decompose(): x1 = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=np.complex64) assert np.all(x1.real == np.array([1, 3, 5], dtype=np.float32)) @@ -689,7 +714,7 @@ def test_uint32_complex_do_not_convert(variant): @needs_numpy -def test36_check_generic(): +def test40_check_generic(): class DLPackWrapper: def __init__(self, o): self.o = o @@ -701,7 +726,7 @@ def __dlpack__(self): @needs_numpy -def test37_noninteger_stride(): +def test41_noninteger_stride(): a = np.array([[1, 2, 3, 4, 0, 0], [5, 6, 7, 8, 0, 0]], dtype=np.float32) s = a[:, 0:4] # slice t.pass_float32(s) @@ -730,7 +755,7 @@ def test37_noninteger_stride(): @needs_numpy -def test38_const_qualifiers_numpy(): +def test42_const_qualifiers_numpy(): a = np.array([0, 0, 0, 3.14159, 0], dtype=np.float64) assert t.check_rw_by_value(a); assert a[1] == 1.414214; @@ -775,7 +800,7 @@ def test38_const_qualifiers_numpy(): @needs_torch -def test39_const_qualifiers_pytorch(): +def test43_const_qualifiers_pytorch(): a = torch.tensor([0, 0, 0, 3.14159, 0], dtype=torch.float64) assert t.check_rw_by_value(a); assert a[1] == 1.414214; @@ -812,7 +837,7 @@ def test39_const_qualifiers_pytorch(): @needs_cupy @pytest.mark.filterwarnings -def test40_constrain_order_cupy(): +def test44_constrain_order_cupy(): try: c = cp.zeros((3, 5)) c.__dlpack__() @@ -828,7 +853,7 @@ def test40_constrain_order_cupy(): @needs_cupy -def test41_implicit_conversion_cupy(): +def test45_implicit_conversion_cupy(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: @@ -849,7 +874,7 @@ def test41_implicit_conversion_cupy(): @needs_numpy -def test42_implicit_conversion_contiguous_complex(): +def test46_implicit_conversion_contiguous_complex(): # Test fix for issue #709 import numpy as np @@ -878,13 +903,14 @@ def test_conv(x): @needs_numpy -def test_43_ret_infer(): +def test_47_ret_infer(): import numpy as np assert np.all(t.ret_infer_c() == [[1, 2, 3, 4], [5, 6, 7, 8]]) assert np.all(t.ret_infer_f() == [[1, 3, 5, 7], [2, 4, 6, 8]]) + @needs_numpy -def test44_test_matrix4f(): +def test48_test_matrix4f(): a = t.Matrix4f() ad = a.data() bd = a.data() @@ -894,8 +920,9 @@ def test44_test_matrix4f(): for i in range(16): assert bd[i%4, i//4] == i + @needs_numpy -def test45_test_matrix4f_ref(): +def test49_test_matrix4f_ref(): assert t.Matrix4f.data_ref.__doc__.replace('data_ref', 'data') == t.Matrix4f.data.__doc__ a = t.Matrix4f() @@ -907,8 +934,9 @@ def test45_test_matrix4f_ref(): for i in range(16): assert bd[i%4, i//4] == i + @needs_numpy -def test46_test_matrix4f_copy(): +def test50_test_matrix4f_copy(): assert t.Matrix4f.data_ref.__doc__.replace('data_ref', 'data') == t.Matrix4f.data.__doc__ a = t.Matrix4f() @@ -922,8 +950,9 @@ def test46_test_matrix4f_copy(): for i in range(16): assert bd[i%4, i//4] == i + @needs_numpy -def test47_return_from_stack(): +def test51_return_from_stack(): import numpy as np assert np.all(t.ret_from_stack_1() == [1,2,3]) assert np.all(t.ret_from_stack_2() == [1,2,3]) diff --git a/tests/test_ndarray_ext.pyi.ref b/tests/test_ndarray_ext.pyi.ref index efe9c174..ad2d87e5 100644 --- a/tests/test_ndarray_ext.pyi.ref +++ b/tests/test_ndarray_ext.pyi.ref @@ -1,6 +1,8 @@ from typing import Annotated, overload +import jaxlib.xla_extension from numpy.typing import ArrayLike +import tensorflow.python.framework.ops class Cls: @@ -158,6 +160,8 @@ def ret_infer_c() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4), or def ret_infer_f() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4), order='F')]: ... +def ret_jax() -> jaxlib.xla_extension.DeviceArray[dtype=float32, shape=(2, 4)]: ... + def ret_numpy() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... def ret_numpy_const() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4), writable=False)]: ... @@ -170,6 +174,8 @@ def ret_numpy_half() -> Annotated[ArrayLike, dict(dtype='float16', shape=(2, 4)) def ret_pytorch() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... +def ret_tensorflow() -> tensorflow.python.framework.ops.EagerTensor[dtype=float32, shape=(2, 4)]: ... + def return_dlpack() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ... @overload