Skip to content

Commit

Permalink
Add ndarray tests return_jax and return_tensorflow. (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
hpkfft authored Sep 20, 2024
1 parent 3925f57 commit c1be430
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 44 deletions.
26 changes: 26 additions & 0 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,32 @@ NB_MODULE(test_ndarray_ext, m) {
deleter);
});

m.def("ret_jax", []() {
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
size_t shape[2] = { 2, 4 };

nb::capsule deleter(f, [](void *data) noexcept {
destruct_count++;
delete[] (float *) data;
});

return nb::ndarray<nb::jax, float, nb::shape<2, 4>>(f, 2, shape,
deleter);
});

m.def("ret_tensorflow", []() {
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
size_t shape[2] = { 2, 4 };

nb::capsule deleter(f, [](void *data) noexcept {
destruct_count++;
delete[] (float *) data;
});

return nb::ndarray<nb::tensorflow, float, nb::shape<2, 4>>(f, 2, shape,
deleter);
});

m.def("ret_array_scalar", []() {
float* f = new float[1] { 1 };
size_t shape[1] = {};
Expand Down
Loading

0 comments on commit c1be430

Please sign in to comment.