-
Is there any way to directly pass numpy.float16 dtype to pybind? |
Beta Was this translation helpful? Give feedback.
Answered by
jiwaszki
Feb 22, 2022
Replies: 1 comment 1 reply
-
Two ways of doing that: #include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
// ...
m.def("is_float16", [](py::dtype &type) // use of dtype class from numpy.h
{
if (type.is(py::dtype("float16")))
{
py::print("First overload.");
}
else
{
py::print("Not float16!");
}
});
m.def("is_float16", [](py::object &type) // pass python object directly
{
if (py::dtype::from_args(type).is(py::dtype("float16"))) // use helper function
{
py::print("Second overload.");
}
else
{
py::print("Not float16!");
}
}); In [1]: import numpy as np
...: import mymodule as m
In [2]: m.is_float16(np.dtype("float16"))
First overload.
In [3]: m.is_float16(np.dtype("float32"))
Not float16!
In [4]: m.is_float16(np.float16)
Second overload.
In [5]: m.is_float16(np.int32)
Not float16! Hope it helps!:) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
FeixLiu
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Two ways of doing that: