Skip to content

Commit

Permalink
FIX: Fix NumPy v2 compatibility issue
Browse files Browse the repository at this point in the history
  • Loading branch information
oyamad committed Aug 14, 2024
1 parent 4b0223b commit f48c505
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
7 changes: 6 additions & 1 deletion quantecon/markov/gth_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import numpy as np
from numba import jit

from ..util.compat import copy_if_needed


def gth_solve(A, overwrite=False, use_jit=True):
r"""
This routine computes the stationary distribution of an irreducible
Expand Down Expand Up @@ -52,7 +55,9 @@ def gth_solve(A, overwrite=False, use_jit=True):
Simulation, Princeton University Press, 2009.
"""
A1 = np.array(A, dtype=float, copy=not overwrite, order='C')
copy = copy_if_needed if overwrite else True

A1 = np.array(A, dtype=float, copy=copy, order='C')
# `order='C'` is for use with Numba <= 0.18.2
# See issue github.com/numba/numba/issues/1103

Expand Down
7 changes: 7 additions & 0 deletions quantecon/markov/tests/test_gth_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ def test_matrices_with_C_F_orders():
assert_array_equal(computed_F, stationary_dist)


def test_unable_to_avoid_copy():
A = np.array([[0, 1], [0, 1]]) # dtype=int
stationary_dist = [0., 1.]
x = gth_solve(A, overwrite=True)
assert_array_equal(x, stationary_dist)


def test_raises_value_error_non_2dim():
"""Test with non 2dim input"""
assert_raises(ValueError, gth_solve, np.array([0.4, 0.6]))
Expand Down
23 changes: 23 additions & 0 deletions quantecon/util/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Utilities for compatibility
"""
from typing import Optional
import numpy as np


# From scipy/_lib/_util.py

copy_if_needed: Optional[bool]

if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
copy_if_needed = None
elif np.lib.NumpyVersion(np.__version__) < "1.28.0":
copy_if_needed = False
else:
# 2.0.0 dev versions, handle cases where copy may or may not exist
try:
np.array([1]).__array__(copy=None) # type: ignore[call-overload]
copy_if_needed = None
except TypeError:
copy_if_needed = False

0 comments on commit f48c505

Please sign in to comment.