Skip to content

Commit

Permalink
fixing the bool problem
Browse files Browse the repository at this point in the history
  • Loading branch information
IAlibay committed Sep 11, 2023
1 parent a0e25ef commit 4ea0388
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
3 changes: 3 additions & 0 deletions gufe/custom_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def inherited_is_my_dict(dct, cls):
stored = gufe.tokenization.get_class(module, classname)
return cls in stored.mro()


def is_npy_dtype_dict(dct):
expected = ["dtype", "bytes"]
is_custom = all(exp in dct for exp in expected)
return is_custom and ("shape" not in dct)


def is_openff_unit_dict(dct):
expected = ["pint_unit_registry", "unit_name", ":is_custom:"]
is_custom = all(exp in dct for exp in expected)
Expand Down Expand Up @@ -80,6 +82,7 @@ def is_openff_quantity_dict(dct):
from_dict=lambda dct: np.frombuffer(
dct['bytes'], dtype=np.dtype(dct['dtype'])
)[0],
is_my_obj=lambda obj: isinstance(obj, np.generic),
is_my_dict=is_npy_dtype_dict,
)

Expand Down
29 changes: 24 additions & 5 deletions gufe/tests/test_custom_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def test_not_mine(self):
class TestNumpyCoding(CustomJSONCodingTest):
def setup_method(self):
self.codec = NUMPY_CODEC
self.objs = [np.array([[1.0, 0.0], [2.0, 3.2]]), np.array([1, 0])]
shapes = [[2, 2], [2,]]
self.objs = [np.array([[1.0, 0.0], [2.0, 3.2]]), np.array([1, 0]),
np.array([1.0, 2.0, 3.0], dtype=np.float32),]
shapes = [[2, 2], [2,], [3,]]
dtypes = [str(arr.dtype) for arr in self.objs] # may change by system?
byte_reps = [arr.tobytes() for arr in self.objs]
self.dcts = [
Expand Down Expand Up @@ -123,11 +124,15 @@ def test_round_trip(self):
class TestNumpyGenericCodec(TestNumpyCoding):
def setup_method(self):
self.codec = NPY_DTYPE_CODEC
self.objs = [np.float16(1.0), np.float32(1.0), np.float64(1.0),
np.complex128(1.0), np.clongdouble(1.0), np.uint64(1),]
self.objs = [np.bool_(True), np.float16(1.0), np.float32(1.0),
np.float64(1.0), np.complex128(1.0),
np.clongdouble(1.0), np.uint64(1),]
dtypes = [str(a.dtype) for a in self.objs]
byte_reps = [a.tobytes() for a in self.objs]
classes = [str(a.dtype) for a in self.objs]
# Overly complicated extraction of the class name
# to deal with the bool_ -> bool dtype class name problem
classes = [str(a.__class__).split("'")[1].split('.')[1]
for a in self.objs]
self.dcts = [
{
":is_custom:": True,
Expand All @@ -139,6 +144,20 @@ def setup_method(self):
for dtype, byte_rep, classname in zip(dtypes, byte_reps, classes)
]

def test_round_trip(self):
encoder, decoder = custom_json_factory([self.codec, BYTES_CODEC])
for (obj, dct) in zip(self.objs, self.dcts):
print(dct)
print(encoder)
json_str = json.dumps(obj, cls=encoder)
print(json_str)
reconstructed = json.loads(json_str, cls=decoder)
print(type(reconstructed))
npt.assert_array_equal(reconstructed, obj)
assert reconstructed.dtype == obj.dtype
json_str_2 = json.dumps(obj, cls=encoder)
assert json_str == json_str_2

class TestPathCodec(CustomJSONCodingTest):
def setup_method(self):
self.codec = PATH_CODEC
Expand Down
2 changes: 2 additions & 0 deletions gufe/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gufe.custom_codecs import (
BYTES_CODEC,
DATETIME_CODEC,
NPY_DTYPE_CODEC,
NUMPY_CODEC,
OPENFF_QUANTITY_CODEC,
OPENFF_UNIT_CODEC,
Expand All @@ -26,6 +27,7 @@
_default_json_codecs = [
PATH_CODEC,
NUMPY_CODEC,
NPY_DTYPE_CODEC,
BYTES_CODEC,
DATETIME_CODEC,
SETTINGS_CODEC,
Expand Down

0 comments on commit 4ea0388

Please sign in to comment.