diff --git a/quantecon/markov/gth_solve.py b/quantecon/markov/gth_solve.py index b3257c7bf..76ca0438c 100644 --- a/quantecon/markov/gth_solve.py +++ b/quantecon/markov/gth_solve.py @@ -6,6 +6,9 @@ import numpy as np from numba import jit +from ..util.compat import copy_if_needed + + def gth_solve(A, overwrite=False, use_jit=True): r""" This routine computes the stationary distribution of an irreducible @@ -52,7 +55,9 @@ def gth_solve(A, overwrite=False, use_jit=True): Simulation, Princeton University Press, 2009. """ - A1 = np.array(A, dtype=float, copy=not overwrite, order='C') + copy = copy_if_needed if overwrite else True + + A1 = np.array(A, dtype=float, copy=copy, order='C') # `order='C'` is for use with Numba <= 0.18.2 # See issue github.com/numba/numba/issues/1103 diff --git a/quantecon/markov/tests/test_gth_solve.py b/quantecon/markov/tests/test_gth_solve.py index f5275e869..16651f7e8 100644 --- a/quantecon/markov/tests/test_gth_solve.py +++ b/quantecon/markov/tests/test_gth_solve.py @@ -196,6 +196,13 @@ def test_matrices_with_C_F_orders(): assert_array_equal(computed_F, stationary_dist) +def test_unable_to_avoid_copy(): + A = np.array([[0, 1], [0, 1]]) # dtype=int + stationary_dist = [0., 1.] + x = gth_solve(A, overwrite=True) + assert_array_equal(x, stationary_dist) + + def test_raises_value_error_non_2dim(): """Test with non 2dim input""" assert_raises(ValueError, gth_solve, np.array([0.4, 0.6])) diff --git a/quantecon/util/compat.py b/quantecon/util/compat.py new file mode 100644 index 000000000..5037eec84 --- /dev/null +++ b/quantecon/util/compat.py @@ -0,0 +1,23 @@ +""" +Utilities for compatibility + +""" +from typing import Optional +import numpy as np + + +# From scipy/_lib/_util.py + +copy_if_needed: Optional[bool] + +if np.lib.NumpyVersion(np.__version__) >= "2.0.0": + copy_if_needed = None +elif np.lib.NumpyVersion(np.__version__) < "1.28.0": + copy_if_needed = False +else: + # 2.0.0 dev versions, handle cases where copy may or may not exist + try: + np.array([1]).__array__(copy=None) # type: ignore[call-overload] + copy_if_needed = None + except TypeError: + copy_if_needed = False