Skip to content

Commit

Permalink
REF: Reimplement centered probe constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
carterbox committed Jul 17, 2024
1 parent a7a40ec commit 03d5faa
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
29 changes: 20 additions & 9 deletions src/tike/ptycho/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,18 +809,29 @@ def constrain_center_peak(probe):
stack = probe.reshape((-1, *probe.shape[-2:]))
intensity = cupyx.scipy.ndimage.gaussian_filter(
input=np.sum(np.square(np.abs(stack)), axis=0),
sigma=half,
mode='wrap',
sigma=(half[0] / 3, half[1] / 3),
mode="constant",
cval=0.0,
truncate=6.0,
)
# Find the maximum intensity in 2D.
center = np.argmax(intensity)
# Find the 2D coordinates of the maximum.
coords = cp.unravel_index(center, dims=probe.shape[-2:])
# Shift each of the probes so the max is in the center.
p = np.roll(stack, half[0] - coords[0], axis=-2)
stack = np.roll(p, half[1] - coords[1], axis=-1)
coords = cp.round(cupyx.scipy.ndimage.center_of_mass(intensity))
# Shift each of the probes so the max is in the center. Take integer steps
# only one pixel at a time.
shifted = cupyx.scipy.ndimage.shift(
stack,
shift=(
0,
min(1, max(-1, half[0] - coords[0])),
min(1, max(-1, half[1] - coords[1])),
),
mode="constant",
cval=0.0,
order=0,
)
assert shifted.dtype == stack.dtype, (shifted.dtype, stack.dtype)
# Reform to the original shape; make contiguous.
probe = stack.reshape(probe.shape)
probe = shifted.reshape(probe.shape)
return probe


Expand Down
14 changes: 14 additions & 0 deletions tests/ptycho/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,20 @@ def test_hermite_modes():
np.rollaxis(inputs['result'], -1, 0)[None, ...],
)

def test_center_peak():

x = cp.ones((1, 1, 1, 7, 7), dtype=cp.complex64)

x[0,0,0, 3, 6] = 10 + 23j

print()
print(x.squeeze())

y = tike.ptycho.probe.constrain_center_peak(x)

print()
print(np.round(y.squeeze(), 1))


if __name__ == '__main__':
unittest.main()

0 comments on commit 03d5faa

Please sign in to comment.