diff --git a/trame_client/encoders/numpy.py b/trame_client/encoders/numpy.py index 5dfdb2c..00a102c 100644 --- a/trame_client/encoders/numpy.py +++ b/trame_client/encoders/numpy.py @@ -1,6 +1,15 @@ import json + import numpy as np +np_version = tuple(map(int, np.__version__.split("."))) +if np_version < (2, 0, 0): + NP_FLOATS = (np.float_, np.float16, np.float32, np.float64) + NP_COMPLEX = (np.complex_, np.complex64, np.complex128) +else: + NP_FLOATS = (np.float16, np.float32, np.float64) + NP_COMPLEX = (np.complex64, np.complex128) + class NumpyEncoder(json.JSONEncoder): """Custom encoder for numpy data types""" @@ -24,10 +33,10 @@ def default(self, obj): ): return int(obj) - elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): + elif isinstance(obj, NP_FLOATS): return float(obj) - elif isinstance(obj, (np.complex_, np.complex64, np.complex128)): + elif isinstance(obj, NP_COMPLEX): return {"real": obj.real, "imag": obj.imag} elif isinstance(obj, (np.ndarray,)):