diff --git a/src/tike/operators/cupy/convolution.py b/src/tike/operators/cupy/convolution.py index 0a82bbb7..b2cb6fd1 100644 --- a/src/tike/operators/cupy/convolution.py +++ b/src/tike/operators/cupy/convolution.py @@ -5,6 +5,7 @@ from .operator import Operator from .patch import Patch +from .shift import Shift class Convolution(Operator): @@ -40,6 +41,7 @@ class Convolution(Operator): first, horizontal coordinates second. """ + def __init__(self, probe_shape, nz, n, ntheta=None, detector_shape=None, **kwargs): # yapf: disable self.probe_shape = probe_shape @@ -65,14 +67,22 @@ def fwd(self, psi, scan, probe): if self.detector_shape == self.probe_shape: patches = self.xp.empty_like( psi, - shape=(*scan.shape[:-2], scan.shape[-2] * probe.shape[-3], - self.detector_shape, self.detector_shape), + shape=( + *scan.shape[:-2], + scan.shape[-2] * probe.shape[-3], + self.detector_shape, + self.detector_shape, + ), ) else: patches = self.xp.zeros_like( psi, - shape=(*scan.shape[:-2], scan.shape[-2] * probe.shape[-3], - self.detector_shape, self.detector_shape), + shape=( + *scan.shape[:-2], + scan.shape[-2] * probe.shape[-3], + self.detector_shape, + self.detector_shape, + ), ) patches = self.patch.fwd( patches=patches, @@ -81,8 +91,12 @@ def fwd(self, psi, scan, probe): patch_width=self.probe_shape, nrepeat=probe.shape[-3], ) - patches = patches.reshape((*scan.shape[:-1], probe.shape[-3], - self.detector_shape, self.detector_shape)) + patches = patches.reshape(( + *scan.shape[:-1], + probe.shape[-3], + self.detector_shape, + self.detector_shape, + )) patches[..., self.pad:self.end, self.pad:self.end] *= probe return patches @@ -101,9 +115,11 @@ def adj(self, nearplane, scan, probe, psi=None, overwrite=False): ) assert psi.shape[:-2] == scan.shape[:-2] return self.patch.adj( - patches=nearplane.reshape( - (*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3], - *nearplane.shape[-2:])), + patches=nearplane.reshape(( + *scan.shape[:-2], + scan.shape[-2] * nearplane.shape[-3], + *nearplane.shape[-2:], + )), images=psi, positions=scan, patch_width=self.probe_shape, @@ -117,8 +133,12 @@ def adj_probe(self, nearplane, scan, psi, overwrite=False): assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape) patches = self.xp.zeros_like( psi, - shape=(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3], - self.probe_shape, self.probe_shape), + shape=( + *scan.shape[:-2], + scan.shape[-2] * nearplane.shape[-3], + self.probe_shape, + self.probe_shape, + ), ) patches = self.patch.fwd( patches=patches, @@ -145,8 +165,12 @@ def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False): # Could be xp.empty if scan positions are all in bounds patches=self.xp.zeros_like( psi, - shape=(*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3], - self.probe_shape, self.probe_shape), + shape=( + *scan.shape[:-2], + scan.shape[-2] * nearplane.shape[-3], + self.probe_shape, + self.probe_shape, + ), ), images=psi, positions=scan, @@ -186,9 +210,11 @@ def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False): ) apsi = self.patch.adj( - patches=nearplane.reshape( - (*scan.shape[:-2], scan.shape[-2] * nearplane.shape[-3], - *nearplane.shape[-2:])), + patches=nearplane.reshape(( + *scan.shape[:-2], + scan.shape[-2] * nearplane.shape[-3], + *nearplane.shape[-2:], + )), images=self.xp.zeros_like( psi, shape=(*scan.shape[:-2], self.nz, self.n), @@ -202,3 +228,142 @@ def adj_all(self, nearplane, scan, probe, psi, overwrite=False, rpie=False): return apsi, patches, patches_amp, probe_amp else: return apsi, patches + + +class ConvolutionFFT(Operator): + """A 2D Convolution operator with linear interpolation. + + Compute the product two arrays at specific relative positions. + + Attributes + ---------- + nscan : int + The number of scan positions at each angular view. + probe_shape : int + The pixel width and height of the (square) probe illumination. + nz, n : int + The pixel width and height of the reconstructed grid. + ntheta : int + The number of angular partitions of the data. + + Parameters + ---------- + psi : (..., nz, n) complex64 + The complex wavefront modulation of the object. + probe : complex64 + The (..., nscan, nprobe, probe_shape, probe_shape) or + (..., 1, nprobe, probe_shape, probe_shape) complex illumination + function. + nearplane: complex64 + The (...., nscan, nprobe, probe_shape, probe_shape) + wavefronts after exiting the object. + scan : (..., nscan, 2) float32 + Coordinates of the minimum corner of the probe grid for each + measurement in the coordinate system of psi. Vertical coordinates + first, horizontal coordinates second. + + """ + + def __init__(self, probe_shape, nz, n, ntheta=None, + detector_shape=None, **kwargs): # yapf: disable + self.probe_shape = probe_shape + self.nz = nz + self.n = n + if detector_shape is None: + self.detector_shape = probe_shape + else: + self.detector_shape = detector_shape + self.pad = (self.detector_shape - self.probe_shape) // 2 + self.end = self.probe_shape + self.pad + self.patch = Patch() + self.shift = Shift() + + def __enter__(self): + self.shift.__enter__() + return self + + def __exit__(self, type, value, traceback): + self.shift.__exit__(type, value, traceback) + + def fwd(self, psi, scan, probe): + """Extract probe shaped patches from the psi at each scan position. + + The patches within the bounds of psi are linearly interpolated, and + indices outside the bounds of psi are not allowed. + """ + assert psi.shape[:-2] == scan.shape[:-2], (psi.shape, scan.shape) + assert probe.shape[:-4] == scan.shape[:-2], (probe.shape, scan.shape) + assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2] + if self.detector_shape == self.probe_shape: + patches = self.xp.empty_like( + psi, + shape=( + *scan.shape[:-2], + scan.shape[-2] * probe.shape[-3], + self.detector_shape, + self.detector_shape, + ), + ) + else: + patches = self.xp.zeros_like( + psi, + shape=( + *scan.shape[:-2], + scan.shape[-2] * probe.shape[-3], + self.detector_shape, + self.detector_shape, + ), + ) + index, shift = self.xp.divmod(scan, 1.0) + shift = shift.reshape((*scan.shape[:-1], 1, 2)) + + patches = self.patch.fwd( + patches=patches, + images=psi, + positions=index, + patch_width=self.probe_shape, + nrepeat=probe.shape[-3], + ) + + patches = patches.reshape(( + *scan.shape[:-1], + probe.shape[-3], + self.detector_shape, + self.detector_shape, + )) + patches = self.shift.adj(patches, shift, overwrite=False) + + patches[..., self.pad:self.end, self.pad:self.end] *= probe + return patches + + def adj(self, nearplane, scan, probe, psi=None, overwrite=False): + """Combine probe shaped patches into a psi shaped grid by addition.""" + assert probe.shape[:-4] == scan.shape[:-2] + assert probe.shape[-4] == 1 or probe.shape[-4] == scan.shape[-2] + assert nearplane.shape[:-3] == scan.shape[:-1] + if not overwrite: + nearplane = nearplane.copy() + nearplane[..., self.pad:self.end, self.pad:self.end] *= probe.conj() + + index, shift = self.xp.divmod(scan, 1.0) + shift = shift.reshape((*scan.shape[:-1], 1, 2)) + + nearplane = self.shift.fwd(nearplane, shift, overwrite=True) + + if psi is None: + psi = self.xp.zeros_like( + nearplane, + shape=(*scan.shape[:-2], self.nz, self.n), + ) + assert psi.shape[:-2] == scan.shape[:-2] + return self.patch.adj( + patches=nearplane.reshape(( + *scan.shape[:-2], + scan.shape[-2] * nearplane.shape[-3], + *nearplane.shape[-2:], + )), + images=psi, + positions=index, + patch_width=self.probe_shape, + nrepeat=nearplane.shape[-3], + ) diff --git a/src/tike/operators/cupy/shift.py b/src/tike/operators/cupy/shift.py index c4a91608..8bc5a728 100644 --- a/src/tike/operators/cupy/shift.py +++ b/src/tike/operators/cupy/shift.py @@ -22,7 +22,7 @@ def fwd(self, a, shift, overwrite=False, cval=None): if shift is None: return a shape = a.shape - padded = a.reshape(-1, *shape[-2:]) + padded = a.reshape(*shape) padded = self._fft2( padded, axes=(-2, -1), @@ -33,8 +33,10 @@ def fwd(self, a, shift, overwrite=False, cval=None): self.xp.fft.fftfreq(padded.shape[-2]).astype(shift.dtype), ) padded *= self.xp.exp( - -2j * self.xp.pi * - (x * shift[..., 1, None, None] + y * shift[..., 0, None, None])) + -2j + * self.xp.pi + * (x * shift[..., 1, None, None] + y * shift[..., 0, None, None]) + ) padded = self._ifft2(padded, axes=(-2, -1), overwrite_x=True) return padded.reshape(*shape) diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 26b5636c..9855f4bd 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -264,7 +264,7 @@ def estimate_global_transformation_ransac( positions1: np.ndarray, weights: np.ndarray = None, transform: AffineTransform = AffineTransform(), - min_sample: int = 4, + min_sample: int = 4, # must be 4 because we are solving for a 2x2 matrix max_error: float = 32, min_consensus: float = 0.75, max_iter: int = 20, @@ -282,6 +282,7 @@ def estimate_global_transformation_ransac( """ best_fitness = np.inf # small fitness is good # Choose a subset + # FIXME: Use montecarlo sampling to decide when to stop RANSAC instead of minimum consensus for subset in tike.random.randomizer_np.choice( a=len(positions0), size=(max_iter, min_sample), @@ -353,11 +354,14 @@ class PositionOptions: ) """A rating of the confidence of position information around each position.""" + update_start: int = 0 + """Start position updates at this epoch.""" + def __post_init__(self): self.initial_scan = self.initial_scan.astype(tike.precision.floating) if self.confidence is None: self.confidence = np.ones( - shape=(*self.initial_scan.shape[:-1], 1), + shape=self.initial_scan.shape, dtype=tike.precision.floating, ) if self.use_adaptive_moment: @@ -445,7 +449,7 @@ def join(self, other, indices): self.initial_scan = new_initial_scan if self.confidence is not None: new_confidence = np.empty( - (*self.initial_scan.shape[:-2], max_index, 1), + (*self.initial_scan.shape[:-2], max_index, 2), dtype=self.initial_scan.dtype, ) new_confidence[..., :len_scan, :] = self.confidence @@ -657,13 +661,34 @@ def _gaussian_frequency(sigma, size): return arr +def _affine_position_helper( + scan, + position_options: PositionOptions, + max_error, + relax=0.1, +): + predicted_positions = position_options.transform( + position_options.initial_scan) + err = predicted_positions - position_options.initial_scan + # constrain more the probes in flat regions + W = relax * (1 - (position_options.confidence / + (1 + position_options.confidence))) + # penalize positions that are further than max_error from origin; avoid travel larger than max error + W = cp.minimum(10 * relax, + W + cp.maximum(0, err - max_error)**2 / max_error**2) + # allow free movement in depenence on realibility and max allowed error + new_scan = scan * (1 - W) + W * predicted_positions + return new_scan + + # TODO: What is a good default value for max_error? def affine_position_regularization( comm: tike.communicators.Comm, updated: typing.List[cp.ndarray], position_options: typing.List[PositionOptions], max_error: float = 32, -) -> typing.List[PositionOptions]: + regularization_enabled: bool = False, +) -> typing.Tuple[typing.List[cp.ndarray], typing.List[PositionOptions]]: """Regularize position updates with an affine deformation constraint. Assume that the true position updates are a global affine transformation @@ -706,7 +731,15 @@ def affine_position_regularization( for i in range(len(position_options)): position_options[i].transform = new_transform - return position_options + if regularization_enabled: + updated = comm.pool.map( + _affine_position_helper, + updated, + position_options, + max_error=max_error, + ) + + return updated, position_options def gaussian_gradient( diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 267a3f92..f22b4a95 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -498,6 +498,7 @@ def iterate(self, num_iter: int) -> None: data=self.data, batches=self.batches, parameters=self.parameters, + epoch=len(self.parameters.algorithm_options.times), ) if self.parameters.object_options.positivity_constraint: @@ -566,14 +567,17 @@ def iterate(self, num_iter: int) -> None: self.parameters.eigen_weights, ) - if (self.parameters.position_options and self.parameters - .position_options[0].use_position_regularization): - - (self.parameters.position_options + if self.parameters.position_options: + ( + self.parameters.scan, + self.parameters.position_options, ) = affine_position_regularization( self.comm, updated=self.parameters.scan, position_options=self.parameters.position_options, + regularization_enabled=self.parameters.position_options[ + 0 + ].use_position_regularization, ) self.parameters.algorithm_options.times.append(time.perf_counter() - @@ -602,6 +606,13 @@ def iterate(self, num_iter: int) -> None: np.mean(self.parameters.algorithm_options.costs[-1]), ) + def get_scan(self): + reorder = np.argsort(np.concatenate(self.comm.order)) + return self.comm.pool.gather_host( + self.parameters.scan, + axis=-2, + )[reorder] + def get_result(self): """Return the current parameter estimates.""" reorder = np.argsort(np.concatenate(self.comm.order)) diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index c4c43225..e4b76fcb 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -25,6 +25,7 @@ def dm( batches: typing.List[typing.List[npt.NDArray[cp.intc]]], *, parameters: PtychoParameters, + epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using the difference map approach. diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 49d09e1a..c8d6c1da 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -2,6 +2,8 @@ import typing import cupy as cp +import cupyx.scipy.stats +import numpy as np import numpy.typing as npt import tike.communicators @@ -27,6 +29,7 @@ def lstsq_grad( batches: typing.List[npt.NDArray[cp.intc]], *, parameters: PtychoParameters, + epoch: int, ): """Solve the ptychography problem using Odstrcil et al's approach. @@ -125,6 +128,7 @@ def lstsq_grad( patches, position_update_numerator, position_update_denominator, + position_options, ) = (list(a) for a in zip(*comm.pool.map( _get_nearplane_gradients, data, @@ -136,8 +140,11 @@ def lstsq_grad( batches, position_update_numerator, position_update_denominator, + [None] * comm.pool.num_workers if position_options is + None else position_options, comm.streams, exitwave_options.measured_pixels, + object_options.preconditioner, batch_index=batch_index, num_batch=algorithm_options.num_batch, exitwave_options=exitwave_options, @@ -146,6 +153,8 @@ def lstsq_grad( recover_probe=recover_probe, recover_positions=position_options is not None, ))) + position_options = None if position_options[ + 0] is None else position_options if object_options is not None: object_upd_sum = comm.Allreduce(object_upd_sum) @@ -291,6 +300,7 @@ def lstsq_grad( position_options, position_update_numerator, position_update_denominator, + epoch=epoch, )) algorithm_options.costs.append(batch_cost) @@ -454,8 +464,10 @@ def _get_nearplane_gradients( batches, position_update_numerator, position_update_denominator, + position_options: PositionOptions, streams: typing.List[cp.cuda.Stream], measured_pixels: npt.NDArray, + object_preconditioner: npt.NDArray[cp.csingle], *, batch_index: int, num_batch: int, @@ -619,30 +631,74 @@ def keep_some_args_constant( m_probe_update = None bpatches = None - if recover_positions: + if position_options: m = 0 + # TODO: Try adjusting gradient sigma property grad_x, grad_y = tike.ptycho.position.gaussian_gradient( bpatches[blo:bhi]) + # start section to compute position certainty metric + crop = probe.shape[-1] // 4 + total_illumination = op.diffraction.patch.fwd( + images=object_preconditioner, + positions=scan[lo:hi], + patch_width=probe.shape[-1], + )[:, crop:-crop, crop:-crop].real + + power = cp.abs(probe[0, 0, 0, crop:-crop, crop:-crop])**2 + + dX = cp.mean( + cp.abs(grad_x[:, 0, 0, crop:-crop, crop:-crop]).real * + total_illumination * power, + axis=(-2, -1), + keepdims=False, + ) + dY = cp.mean( + cp.abs(grad_y[:, 0, 0, crop:-crop, crop:-crop]).real * + total_illumination * power, + axis=(-2, -1), + keepdims=False, + ) + + total_variation = cp.sqrt(cp.stack( + [dX, dY], + axis=1, + )) + mean_variation = (cp.mean( + total_variation**4, + axis=0, + ) + 1e-6) + position_options.confidence[ + lo:hi] = total_variation**4 / mean_variation + # end section to compute position certainty metric + position_update_numerator[lo:hi, ..., 0] = cp.sum( cp.real( - cp.conj(grad_x * bunique_probe[blo:bhi, ..., m:m + 1, :, :]) - * bchi[blo:bhi, ..., m:m + 1, :, :]), + cp.conj(grad_x[..., crop:-crop, crop:-crop] * + bunique_probe[blo:bhi, ..., m:m + 1, crop:-crop, + crop:-crop]) * + bchi[blo:bhi, ..., m:m + 1, crop:-crop, crop:-crop]), axis=(-4, -3, -2, -1), ) position_update_denominator[lo:hi, ..., 0] = cp.sum( - cp.abs(grad_x * bunique_probe[blo:bhi, ..., m:m + 1, :, :])**2, + cp.abs(grad_x[..., crop:-crop, crop:-crop] * + bunique_probe[blo:bhi, ..., m:m + 1, crop:-crop, + crop:-crop])**2, axis=(-4, -3, -2, -1), ) position_update_numerator[lo:hi, ..., 1] = cp.sum( cp.real( - cp.conj(grad_y * bunique_probe[blo:bhi, ..., m:m + 1, :, :]) - * bchi[blo:bhi, ..., m:m + 1, :, :]), + cp.conj(grad_y[..., crop:-crop, crop:-crop] * + bunique_probe[blo:bhi, ..., m:m + 1, crop:-crop, + crop:-crop]) * + bchi[blo:bhi, ..., m:m + 1, crop:-crop, crop:-crop]), axis=(-4, -3, -2, -1), ) position_update_denominator[lo:hi, ..., 1] = cp.sum( - cp.abs(grad_y * bunique_probe[blo:bhi, ..., m:m + 1, :, :])**2, + cp.abs(grad_y[..., crop:-crop, crop:-crop] * + bunique_probe[blo:bhi, ..., m:m + 1, crop:-crop, + crop:-crop])**2, axis=(-4, -3, -2, -1), ) @@ -666,6 +722,7 @@ def keep_some_args_constant( bpatches, position_update_numerator, position_update_denominator, + position_options, ) @@ -873,7 +930,11 @@ def _update_position( *, alpha=0.05, max_shift=1, + epoch=0, ): + if epoch < position_options.update_start: + return scan, position_options + step = (position_update_numerator) / ( (1 - alpha) * position_update_denominator + alpha * max(position_update_denominator.max(), 1e-6)) @@ -885,6 +946,9 @@ def _update_position( a_max=position_options.update_magnitude_limit, ) + # Remove outliars and subtract the mean + step = step - cupyx.scipy.stats.trim_mean(step, 0.05) + if position_options.use_adaptive_moment: ( step, diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index b5a83aa4..2b2baea3 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -1,6 +1,7 @@ import logging import cupy as cp +import cupyx.scipy.stats import numpy.typing as npt import tike.communicators @@ -27,6 +28,7 @@ def rpie( batches: typing.List[typing.List[npt.NDArray[cp.intc]]], *, parameters: PtychoParameters, + epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using regularized ptychographical engine. @@ -176,6 +178,7 @@ def rpie( position_update_denominator, max_shift=probe[0].shape[-1] * 0.1, alpha=algorithm_options.alpha, + epoch=epoch, ))) if algorithm_options.batch_method == 'compact': @@ -489,23 +492,44 @@ def keep_some_args_constant( ) # yapf: disable if position_options: - grad_x, grad_y = tike.ptycho.position.gaussian_gradient(patches) + crop = probe.shape[-1] // 4 + position_update_numerator[lo:hi, ..., 0] = cp.sum( - cp.real(cp.conj(grad_x * unique_probe) * diff), + cp.real( + cp.conj( + grad_x[..., crop:-crop, crop:-crop] + * unique_probe[..., crop:-crop, crop:-crop] + ) + * diff[..., crop:-crop, crop:-crop] + ), axis=(-4, -3, -2, -1), ) position_update_denominator[lo:hi, ..., 0] = cp.sum( - cp.abs(grad_x * unique_probe)**2, + cp.abs( + grad_x[..., crop:-crop, crop:-crop] + * unique_probe[..., crop:-crop, crop:-crop] + ) + ** 2, axis=(-4, -3, -2, -1), ) position_update_numerator[lo:hi, ..., 1] = cp.sum( - cp.real(cp.conj(grad_y * unique_probe) * diff), + cp.real( + cp.conj( + grad_y[..., crop:-crop, crop:-crop] + * unique_probe[..., crop:-crop, crop:-crop] + ) + * diff[..., crop:-crop, crop:-crop] + ), axis=(-4, -3, -2, -1), ) position_update_denominator[lo:hi, ..., 1] = cp.sum( - cp.abs(grad_y * unique_probe)**2, + cp.abs( + grad_y[..., crop:-crop, crop:-crop] + * unique_probe[..., crop:-crop, crop:-crop] + ) + ** 2, axis=(-4, -3, -2, -1), ) @@ -537,7 +561,11 @@ def _update_position( *, alpha=0.05, max_shift=1, + epoch=0, ): + if epoch < position_options.update_start: + return scan, position_options + step = (position_update_numerator) / ( (1 - alpha) * position_update_denominator + alpha * max(position_update_denominator.max(), 1e-6)) @@ -549,6 +577,9 @@ def _update_position( a_max=position_options.update_magnitude_limit, ) + # Remove outliars and subtract the mean + step = step - cupyx.scipy.stats.trim_mean(step, 0.05) + if position_options.use_adaptive_moment: ( step, diff --git a/tests/operators/test_convolution.py b/tests/operators/test_convolution.py index b32caf0b..ff5f40b4 100644 --- a/tests/operators/test_convolution.py +++ b/tests/operators/test_convolution.py @@ -5,7 +5,7 @@ import unittest import numpy as np -from tike.operators import Convolution +from tike.operators import Convolution, ConvolutionFFT import tike.precision import tike.linalg import tike.random @@ -14,7 +14,7 @@ __author__ = "Daniel Ching" __copyright__ = "Copyright (c) 2020, UChicago Argonne, LLC." -__docformat__ = 'restructuredtext en' +__docformat__ = "restructuredtext en" class TestConvolution(unittest.TestCase, OperatorTests): @@ -44,31 +44,35 @@ def setUp(self): np.random.seed(0) scan = np.random.rand(self.ntheta, self.nscan, 2) * (127 - 15 - 1) original = tike.random.numpy_complex(*self.original_shape) - nearplane = tike.random.numpy_complex(self.ntheta, self.nscan, - self.nprobe, self.detector_shape, - self.detector_shape) + nearplane = tike.random.numpy_complex( + self.ntheta, + self.nscan, + self.nprobe, + self.detector_shape, + self.detector_shape, + ) kernel = tike.random.numpy_complex(self.ntheta, self.nscan, self.nprobe, self.probe_shape, self.probe_shape) self.m = self.xp.asarray(original) - self.m_name = 'psi' + self.m_name = "psi" self.kwargs = { - 'scan': self.xp.asarray(scan, dtype=tike.precision.floating), - 'probe': self.xp.asarray(kernel) + "scan": self.xp.asarray(scan, dtype=tike.precision.floating), + "probe": self.xp.asarray(kernel), } self.m1 = self.xp.asarray(kernel) - self.m1_name = 'probe' + self.m1_name = "probe" self.kwargs1 = { - 'scan': self.xp.asarray(scan, dtype=tike.precision.floating), - 'psi': self.xp.asarray(original) + "scan": self.xp.asarray(scan, dtype=tike.precision.floating), + "psi": self.xp.asarray(original), } self.kwargs2 = { - 'scan': self.xp.asarray(scan, dtype=tike.precision.floating), + "scan": self.xp.asarray(scan, dtype=tike.precision.floating), } self.d = self.xp.asarray(nearplane) - self.d_name = 'nearplane' + self.d_name = "nearplane" print(self.operator) @@ -81,8 +85,8 @@ def test_adjoint_probe(self): a = tike.linalg.inner(d, self.d) b = tike.linalg.inner(self.m1, m) print() - print(' = {:.6f}{:+.6f}j'.format(a.real.item(), a.imag.item())) - print('< d, F*d> = {:.6f}{:+.6f}j'.format(b.real.item(), b.imag.item())) + print(" = {:.6f}{:+.6f}j".format(a.real.item(), a.imag.item())) + print("< d, F*d> = {:.6f}{:+.6f}j".format(b.real.item(), b.imag.item())) self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) @@ -93,7 +97,7 @@ def test_adj_probe_time(self): elapsed = time.perf_counter() - start print(f"\n{elapsed:1.3e} seconds") - @unittest.skip('FIXME: This operator is not scaled.') + @unittest.skip("FIXME: This operator is not scaled.") def test_scaled(self): pass @@ -121,11 +125,11 @@ def test_adjoint_all(self): b = tike.linalg.inner(self.m, m) c = tike.linalg.inner(self.m1, m1) print() - print('< Fm, m> = {:.6f}{:+.6f}j'.format(a.real.item(), + print("< Fm, m> = {:.6f}{:+.6f}j".format(a.real.item(), a.imag.item())) - print('< d0, F*d0> = {:.6f}{:+.6f}j'.format(b.real.item(), + print("< d0, F*d0> = {:.6f}{:+.6f}j".format(b.real.item(), b.imag.item())) - print('< d1, F*d1> = {:.6f}{:+.6f}j'.format(c.real.item(), + print("< d1, F*d1> = {:.6f}{:+.6f}j".format(c.real.item(), c.imag.item())) self.xp.testing.assert_allclose(a.real, b.real, rtol=1e-3, atol=0) self.xp.testing.assert_allclose(a.imag, b.imag, rtol=1e-3, atol=0) @@ -133,5 +137,69 @@ def test_adjoint_all(self): self.xp.testing.assert_allclose(a.imag, c.imag, rtol=1e-3, atol=0) -if __name__ == '__main__': +class TestConvolutionFFT(unittest.TestCase, OperatorTests): + """Test the ConvolutionFFT operator.""" + + def setUp(self): + """Load a dataset for reconstruction.""" + + self.ntheta = 3 + self.nscan = 27 + self.nprobe = 3 + self.original_shape = (self.ntheta, 128, 128) + self.probe_shape = 15 + self.detector_shape = self.probe_shape * 3 + + self.operator = ConvolutionFFT( + ntheta=self.ntheta, + nscan=self.nscan, + nz=self.original_shape[-2], + n=self.original_shape[-1], + probe_shape=self.probe_shape, + detector_shape=self.detector_shape, + ) + self.operator.__enter__() + self.xp = self.operator.xp + + np.random.seed(0) + scan = np.random.rand(self.ntheta, self.nscan, 2) * (127 - 15 - 1) + original = tike.random.numpy_complex(*self.original_shape) + nearplane = tike.random.numpy_complex( + self.ntheta, + self.nscan, + self.nprobe, + self.detector_shape, + self.detector_shape, + ) + kernel = tike.random.numpy_complex(self.ntheta, self.nscan, self.nprobe, + self.probe_shape, self.probe_shape) + + self.m = self.xp.asarray(original) + self.m_name = "psi" + self.kwargs = { + "scan": self.xp.asarray(scan, dtype=tike.precision.floating), + "probe": self.xp.asarray(kernel), + } + + self.m1 = self.xp.asarray(kernel) + self.m1_name = "probe" + self.kwargs1 = { + "scan": self.xp.asarray(scan, dtype=tike.precision.floating), + "psi": self.xp.asarray(original), + } + self.kwargs2 = { + "scan": self.xp.asarray(scan, dtype=tike.precision.floating), + } + + self.d = self.xp.asarray(nearplane) + self.d_name = "nearplane" + + print(self.operator) + + @unittest.skip("FIXME: This operator is not scaled.") + def test_scaled(self): + pass + + +if __name__ == "__main__": unittest.main()