Skip to content

Commit

Permalink
Merge pull request #736 from QuantEcon/npy2
Browse files Browse the repository at this point in the history
FIX: Fix NumPy v2 compatibility issue
  • Loading branch information
oyamad authored Aug 16, 2024
2 parents a83c2ae + f48c505 commit ae7fb26
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 1 deletion.
32 changes: 32 additions & 0 deletions .github/workflows/ci_np2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: conda-build (NumPy v2)

on: [push]

jobs:
build-linux:
runs-on: ubuntu-latest
strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.12'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
conda env update --file environment_np2.yml --name base
- name: Conda info
shell: bash -l {0}
run: |
conda info
conda list
- name: Test with pytest
run: |
conda install pytest
pytest
18 changes: 18 additions & 0 deletions environment_np2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: qe
channels:
- conda-forge
- defaults
dependencies:
- coverage
- numpy>=2
- scipy
- pandas
- numba
- sympy
- ipython
- flake8
- requests
- urllib3>=2
- flit
- chardet # python>3.9,osx
- pytest
7 changes: 6 additions & 1 deletion quantecon/markov/gth_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import numpy as np
from numba import jit

from ..util.compat import copy_if_needed


def gth_solve(A, overwrite=False, use_jit=True):
r"""
This routine computes the stationary distribution of an irreducible
Expand Down Expand Up @@ -52,7 +55,9 @@ def gth_solve(A, overwrite=False, use_jit=True):
Simulation, Princeton University Press, 2009.
"""
A1 = np.array(A, dtype=float, copy=not overwrite, order='C')
copy = copy_if_needed if overwrite else True

A1 = np.array(A, dtype=float, copy=copy, order='C')
# `order='C'` is for use with Numba <= 0.18.2
# See issue github.com/numba/numba/issues/1103

Expand Down
7 changes: 7 additions & 0 deletions quantecon/markov/tests/test_gth_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ def test_matrices_with_C_F_orders():
assert_array_equal(computed_F, stationary_dist)


def test_unable_to_avoid_copy():
A = np.array([[0, 1], [0, 1]]) # dtype=int
stationary_dist = [0., 1.]
x = gth_solve(A, overwrite=True)
assert_array_equal(x, stationary_dist)


def test_raises_value_error_non_2dim():
"""Test with non 2dim input"""
assert_raises(ValueError, gth_solve, np.array([0.4, 0.6]))
Expand Down
23 changes: 23 additions & 0 deletions quantecon/util/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Utilities for compatibility
"""
from typing import Optional
import numpy as np


# From scipy/_lib/_util.py

copy_if_needed: Optional[bool]

if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
copy_if_needed = None
elif np.lib.NumpyVersion(np.__version__) < "1.28.0":
copy_if_needed = False
else:
# 2.0.0 dev versions, handle cases where copy may or may not exist
try:
np.array([1]).__array__(copy=None) # type: ignore[call-overload]
copy_if_needed = None
except TypeError:
copy_if_needed = False

0 comments on commit ae7fb26

Please sign in to comment.