Skip to content

Commit

Permalink
Merge pull request #304 from carterbox/transform
Browse files Browse the repository at this point in the history
NEW: Mitigate scan position global drift
  • Loading branch information
carterbox authored Apr 18, 2024
2 parents e4685fd + 7988bf5 commit 9075301
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 60 deletions.
197 changes: 181 additions & 16 deletions src/tike/operators/cupy/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .operator import Operator
from .patch import Patch
from .shift import Shift


class Convolution(Operator):
Expand Down Expand Up @@ -40,6 +41,7 @@ class Convolution(Operator):
first, horizontal coordinates second.
"""

def __init__(self, probe_shape, nz, n, ntheta=None,
detector_shape=None, **kwargs): # yapf: disable
self.probe_shape = probe_shape
Expand All @@ -65,14 +67,22 @@ def fwd(self, psi, scan, probe):
if self.detector_shape == self.probe_shape:
patches = self.xp.empty_like(
psi,
shape=(*scan.shape[:-2], scan.shape[-2] * probe.shape[-3],
self.detector_shape, self.detector_shape),
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
else:
patches = self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], scan.shape[-2] * probe.shape[-3],
self.detector_shape, self.detector_shape),
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
patches = self.patch.fwd(
patches=patches,
Expand All @@ -81,8 +91,12 @@ def fwd(self, psi, scan, probe):
patch_width=self.probe_shape,
nrepeat=probe.shape[-3],
)
patches = patches.reshape((*scan.shape[:-1], probe.shape[-3],
self.detector_shape, self.detector_shape))
patches = patches.reshape((
*scan.shape[:-1],
probe.shape[-3],
self.detector_shape,
self.detector_shape,
))
patches[..., self.pad:self.end, self.pad:self.end] *= probe
return patches

Expand All @@ -101,9 +115,11 @@ def adj(self, nearplane, scan, probe, psi=None, overwrite=False):
)
assert psi.shape[:-2] == scan.shape[:-2]
return self.patch.adj(
patches=nearplane.reshape(
(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:])),
patches=nearplane.reshape((
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:],
)),
images=psi,
positions=scan,
patch_width=self.probe_shape,
Expand All @@ -117,8 +133,12 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False):
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
patches = self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
self.probe_shape, self.probe_shape),
shape=(
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
self.probe_shape,
self.probe_shape,
),
)
patches = self.patch.fwd(
patches=patches,
Expand All @@ -145,8 +165,12 @@ def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False):
# Could be xp.empty if scan positions are all in bounds
patches=self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
self.probe_shape, self.probe_shape),
shape=(
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
self.probe_shape,
self.probe_shape,
),
),
images=psi,
positions=scan,
Expand Down Expand Up @@ -186,9 +210,11 @@ def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False):
)

apsi = self.patch.adj(
patches=nearplane.reshape(
(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:])),
patches=nearplane.reshape((
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:],
)),
images=self.xp.zeros_like(
psi,
shape=(*scan.shape[:-2], self.nz, self.n),
Expand All @@ -202,3 +228,142 @@ def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False):
return apsi, patches, patches_amp, probe_amp
else:
return apsi, patches


class ConvolutionFFT(Operator):
"""A 2D Convolution operator with linear interpolation.
Compute the product two arrays at specific relative positions.
Attributes
----------
nscan : int
The number of scan positions at each angular view.
probe_shape : int
The pixel width and height of the (square) probe illumination.
nz, n : int
The pixel width and height of the reconstructed grid.
ntheta : int
The number of angular partitions of the data.
Parameters
----------
psi : (..., nz, n) complex64
The complex wavefront modulation of the object.
probe : complex64
The (..., nscan, nprobe, probe_shape, probe_shape) or
(..., 1, nprobe, probe_shape, probe_shape) complex illumination
function.
nearplane: complex64
The (...., nscan, nprobe, probe_shape, probe_shape)
wavefronts after exiting the object.
scan : (..., nscan, 2) float32
Coordinates of the minimum corner of the probe grid for each
measurement in the coordinate system of psi. Vertical coordinates
first, horizontal coordinates second.
"""

def __init__(self, probe_shape, nz, n, ntheta=None,
detector_shape=None, **kwargs): # yapf: disable
self.probe_shape = probe_shape
self.nz = nz
self.n = n
if detector_shape is None:
self.detector_shape = probe_shape
else:
self.detector_shape = detector_shape
self.pad = (self.detector_shape - self.probe_shape) // 2
self.end = self.probe_shape + self.pad
self.patch = Patch()
self.shift = Shift()

def __enter__(self):
self.shift.__enter__()
return self

def __exit__(self, type, value, traceback):
self.shift.__exit__(type, value, traceback)

def fwd(self, psi, scan, probe):
"""Extract probe shaped patches from the psi at each scan position.
The patches within the bounds of psi are linearly interpolated, and
indices outside the bounds of psi are not allowed.
"""
assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape)
assert probe.shape[:-4] == scan.shape[:-2], (probe.shape, scan.shape)
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
if self.detector_shape == self.probe_shape:
patches = self.xp.empty_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
else:
patches = self.xp.zeros_like(
psi,
shape=(
*scan.shape[:-2],
scan.shape[-2] * probe.shape[-3],
self.detector_shape,
self.detector_shape,
),
)
index, shift = self.xp.divmod(scan, 1.0)
shift = shift.reshape((*scan.shape[:-1], 1, 2))

patches = self.patch.fwd(
patches=patches,
images=psi,
positions=index,
patch_width=self.probe_shape,
nrepeat=probe.shape[-3],
)

patches = patches.reshape((
*scan.shape[:-1],
probe.shape[-3],
self.detector_shape,
self.detector_shape,
))
patches = self.shift.adj(patches, shift, overwrite=False)

patches[..., self.pad:self.end, self.pad:self.end] *= probe
return patches

def adj(self, nearplane, scan, probe, psi=None, overwrite=False):
"""Combine probe shaped patches into a psi shaped grid by addition."""
assert probe.shape[:-4] == scan.shape[:-2]
assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2]
assert nearplane.shape[:-3] == scan.shape[:-1]
if not overwrite:
nearplane = nearplane.copy()
nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj()

index, shift = self.xp.divmod(scan, 1.0)
shift = shift.reshape((*scan.shape[:-1], 1, 2))

nearplane = self.shift.fwd(nearplane, shift, overwrite=True)

if psi is None:
psi = self.xp.zeros_like(
nearplane,
shape=(*scan.shape[:-2], self.nz, self.n),
)
assert psi.shape[:-2] == scan.shape[:-2]
return self.patch.adj(
patches=nearplane.reshape((
*scan.shape[:-2],
scan.shape[-2] * nearplane.shape[-3],
*nearplane.shape[-2:],
)),
images=psi,
positions=index,
patch_width=self.probe_shape,
nrepeat=nearplane.shape[-3],
)
8 changes: 5 additions & 3 deletions src/tike/operators/cupy/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def fwd(self, a, shift, overwrite=False, cval=None):
if shift is None:
return a
shape = a.shape
padded = a.reshape(-1, *shape[-2:])
padded = a.reshape(*shape)
padded = self._fft2(
padded,
axes=(-2, -1),
Expand All @@ -33,8 +33,10 @@ def fwd(self, a, shift, overwrite=False, cval=None):
self.xp.fft.fftfreq(padded.shape[-2]).astype(shift.dtype),
)
padded *= self.xp.exp(
-2j * self.xp.pi *
(x * shift[..., 1, None, None] + y * shift[..., 0, None, None]))
-2j
* self.xp.pi
* (x * shift[..., 1, None, None] + y * shift[..., 0, None, None])
)
padded = self._ifft2(padded, axes=(-2, -1), overwrite_x=True)
return padded.reshape(*shape)

Expand Down
43 changes: 38 additions & 5 deletions src/tike/ptycho/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def estimate_global_transformation_ransac(
positions1: np.ndarray,
weights: np.ndarray = None,
transform: AffineTransform = AffineTransform(),
min_sample: int = 4,
min_sample: int = 4, # must be 4 because we are solving for a 2x2 matrix
max_error: float = 32,
min_consensus: float = 0.75,
max_iter: int = 20,
Expand All @@ -282,6 +282,7 @@ def estimate_global_transformation_ransac(
"""
best_fitness = np.inf # small fitness is good
# Choose a subset
# FIXME: Use montecarlo sampling to decide when to stop RANSAC instead of minimum consensus
for subset in tike.random.randomizer_np.choice(
a=len(positions0),
size=(max_iter, min_sample),
Expand Down Expand Up @@ -353,11 +354,14 @@ class PositionOptions:
)
"""A rating of the confidence of position information around each position."""

update_start: int = 0
"""Start position updates at this epoch."""

def __post_init__(self):
self.initial_scan = self.initial_scan.astype(tike.precision.floating)
if self.confidence is None:
self.confidence = np.ones(
shape=(*self.initial_scan.shape[:-1], 1),
shape=self.initial_scan.shape,
dtype=tike.precision.floating,
)
if self.use_adaptive_moment:
Expand Down Expand Up @@ -445,7 +449,7 @@ def join(self, other, indices):
self.initial_scan = new_initial_scan
if self.confidence is not None:
new_confidence = np.empty(
(*self.initial_scan.shape[:-2], max_index, 1),
(*self.initial_scan.shape[:-2], max_index, 2),
dtype=self.initial_scan.dtype,
)
new_confidence[..., :len_scan, :] = self.confidence
Expand Down Expand Up @@ -657,13 +661,34 @@ def _gaussian_frequency(sigma, size):
return arr


def _affine_position_helper(
scan,
position_options: PositionOptions,
max_error,
relax=0.1,
):
predicted_positions = position_options.transform(
position_options.initial_scan)
err = predicted_positions - position_options.initial_scan
# constrain more the probes in flat regions
W = relax * (1 - (position_options.confidence /
(1 + position_options.confidence)))
# penalize positions that are further than max_error from origin; avoid travel larger than max error
W = cp.minimum(10 * relax,
W + cp.maximum(0, err - max_error)**2 / max_error**2)
# allow free movement in depenence on realibility and max allowed error
new_scan = scan * (1 - W) + W * predicted_positions
return new_scan


# TODO: What is a good default value for max_error?
def affine_position_regularization(
comm: tike.communicators.Comm,
updated: typing.List[cp.ndarray],
position_options: typing.List[PositionOptions],
max_error: float = 32,
) -> typing.List[PositionOptions]:
regularization_enabled: bool = False,
) -> typing.Tuple[typing.List[cp.ndarray], typing.List[PositionOptions]]:
"""Regularize position updates with an affine deformation constraint.
Assume that the true position updates are a global affine transformation
Expand Down Expand Up @@ -706,7 +731,15 @@ def affine_position_regularization(
for i in range(len(position_options)):
position_options[i].transform = new_transform

return position_options
if regularization_enabled:
updated = comm.pool.map(
_affine_position_helper,
updated,
position_options,
max_error=max_error,
)

return updated, position_options


def gaussian_gradient(
Expand Down
Loading

0 comments on commit 9075301

Please sign in to comment.