diff --git a/torchmdnet/utils.py b/torchmdnet/utils.py index 7ed94301..d01c3fb6 100644 --- a/torchmdnet/utils.py +++ b/torchmdnet/utils.py @@ -247,7 +247,7 @@ def make_splits( order=None, ): if splits is not None: - splits = np.load(splits) + splits = np.load(splits, allow_pickle=True) idx_train = splits["idx_train"] idx_val = splits["idx_val"] idx_test = splits["idx_test"]