Skip to content

Commit

Permalink
fix(numpy): add support for numpy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
kmarchais authored Jul 2, 2024
1 parent 412090e commit b45403b
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions trame_client/encoders/numpy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import json

import numpy as np

np_version = tuple(map(int, np.__version__.split(".")))
if np_version < (2, 0, 0):
NP_FLOATS = (np.float_, np.float16, np.float32, np.float64)
NP_COMPLEX = (np.complex_, np.complex64, np.complex128)
else:
NP_FLOATS = (np.float16, np.float32, np.float64)
NP_COMPLEX = (np.complex64, np.complex128)


class NumpyEncoder(json.JSONEncoder):
"""Custom encoder for numpy data types"""
Expand All @@ -24,10 +33,10 @@ def default(self, obj):
):
return int(obj)

elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
elif isinstance(obj, NP_FLOATS):
return float(obj)

elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
elif isinstance(obj, NP_COMPLEX):
return {"real": obj.real, "imag": obj.imag}

elif isinstance(obj, (np.ndarray,)):
Expand Down

0 comments on commit b45403b

Please sign in to comment.