Skip to content

Commit

Permalink
Merge pull request #330 from carterbox/deprecations
Browse files Browse the repository at this point in the history
BUG: Fix errors from deprecated API use with NumPy 2.0 API
  • Loading branch information
carterbox authored Jul 29, 2024
2 parents ccd45b9 + d92c636 commit 6bc53b0
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 32 deletions.
9 changes: 7 additions & 2 deletions src/tike/ptycho/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,16 @@ def copy_to_device(self) -> ObjectOptions:
if self.v is not None:
options.v = cp.asarray(
self.v,
dtype=tike.precision.floating,
dtype=tike.precision.cfloating
if np.iscomplexobj(self.v)
else tike.precision.floating,
)
if self.m is not None:
options.m = cp.asarray(
self.m,
dtype=tike.precision.floating,
dtype=tike.precision.cfloating
if np.iscomplexobj(self.m)
else tike.precision.floating,
)
if self.preconditioner is not None:
options.preconditioner = cp.asarray(
Expand Down Expand Up @@ -254,6 +258,7 @@ def get_padded_object(scan, probe, extra: int = 0):
return np.full_like(
probe,
shape=span.astype(tike.precision.integer),
dtype=tike.precision.cfloating,
fill_value=tike.precision.cfloating(0.5 + 0j),
), scan + 1 - min_corner + extra

Expand Down
5 changes: 2 additions & 3 deletions src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@
import dataclasses
import logging
import typing
import copy

import cupy as cp
import cupyx.scipy.ndimage
Expand Down Expand Up @@ -458,9 +457,9 @@ def insert(self, other, indices):

@staticmethod
def join(
x: typing.Iterable[PositionOptions | None],
x: typing.Iterable[typing.Union[PositionOptions, None]],
reorder: npt.NDArray[np.intc],
) -> PositionOptions | None:
) -> typing.Union[PositionOptions, None]:
if None in x:
return None
new = PositionOptions(
Expand Down
8 changes: 6 additions & 2 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,16 @@ def copy_to_device(self) -> ProbeOptions:
if self.v is not None:
options.v = cp.asarray(
self.v,
dtype=tike.precision.floating,
dtype=tike.precision.cfloating
if np.iscomplexobj(self.v)
else tike.precision.floating,
)
if self.m is not None:
options.m = cp.asarray(
self.m,
dtype=tike.precision.floating,
dtype=tike.precision.cfloating
if np.iscomplexobj(self.m)
else tike.precision.floating,
)
if self.preconditioner is not None:
options.preconditioner = cp.asarray(
Expand Down
3 changes: 2 additions & 1 deletion src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ def keep_some_args_constant(

step_length = cp.full(
shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1),
fill_value=exitwave_options.step_length_start,
fill_value=tike.precision.floating(exitwave_options.step_length_start),
dtype=tike.precision.floating,
)

if exitwave_options.step_length_usemodes == 'dominant_mode':
Expand Down
5 changes: 3 additions & 2 deletions src/tike/ptycho/solvers/rpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _get_nearplane_gradients(
position_options: typing.Union[None, PositionOptions],
exitwave_options: ExitWaveOptions,
) -> typing.Tuple[
float, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray | None
float, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, typing.Union[npt.NDArray, None]
]:
cost: float = 0.0
count: float = 1.0 / len(batches[n])
Expand Down Expand Up @@ -382,7 +382,8 @@ def keep_some_args_constant(

step_length = cp.full(
shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1),
fill_value=exitwave_options.step_length_start,
fill_value=tike.precision.floating(exitwave_options.step_length_start),
dtype=tike.precision.floating,
)

if exitwave_options.step_length_usemodes == 'dominant_mode':
Expand Down
5 changes: 3 additions & 2 deletions tests/ptycho/test_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os.path
import unittest

import tike.precision
from tike.ptycho.exitwave import ExitWaveOptions
from tike.ptycho.object import ObjectOptions
from tike.ptycho.position import PositionOptions
Expand Down Expand Up @@ -338,7 +339,7 @@ def test_consistent_rpie_unmeasured_detector_regions(self):
measured_pixels = np.logical_not(unmeasured_pixels)

# Zero out these regions on the diffraction measurement data
self.data = self.data.astype(np.floating)
self.data = self.data.astype(tike.precision.floating)
self.data[:, unmeasured_pixels] = np.nan

params = tike.ptycho.PtychoParameters(
Expand Down Expand Up @@ -380,7 +381,7 @@ def test_consistent_lstsq_grad_unmeasured_detector_regions(self):
measured_pixels = np.logical_not(unmeasured_pixels)

# Zero out these regions on the diffraction measurement data
self.data = self.data.astype(np.floating)
self.data = self.data.astype(tike.precision.floating)
self.data[:, unmeasured_pixels] = np.nan

params = tike.ptycho.PtychoParameters(
Expand Down
20 changes: 0 additions & 20 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ class TestWobblyCenter(ClusterTests, unittest.TestCase):

cluster_method = staticmethod(tike.cluster.wobbly_center)

def test_simple_cluster(self):
references = [
np.array([2, 3, 4, 9]),
np.array([0, 5, 8]),
np.array([1, 6, 7]),
]
result = tike.cluster.wobbly_center(np.arange(10)[:, None], 3)
for a, b in zip(references, result):
np.testing.assert_array_equal(a, b)

def test_same_mean(self):
"""Test that wobbly center generates better samples of the population.
Expand Down Expand Up @@ -118,16 +108,6 @@ class TestWobblyCenterRandomBootstrap(ClusterTests, unittest.TestCase):

cluster_method = staticmethod(tike.cluster.wobbly_center_random_bootstrap)

def test_simple_cluster(self):
references = [
np.array([2, 3, 4, 9]),
np.array([0, 5, 8]),
np.array([1, 6, 7]),
]
result = tike.cluster.wobbly_center(np.arange(10)[:, None], 3)
for a, b in zip(references, result):
np.testing.assert_array_equal(a, b)

def test_same_mean(self):
"""Test that wobbly center generates better samples of the population.
Expand Down

0 comments on commit 6bc53b0

Please sign in to comment.