Skip to content

Commit

Permalink
Merge branch 'main' into costs-reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Ching committed Jul 29, 2024
2 parents 8c1561a + 6bc53b0 commit adc1693
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 84 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"cupy >=10.0, !=10.3.0, !=13.0.*",
'importlib_resources; python_version<"3.9"',
"matplotlib ==3.*",
"numpy ~=1.17",
"numpy >=1.17",
"opencv-python >=3.4, <5.0",
"scipy >=1.6.0",
]
Expand Down
9 changes: 7 additions & 2 deletions src/tike/ptycho/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,16 @@ def copy_to_device(self) -> ObjectOptions:
if self.v is not None:
options.v = cp.asarray(
self.v,
dtype=tike.precision.floating,
dtype=tike.precision.cfloating
if np.iscomplexobj(self.v)
else tike.precision.floating,
)
if self.m is not None:
options.m = cp.asarray(
self.m,
dtype=tike.precision.floating,
dtype=tike.precision.cfloating
if np.iscomplexobj(self.m)
else tike.precision.floating,
)
if self.preconditioner is not None:
options.preconditioner = cp.asarray(
Expand Down Expand Up @@ -254,6 +258,7 @@ def get_padded_object(scan, probe, extra: int = 0):
return np.full_like(
probe,
shape=span.astype(tike.precision.integer),
dtype=tike.precision.cfloating,
fill_value=tike.precision.cfloating(0.5 + 0j),
), scan + 1 - min_corner + extra

Expand Down
5 changes: 2 additions & 3 deletions src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@
import dataclasses
import logging
import typing
import copy

import cupy as cp
import cupyx.scipy.ndimage
Expand Down Expand Up @@ -458,9 +457,9 @@ def insert(self, other, indices):

@staticmethod
def join(
x: typing.Iterable[PositionOptions | None],
x: typing.Iterable[typing.Union[PositionOptions, None]],
reorder: npt.NDArray[np.intc],
) -> PositionOptions | None:
) -> typing.Union[PositionOptions, None]:
if None in x:
return None
new = PositionOptions(
Expand Down
37 changes: 26 additions & 11 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,16 @@ def copy_to_device(self) -> ProbeOptions:
if self.v is not None:
options.v = cp.asarray(
self.v,
dtype=tike.precision.floating,
dtype=tike.precision.cfloating
if np.iscomplexobj(self.v)
else tike.precision.floating,
)
if self.m is not None:
options.m = cp.asarray(
self.m,
dtype=tike.precision.floating,
dtype=tike.precision.cfloating
if np.iscomplexobj(self.m)
else tike.precision.floating,
)
if self.preconditioner is not None:
options.preconditioner = cp.asarray(
Expand Down Expand Up @@ -809,18 +813,29 @@ def constrain_center_peak(probe):
stack = probe.reshape((-1, *probe.shape[-2:]))
intensity = cupyx.scipy.ndimage.gaussian_filter(
input=np.sum(np.square(np.abs(stack)), axis=0),
sigma=half,
mode='wrap',
sigma=(half[0] / 3, half[1] / 3),
mode="constant",
cval=0.0,
truncate=6.0,
)
# Find the maximum intensity in 2D.
center = np.argmax(intensity)
# Find the 2D coordinates of the maximum.
coords = cp.unravel_index(center, dims=probe.shape[-2:])
# Shift each of the probes so the max is in the center.
p = np.roll(stack, half[0] - coords[0], axis=-2)
stack = np.roll(p, half[1] - coords[1], axis=-1)
coords = cp.round(cupyx.scipy.ndimage.center_of_mass(intensity))
# Shift each of the probes so the max is in the center. Take integer steps
# only one pixel at a time.
shifted = cupyx.scipy.ndimage.shift(
stack,
shift=(
0,
min(1, max(-1, half[0] - coords[0])),
min(1, max(-1, half[1] - coords[1])),
),
mode="constant",
cval=0.0,
order=0,
)
assert shifted.dtype == stack.dtype, (shifted.dtype, stack.dtype)
# Reform to the original shape; make contiguous.
probe = stack.reshape(probe.shape)
probe = shifted.reshape(probe.shape)
return probe


Expand Down
22 changes: 22 additions & 0 deletions src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
get_varying_probe,
apply_median_filter_abs_probe,
orthogonalize_eig,
finite_probe_support,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -723,6 +724,27 @@ def _apply_probe_constraints(
if parameters.probe_options is not None:
if parameters.probe_options.recover_probe(epoch):

if parameters.probe_options.probe_support > 0:
b0 = finite_probe_support(
parameters.probe,
p=parameters.probe_options.probe_support,
radius=parameters.probe_options.probe_support_radius,
degree=parameters.probe_options.probe_support_degree,
)
parameters.probe -= b0 * cp.conj(b0 * parameters.probe)

if parameters.probe_options.additional_probe_penalty > 0:
b1 = (
parameters.probe_options.additional_probe_penalty
* cp.linspace(
0,
1,
parameters.probe.shape[-3],
dtype=tike.precision.floating,
)[..., None, None]
)
parameters.probe -= b1 * cp.conj(b1 * parameters.probe)

if parameters.probe_options.median_filter_abs_probe:
parameters.probe = apply_median_filter_abs_probe(
parameters.probe,
Expand Down
32 changes: 2 additions & 30 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ def keep_some_args_constant(

step_length = cp.full(
shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1),
fill_value=exitwave_options.step_length_start,
fill_value=tike.precision.floating(exitwave_options.step_length_start),
dtype=tike.precision.floating,
)

if exitwave_options.step_length_usemodes == 'dominant_mode':
Expand Down Expand Up @@ -711,35 +712,6 @@ def _precondition_nearplane_gradients(
A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1))

if recover_probe:
b0 = tike.ptycho.probe.finite_probe_support(
unique_probe[..., m : m + 1, :, :],
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)

b1 = (
probe_options.additional_probe_penalty
* cp.linspace(
0,
1,
probe[0].shape[-3],
dtype=tike.precision.floating,
)[..., m : m + 1, None, None]
)

m_probe_update = m_probe_update - (b0 + b1) * probe[..., m : m + 1, :, :]
# / (
# (1 - alpha) * probe_update_denominator
# + alpha
# * probe_update_denominator.max(
# axis=(-2, -1),
# keepdims=True,
# )
# + b0
# + b1
# )

dPO = m_probe_update[..., m:m + 1, :, :] * patches
A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1))

Expand Down
19 changes: 4 additions & 15 deletions src/tike/ptycho/solvers/rpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,26 +265,14 @@ def _update(
psi = psi + dpsi / deno

if recover_probe:
b0 = tike.ptycho.probe.finite_probe_support(
probe,
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)
b1 = (
probe_options.additional_probe_penalty
* cp.linspace(0, 1, probe.shape[-3], dtype="float32")[..., None, None]
)
dprobe = probe_update_numerator - (b1 + b0) * probe
dprobe = probe_update_numerator
deno = (
(1 - algorithm_options.alpha) * probe_options.preconditioner
+ algorithm_options.alpha
* probe_options.preconditioner.max(
axis=(-2, -1),
keepdims=True,
)
+ b0
+ b1
)
probe = probe + dprobe / deno
if probe_options.use_adaptive_moment:
Expand Down Expand Up @@ -343,7 +331,7 @@ def _get_nearplane_gradients(
position_options: typing.Union[None, PositionOptions],
exitwave_options: ExitWaveOptions,
) -> typing.Tuple[
float, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray | None
float, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, typing.Union[npt.NDArray, None]
]:
batch_start = batches[n][0]
batch_size = len(batches[n])
Expand Down Expand Up @@ -398,7 +386,8 @@ def keep_some_args_constant(

step_length = cp.full(
shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1),
fill_value=exitwave_options.step_length_start,
fill_value=tike.precision.floating(exitwave_options.step_length_start),
dtype=tike.precision.floating,
)

if exitwave_options.step_length_usemodes == 'dominant_mode':
Expand Down
5 changes: 3 additions & 2 deletions tests/ptycho/test_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os.path
import unittest

import tike.precision
from tike.ptycho.exitwave import ExitWaveOptions
from tike.ptycho.object import ObjectOptions
from tike.ptycho.position import PositionOptions
Expand Down Expand Up @@ -338,7 +339,7 @@ def test_consistent_rpie_unmeasured_detector_regions(self):
measured_pixels = np.logical_not(unmeasured_pixels)

# Zero out these regions on the diffraction measurement data
self.data = self.data.astype(np.floating)
self.data = self.data.astype(tike.precision.floating)
self.data[:, unmeasured_pixels] = np.nan

params = tike.ptycho.PtychoParameters(
Expand Down Expand Up @@ -380,7 +381,7 @@ def test_consistent_lstsq_grad_unmeasured_detector_regions(self):
measured_pixels = np.logical_not(unmeasured_pixels)

# Zero out these regions on the diffraction measurement data
self.data = self.data.astype(np.floating)
self.data = self.data.astype(tike.precision.floating)
self.data[:, unmeasured_pixels] = np.nan

params = tike.ptycho.PtychoParameters(
Expand Down
14 changes: 14 additions & 0 deletions tests/ptycho/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,20 @@ def test_hermite_modes():
np.rollaxis(inputs['result'], -1, 0)[None, ...],
)

def test_center_peak():

x = cp.ones((1, 1, 1, 7, 7), dtype=cp.complex64)

x[0,0,0, 3, 6] = 10 + 23j

print()
print(x.squeeze())

y = tike.ptycho.probe.constrain_center_peak(x)

print()
print(np.round(y.squeeze(), 1))


if __name__ == '__main__':
unittest.main()
20 changes: 0 additions & 20 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ class TestWobblyCenter(ClusterTests, unittest.TestCase):

cluster_method = staticmethod(tike.cluster.wobbly_center)

def test_simple_cluster(self):
references = [
np.array([2, 3, 4, 9]),
np.array([0, 5, 8]),
np.array([1, 6, 7]),
]
result = tike.cluster.wobbly_center(np.arange(10)[:, None], 3)
for a, b in zip(references, result):
np.testing.assert_array_equal(a, b)

def test_same_mean(self):
"""Test that wobbly center generates better samples of the population.
Expand Down Expand Up @@ -118,16 +108,6 @@ class TestWobblyCenterRandomBootstrap(ClusterTests, unittest.TestCase):

cluster_method = staticmethod(tike.cluster.wobbly_center_random_bootstrap)

def test_simple_cluster(self):
references = [
np.array([2, 3, 4, 9]),
np.array([0, 5, 8]),
np.array([1, 6, 7]),
]
result = tike.cluster.wobbly_center(np.arange(10)[:, None], 3)
for a, b in zip(references, result):
np.testing.assert_array_equal(a, b)

def test_same_mean(self):
"""Test that wobbly center generates better samples of the population.
Expand Down

0 comments on commit adc1693

Please sign in to comment.