From 6fe3330fdc3d7f83e6cfabc7f7850d0e7173da59 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 20 Jun 2024 14:46:43 -0500 Subject: [PATCH 01/31] DOC: Add missing type hints --- src/tike/cluster.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/tike/cluster.py b/src/tike/cluster.py index 484fe053..92ef6c2f 100644 --- a/src/tike/cluster.py +++ b/src/tike/cluster.py @@ -306,7 +306,10 @@ def stripes_equal_count( ) -def wobbly_center(population, num_cluster): +def wobbly_center( + population: npt.ArrayLike, + num_cluster: int, +) -> typing.List[npt.NDArray]: """Return the indices that divide population into heterogenous clusters. Uses a contrarian approach to clustering by maximizing the heterogeneity @@ -382,7 +385,7 @@ def wobbly_center(population, num_cluster): def wobbly_center_random_bootstrap( - population, + population: npt.ArrayLike, num_cluster: int, boot_fraction: float = 0.95, ) -> typing.List[npt.NDArray]: @@ -466,7 +469,11 @@ def wobbly_center_random_bootstrap( return [cp.asnumpy(xp.flatnonzero(labels == c)) for c in range(num_cluster)] -def compact(population, num_cluster, max_iter=500): +def compact( + population: npt.ArrayLike, + num_cluster: int, + max_iter: int = 500, +) -> typing.List[npt.NDArray]: """Return the indices that divide population into compact clusters. Uses an approach that is inspired by the naive k-means algorithm, but it From fb87432bf26dea8284182c62c0fdd9a156287cdc Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 20 Jun 2024 14:47:25 -0500 Subject: [PATCH 02/31] REF: Separate batch computation from batch separation --- src/tike/cluster.py | 34 ++++++++-------------------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/src/tike/cluster.py b/src/tike/cluster.py index 92ef6c2f..44149a9a 100644 --- a/src/tike/cluster.py +++ b/src/tike/cluster.py @@ -171,17 +171,15 @@ def by_scan_stripes( def by_scan_stripes_contiguous( - *args, pool: tike.communicators.ThreadPool, shape: typing.Tuple[int], - dtype: typing.List[npt.DTypeLike], - destination: typing.List[str], scan: npt.NDArray[np.float32], - fly: int = 1, batch_method, num_batch: int, -) -> typing.Tuple[typing.List[npt.NDArray], - typing.List[typing.List[npt.NDArray]]]: +) -> typing.Tuple[ + typing.List[npt.NDArray], + typing.List[typing.List[npt.NDArray]], +]: """Split data by into stripes and create contiguously ordered batches. Divide the field of view into one stripe per devices; within each stripe, @@ -206,13 +204,10 @@ def by_scan_stripes_contiguous( Returns ------- order : List[array[int]] - The locations of the inputs in the original arrays. + For each worker in pool, the indices of the data batches : List[List[array[int]]] - The locations of the elements of each batch - scan : List[array[float32]] - The divided 2D coordinates of the scan positions. - args : List[array[float32]] or None - Each input divided into regions or None if arg was None. + For each worker in pool, for each batch, the indices of the elements of + each batch """ if len(shape) != 2: @@ -247,26 +242,13 @@ def by_scan_stripes_contiguous( batch_breaks, )) - split_args = [] - for arg, t, dest in zip([scan, *args], dtype, destination): - if arg is None: - split_args.append(None) - else: - split_args.append( - pool.map( - _split_gpu if dest == 'gpu' else _split_pinned, - map_to_gpu_contiguous, - x=arg, - dtype=t, - )) - if __debug__: for device in batches_contiguous: assert len(device) == num_batch, ( f"There should be {num_batch} batches, found {len(device)}" ) - return (map_to_gpu_contiguous, batches_contiguous, *split_args) + return (map_to_gpu_contiguous, batches_contiguous) def stripes_equal_count( From 56ef8ee98506ad6393e9315de05708b76b8cbe3e Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 20 Jun 2024 14:56:20 -0500 Subject: [PATCH 03/31] REF: Add copy_to_host and copy_to_device methods --- src/tike/ptycho/exitwave.py | 12 ++--- src/tike/ptycho/object.py | 7 ++- src/tike/ptycho/position.py | 10 ++-- src/tike/ptycho/probe.py | 11 ++--- src/tike/ptycho/solvers/options.py | 73 ++++++++++++++++++++++++++++++ 5 files changed, 94 insertions(+), 19 deletions(-) diff --git a/src/tike/ptycho/exitwave.py b/src/tike/ptycho/exitwave.py index eec3d0b3..377b1b6c 100644 --- a/src/tike/ptycho/exitwave.py +++ b/src/tike/ptycho/exitwave.py @@ -78,18 +78,18 @@ class ExitWaveOptions: """ - def copy_to_device(self, comm) -> ExitWaveOptions: + def copy_to_device(self) -> ExitWaveOptions: """Copy to the current GPU memory.""" options = copy.copy(self) if self.measured_pixels is not None: - options.measured_pixels = comm.pool.bcast([self.measured_pixels]) + options.measured_pixels = cp.asarray(self.measured_pixels) return options def copy_to_host(self) -> ExitWaveOptions: """Copy to the host CPU memory.""" options = copy.copy(self) if self.measured_pixels is not None: - options.measured_pixels = cp.asnumpy(self.measured_pixels[0]) + options.measured_pixels = cp.asnumpy(self.measured_pixels) return options def resample(self, factor: float) -> ExitWaveOptions: @@ -103,9 +103,9 @@ def resample(self, factor: float) -> ExitWaveOptions: self.measured_pixels, int(self.measured_pixels.shape[-1] * factor), ), - unmeasured_pixels_scaling=self.unmeasured_pixels_scaling, - propagation_normalization=self.propagation_normalization ) - + unmeasured_pixels_scaling=self.unmeasured_pixels_scaling, + propagation_normalization=self.propagation_normalization, + ) def poisson_steplength_all_modes( xi, diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 1ca8f643..291ae777 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -80,7 +80,7 @@ class ObjectOptions: clip_magnitude: bool = False """Whether to force the object magnitude to remain <= 1.""" - def copy_to_device(self, comm) -> ObjectOptions: + def copy_to_device(self) -> ObjectOptions: """Copy to the current GPU memory.""" options = copy.copy(self) options.update_mnorm = copy.copy(self.update_mnorm) @@ -89,7 +89,7 @@ def copy_to_device(self, comm) -> ObjectOptions: if self.m is not None: options.m = cp.asarray(self.m) if self.preconditioner is not None: - options.preconditioner = comm.pool.bcast([self.preconditioner]) + options.preconditioner = cp.asarray(self.preconditioner) return options def copy_to_host(self) -> ObjectOptions: @@ -101,7 +101,7 @@ def copy_to_host(self) -> ObjectOptions: if self.m is not None: options.m = cp.asnumpy(self.m) if self.preconditioner is not None: - options.preconditioner = cp.asnumpy(self.preconditioner[0]) + options.preconditioner = cp.asnumpy(self.preconditioner) return options def resample(self, factor: float, interp) -> ObjectOptions: @@ -119,7 +119,6 @@ def resample(self, factor: float, interp) -> ObjectOptions: return options # Momentum reset to zero when grid scale changes - def positivity_constraint(x, r): """Constrains the amplitude of x to be positive with sum of abs(x) and x. diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 0f02f835..3651d939 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -413,7 +413,7 @@ def empty(self): new._momentum = np.empty((0, 4)) return new - def split(self, indices): + def split(self, indices) -> PositionOptions: """Split the PositionOption meta-data along indices.""" new = PositionOptions( self.initial_scan[..., indices, :], @@ -569,8 +569,12 @@ def check_allowed_positions(scan: np.array, psi: np.array, probe_shape: tuple): valid_min_corner = (1, 1) valid_max_corner = (psi.shape[-2] - probe_shape[-2] - 1, psi.shape[-1] - probe_shape[-1] - 1) - if (np.any(min_corner < valid_min_corner) - or np.any(max_corner > valid_max_corner)): + if ( + min_corner[0] < valid_min_corner[0] + or min_corner[1] < valid_min_corner[1] + or max_corner[0] > valid_max_corner[0] + or max_corner[1] > valid_max_corner[1] + ): raise ValueError( "Scan positions must be >= 1 and " "scan positions + 1 + probe.shape must be <= psi.shape. " diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index a5d2e847..e0134a2d 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -139,7 +139,7 @@ class ProbeOptions: """ median_filter_abs_probe: bool = False - """Binary switch on whether to apply a median filter to absolute value of + """Binary switch on whether to apply a median filter to absolute value of each shared probe mode. """ @@ -163,7 +163,7 @@ class ProbeOptions: ) """The power of the primary probe modes at each iteration.""" - def copy_to_device(self, comm) -> ProbeOptions: + def copy_to_device(self) -> ProbeOptions: """Copy to the current GPU memory.""" options = copy.copy(self) if self.v is not None: @@ -171,7 +171,7 @@ def copy_to_device(self, comm) -> ProbeOptions: if self.m is not None: options.m = cp.asarray(self.m) if self.preconditioner is not None: - options.preconditioner = comm.pool.bcast([self.preconditioner]) + options.preconditioner = cp.asarray(self.preconditioner) return options def copy_to_host(self) -> ProbeOptions: @@ -182,7 +182,7 @@ def copy_to_host(self) -> ProbeOptions: if self.m is not None: options.m = cp.asnumpy(self.m) if self.preconditioner is not None: - options.preconditioner = cp.asnumpy(self.preconditioner[0]) + options.preconditioner = cp.asnumpy(self.preconditioner) return options def resample(self, factor: float, interp) -> ProbeOptions: @@ -202,13 +202,12 @@ def resample(self, factor: float, interp) -> ProbeOptions: probe_support=self.probe_support, probe_support_degree=self.probe_support_degree, probe_support_radius=self.probe_support_radius, - median_filter_abs_probe=self.median_filter_abs_probe, + median_filter_abs_probe=self.median_filter_abs_probe, median_filter_abs_probe_px=self.median_filter_abs_probe_px, ) return options # Momentum reset to zero when grid scale changes - def get_varying_probe(shared_probe, eigen_probe=None, weights=None): """Construct the varying probes. diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index d342ad6e..b8a24f34 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -6,6 +6,7 @@ import numpy as np import numpy.typing as npt import scipy.ndimage +import cupy as cp from tike.ptycho.object import ObjectOptions from tike.ptycho.position import PositionOptions, check_allowed_positions @@ -200,6 +201,78 @@ def resample( if self.exitwave_options is not None else None, ) + def copy_to_device(self) -> PtychoParameters: + """Copy to the current device.""" + return PtychoParameters( + probe=cp.asarray(self.probe), + psi=cp.asarray(self.psi), + scan=cp.asarray(self.scan), + eigen_probe=cp.asarray(self.eigen_probe) + if self.eigen_probe is not None + else None, + eigen_weights=cp.asarray(self.eigen_weights) + if self.eigen_weights is not None + else None, + algorithm_options=self.algorithm_options, + exitwave_options=self.exitwave_options.copy_to_device() + if self.exitwave_options is not None + else None, + probe_options=self.probe_options.copy_to_device() + if self.probe_options is not None + else None, + object_options=self.object_options.copy_to_device() + if self.object_options is not None + else None, + position_options=self.position_options.copy_to_device() + if self.position_options is not None + else None, + ) + + def copy_to_host(self) -> PtychoParameters: + """Copy to the host.""" + return PtychoParameters( + probe=cp.asnumpy(self.probe), + psi=cp.asnumpy(self.psi), + scan=cp.asnumpy(self.scan), + eigen_probe=cp.asnumpy(self.eigen_probe) + if self.eigen_probe is not None + else None, + eigen_weights=cp.asnumpy(self.eigen_weights) + if self.eigen_weights is not None + else None, + algorithm_options=self.algorithm_options, + exitwave_options=self.exitwave_options.copy_to_host() + if self.exitwave_options is not None + else None, + probe_options=self.probe_options.copy_to_host() + if self.probe_options is not None + else None, + object_options=self.object_options.copy_to_host() + if self.object_options is not None + else None, + position_options=self.position_options.copy_to_host() + if self.position_options is not None + else None, + ) + + def split(self, indices: npt.NDArray[np.int]) -> PtychoParameters: + """Return a new PtychoParameters with only the data from the indices""" + return PtychoParameters( + probe=self.probe, + psi=self.psi, + scan=self.scan[indices], + eigen_probe=self.eigen_probe, + eigen_weights=self.eigen_weights[indices] + if self.eigen_weights is not None + else None, + algorithm_options=self.algorithm_options, + exitwave_options=self.exitwave_options, + probe_options=self.probe_options, + object_options=self.object_options, + position_options=self.position_options, + ) + + def _resize_spline(x: np.ndarray, f: float) -> np.ndarray: return scipy.ndimage.zoom( From 60e9de5f318ced9d7410f8bfc2ae219eedfff5fd Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 20 Jun 2024 14:58:56 -0500 Subject: [PATCH 04/31] REF: Separate communications from DM algorithm --- src/tike/ptycho/solvers/dm.py | 69 +++++++++++++++++------------------ 1 file changed, 34 insertions(+), 35 deletions(-) diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index 48b3851a..4c0f0274 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -13,16 +13,21 @@ import tike.ptycho.probe import tike.random -from .options import * +from .options import ( + ExitWaveOptions, + ObjectOptions, + ProbeOptions, + PtychoParameters, +) logger = logging.getLogger(__name__) def dm( op: tike.operators.Ptycho, - comm: tike.communicators.Comm, - data: typing.List[npt.NDArray], - batches: typing.List[typing.List[npt.NDArray[cp.intc]]], + streams: typing.List[cp.cuda.Stream], + data: npt.NDArray, + batches: typing.List[npt.NDArray[cp.intc]], *, parameters: PtychoParameters, epoch: int, @@ -35,13 +40,13 @@ def dm( A ptychography operator. comm : :py:class:`tike.communicators.Comm` An object which manages communications between GPUs and nodes. - data : list((FRAME, WIDE, HIGH) float32, ...) - A list of unique CuPy arrays for each device containing + data : (FRAME, WIDE, HIGH) float32 + A unique CuPy array containing the intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. FFT-shifted so the diffraction peak is at the corners. - batches : list(list((BATCH_SIZE, ) int, ...), ...) - A list of list of indices along the FRAME axis of `data` for + batches : list((BATCH_SIZE, ) int, ...) + A list of indices along the FRAME axis of `data` for each device which define the batches of `data` to process simultaneously. parameters : :py:class:`tike.ptycho.solvers.PtychoParameters` @@ -61,15 +66,19 @@ def dm( .. seealso:: :py:mod:`tike.ptycho` """ - psi_update_numerator = [None] * comm.pool.num_workers - probe_update_numerator = [None] * comm.pool.num_workers + assert isinstance(op, tike.operators.Operator) + assert isinstance(data, npt.ArrayLike) + assert isinstance(parameters, PtychoParameters) + assert isinstance(epoch, int) + assert isinstance(batches, list) + psi_update_numerator = None + probe_update_numerator = None ( cost, psi_update_numerator, probe_update_numerator, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_gradients, + ) = _get_nearplane_gradients( data, parameters.scan, parameters.psi, @@ -77,20 +86,19 @@ def dm( parameters.exitwave_options.measured_pixels, psi_update_numerator, probe_update_numerator, - comm.streams, + streams, op=op, object_options=parameters.object_options, probe_options=parameters.probe_options, exitwave_options=parameters.exitwave_options, - ))) + ) - cost = comm.Allreduce_mean(cost).get() + cost = cost.get() ( parameters.psi, parameters.probe, ) = _apply_update( - comm, psi_update_numerator, probe_update_numerator, parameters.psi, @@ -104,7 +112,6 @@ def dm( def _apply_update( - comm, psi_update_numerator, probe_update_numerator, psi, @@ -112,49 +119,41 @@ def _apply_update( object_options, probe_options, ): - if object_options: - psi_update_numerator = comm.Allreduce_reduce_gpu( - psi_update_numerator)[0] - - new_psi = psi_update_numerator / (object_options.preconditioner[0] + - 1e-9) + new_psi = psi_update_numerator / (object_options.preconditioner + 1e-9) if object_options.use_adaptive_moment: ( dpsi, object_options.v, object_options.m, ) = tike.opt.adam( - g=(new_psi - psi[0]), + g=(new_psi - psi), v=object_options.v, m=object_options.m, vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) - new_psi = dpsi + psi[0] - psi = comm.pool.bcast([new_psi]) + new_psi = dpsi + psi + psi = new_psi + else: + print('object update skipped') if probe_options: - - probe_update_numerator = comm.Allreduce_reduce_gpu( - probe_update_numerator)[0] - - new_probe = probe_update_numerator / (probe_options.preconditioner[0] + - 1e-9) + new_probe = probe_update_numerator / (probe_options.preconditioner + 1e-9) if probe_options.use_adaptive_moment: ( dprobe, probe_options.v, probe_options.m, ) = tike.opt.adam( - g=(new_probe - probe[0]), + g=(new_probe - probe), v=probe_options.v, m=probe_options.m, vdecay=probe_options.vdecay, mdecay=probe_options.mdecay, ) - new_probe = dprobe + probe[0] - probe = comm.pool.bcast([new_probe]) + new_probe = dprobe + probe + probe = new_probe return psi, probe From 01bc1909f6e53bade4a56f7eb8e954b2033d2616 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Mon, 24 Jun 2024 13:48:34 -0500 Subject: [PATCH 05/31] DEV: Multi-gpu DM --- src/tike/cluster.py | 9 +- src/tike/communicators/comm.py | 13 + src/tike/communicators/pool.py | 35 ++ src/tike/ptycho/position.py | 59 ++- src/tike/ptycho/ptycho.py | 567 ++++++++++----------- src/tike/ptycho/solvers/_preconditioner.py | 106 ++-- src/tike/ptycho/solvers/dm.py | 11 +- src/tike/ptycho/solvers/options.py | 68 ++- tests/communicators/test_pool.py | 17 + 9 files changed, 500 insertions(+), 385 deletions(-) diff --git a/src/tike/cluster.py b/src/tike/cluster.py index 44149a9a..18fb5876 100644 --- a/src/tike/cluster.py +++ b/src/tike/cluster.py @@ -179,6 +179,7 @@ def by_scan_stripes_contiguous( ) -> typing.Tuple[ typing.List[npt.NDArray], typing.List[typing.List[npt.NDArray]], + typing.List[int], ]: """Split data by into stripes and create contiguously ordered batches. @@ -208,7 +209,10 @@ def by_scan_stripes_contiguous( batches : List[List[array[int]]] For each worker in pool, for each batch, the indices of the elements of each batch - + stripe_start : List[int] + The coorinates of the leading edge of each stripe along the 0th + dimension in the scan coordinates. e.g the minimum coordinate of the + scan positions in each stripe. """ if len(shape) != 2: raise ValueError('The grid shape must have two dimensions.') @@ -224,6 +228,7 @@ def by_scan_stripes_contiguous( x=scan, dtype=scan.dtype, ) + stripe_start = [int(np.floor(np.min(x[:, 0]))) for x in split_scan] batches_noncontiguous: typing.List[typing.List[npt.NDArray]] = pool.map( getattr(tike.cluster, batch_method), split_scan, @@ -248,7 +253,7 @@ def by_scan_stripes_contiguous( f"There should be {num_batch} batches, found {len(device)}" ) - return (map_to_gpu_contiguous, batches_contiguous) + return (map_to_gpu_contiguous, batches_contiguous, stripe_start) def stripes_equal_count( diff --git a/src/tike/communicators/comm.py b/src/tike/communicators/comm.py index c80d7b01..058de609 100644 --- a/src/tike/communicators/comm.py +++ b/src/tike/communicators/comm.py @@ -139,3 +139,16 @@ def Allreduce( buf.append( self.mpi.Allreduce(src[self.pool.workers.index(worker)])) return buf + + def swap_edges( + self, + x: typing.List[cp.ndarray], + overlap: int, + edges: typing.List[int], + ) -> typing.List[cp.ndarray]: + # FIXME: Swap edges between MPI nodes + return self.pool.swap_edges( + x=x, + overlap=overlap, + edges=edges, + ) diff --git a/src/tike/communicators/pool.py b/src/tike/communicators/pool.py index 4a67cd4d..03b3dca2 100644 --- a/src/tike/communicators/pool.py +++ b/src/tike/communicators/pool.py @@ -404,3 +404,38 @@ def f(worker, *args): workers = self.workers if workers is None else workers return list(self.executor.map(f, workers, *iterables)) + + def swap_edges( + self, + x: typing.List[cp.ndarray], + overlap: int, + edges: typing.List[int], + ): + """Swap edge:(edge + overlap) between neighbors in-place + + For example, given overlap=1 and edges=[4, 8, 12, 16], the following + swap would be returned: + + ``` + [[0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0]] + [[1 1 1 0 0 1 1 2 2 1 1 1 1 1 1 1]] + [[2 2 2 2 2 2 2 1 1 2 2 3 3 2 2 2]] + [[3 3 3 3 3 3 3 3 3 3 3 2 2 3 3 3]] + ``` + + Note that the minimum swapped region is 2 wide. + + """ + if overlap < 1: + msg = f"Overlap for swap_edges cannot be less than 1: {overlap}" + raise ValueError(msg) + for i in range(self.num_workers - 1): + lo = edges[i + 1] + hi = lo + overlap + temp0 = self._copy_to(x[i][:, lo:hi], self.workers[i + 1]) + temp1 = self._copy_to(x[i + 1][:, lo:hi], self.workers[i]) + with self.Device(self.workers[i]): + x[i][:, lo:hi] = temp1 + with self.Device(self.workers[i + 1]): + x[i + 1][:, lo:hi] = temp0 + return x diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 3651d939..a09145e7 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -124,6 +124,7 @@ import cupy as cp import cupyx.scipy.ndimage import numpy as np +import numpy.typing as npt import tike.communicators import tike.linalg @@ -413,7 +414,7 @@ def empty(self): new._momentum = np.empty((0, 4)) return new - def split(self, indices) -> PositionOptions: + def split(self, indices: npt.NDArray[np.intc]) -> PositionOptions: """Split the PositionOption meta-data along indices.""" new = PositionOptions( self.initial_scan[..., indices, :], @@ -439,34 +440,40 @@ def insert(self, other, indices): self._momentum[..., indices, :] = other._momentum return self - def join(self, other, indices): - """Replace the PositionOption meta-data with other data.""" - len_scan = self.initial_scan.shape[-2] - max_index = max(indices.max() + 1, len_scan) - new_initial_scan = np.empty( - (*self.initial_scan.shape[:-2], max_index, 2), - dtype=self.initial_scan.dtype, + @staticmethod + def join( + x: typing.Iterable[PositionOptions | None], + reorder: npt.NDArray[np.intc], + ) -> PositionOptions | None: + if None in x: + return None + new = PositionOptions( + initial_scan=np.concatenate( + [e.initial_scan for e in x], + axis=0, + )[reorder], + use_adaptive_moment=x[0].use_adaptive_moment, + vdecay=x[0].vdecay, + mdecay=x[0].mdecay, + use_position_regularization=x[0].use_position_regularization, + update_magnitude_limit=x[0].update_magnitude_limit, + transform=x[0].transform, ) - new_initial_scan[..., :len_scan, :] = self.initial_scan - new_initial_scan[..., indices, :] = other.initial_scan - self.initial_scan = new_initial_scan - if self.confidence is not None: - new_confidence = np.empty( - (*self.initial_scan.shape[:-2], max_index, 2), - dtype=self.initial_scan.dtype, + if x[0].confidence is not None: + new.confidence = ( + np.concatenate( + [e.confidence for e in x], + axis=0, + )[reorder], ) - new_confidence[..., :len_scan, :] = self.confidence - new_confidence[..., indices, :] = other.confidence - self.confidence = new_confidence - if self.use_adaptive_moment: - new_momentum = np.empty( - (*self.initial_scan.shape[:-2], max_index, 4), - dtype=self.initial_scan.dtype, + if x[0].use_adaptive_moment: + new._momentum = ( + np.concatenate( + [e._momentum for e in x], + axis=0, + )[reorder], ) - new_momentum[..., :len_scan, :] = self._momentum - new_momentum[..., indices, :] = other._momentum - self._momentum = new_momentum - return self + return new def copy_to_device(self): """Copy to the current GPU memory.""" diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index bffc5b2a..ca959480 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -83,6 +83,7 @@ constrain_probe_sparsity, get_varying_probe, apply_median_filter_abs_probe, + orthogonalize_eig, ) logger = logging.getLogger(__name__) @@ -333,8 +334,10 @@ def __init__( else: mpi = tike.communicators.NoMPIComm - self.data = data - self.parameters = copy.deepcopy(parameters) + self.data: typing.List[npt.ArrayLike] = [data] + self.parameters: typing.List[solvers.PtychoParameters] = [ + copy.deepcopy(parameters) + ] self.device = cp.cuda.Device( num_gpu[0] if isinstance(num_gpu, tuple) else None) self.operator = tike.operators.Ptycho( @@ -352,7 +355,7 @@ def __enter__(self): self.comm.__enter__() # Divide the inputs into regions - if (not np.all(np.isfinite(self.data)) or np.any(self.data < 0)): + if not np.all(np.isfinite(self.data[0])) or np.any(self.data[0] < 0): warnings.warn( "Diffraction patterns contain invalid data. " "All data should be non-negative and finite.", UserWarning) @@ -360,322 +363,258 @@ def __enter__(self): ( self.comm.order, self.batches, - self.parameters.scan, - self.data, - self.parameters.eigen_weights, + self.comm.stripe_start, ) = tike.cluster.by_scan_stripes_contiguous( - self.data, - self.parameters.eigen_weights, - scan=self.parameters.scan, + scan=self.parameters[0].scan, pool=self.comm.pool, shape=(self.comm.pool.num_workers, 1), - dtype=( - tike.precision.floating, - tike.precision.floating - if self.data.itemsize > 2 else self.data.dtype, - tike.precision.floating, - ), - destination=('gpu', 'pinned', 'gpu'), - batch_method=self.parameters.algorithm_options.batch_method, - num_batch=self.parameters.algorithm_options.num_batch, + batch_method=self.parameters[0].algorithm_options.batch_method, + num_batch=self.parameters[0].algorithm_options.num_batch, ) - self.parameters.psi = self.comm.pool.bcast( - [self.parameters.psi.astype(tike.precision.cfloating)]) - - self.parameters.probe = self.comm.pool.bcast( - [self.parameters.probe.astype(tike.precision.cfloating)]) - - if self.parameters.probe_options is not None: - self.parameters.probe_options = self.parameters.probe_options.copy_to_device( - self.comm,) - - if self.parameters.object_options is not None: - self.parameters.object_options = self.parameters.object_options.copy_to_device( - self.comm,) - - if self.parameters.exitwave_options is not None: - self.parameters.exitwave_options = self.parameters.exitwave_options.copy_to_device( - self.comm,) - - if self.parameters.eigen_probe is not None: - self.parameters.eigen_probe = self.comm.pool.bcast( - [self.parameters.eigen_probe.astype(tike.precision.cfloating)]) + self.data = self.comm.pool.map( + tike.cluster._split_pinned, + self.comm.order, + x=self.data[0], + dtype=tike.precision.floating + if self.data[0].itemsize > 2 + else self.data[0].dtype, + ) - if self.parameters.position_options is not None: - # TODO: Consider combining put/split, get/join operations? - self.parameters.position_options = self.comm.pool.map( - PositionOptions.copy_to_device, - (self.parameters.position_options.split(x) - for x in self.comm.order), - ) + self.parameters = self.comm.pool.map( + solvers.PtychoParameters.split, + self.comm.order, + x=self.parameters[0], + ) + assert len(self.parameters) == self.comm.pool.num_workers, ( + len(self.parameters), + self.comm.pool.num_workers, + ) - if self.parameters.probe_options is not None: + self.parameters = self.comm.pool.map( + solvers.PtychoParameters.copy_to_device, + self.parameters, + ) + assert len(self.parameters) == self.comm.pool.num_workers, ( + len(self.parameters), + self.comm.pool.num_workers, + ) - if self.parameters.probe_options.init_rescale_from_measurements: - self.parameters.probe = _rescale_probe( + if self.parameters[0].probe_options is not None: + if self.parameters[0].probe_options.init_rescale_from_measurements: + self.parameters = _rescale_probe( self.operator, self.comm, self.data, - self.parameters.exitwave_options, - self.parameters.psi, - self.parameters.scan, - self.parameters.probe, - num_batch=self.parameters.algorithm_options.num_batch, + self.parameters, ) - if np.isnan(self.parameters.probe_options.probe_photons): - self.parameters.probe_options.probe_photons = np.sum( - np.abs(self.parameters.probe[0].get())**2) - return self def iterate(self, num_iter: int) -> None: """Advance the reconstruction by num_iter epochs.""" start = time.perf_counter() - psi_previous = self.parameters.psi[0].copy() + # psi_previous = self.parameters[0].psi.copy() for i in range(num_iter): if ( - np.sum(self.parameters.algorithm_options.times) - > self.parameters.algorithm_options.time_limit + np.sum(self.parameters[0].algorithm_options.times) + > self.parameters[0].algorithm_options.time_limit ): logger.info("Maximum reconstruction time exceeded.") break - logger.info(f"{self.parameters.algorithm_options.name} epoch " - f"{len(self.parameters.algorithm_options.times):,d}") - - total_epochs = len(self.parameters.algorithm_options.times) - - if self.parameters.probe_options is not None: - self.parameters.probe_options.recover_probe = ( - total_epochs >= self.parameters.probe_options.update_start - and (total_epochs % self.parameters.probe_options.update_period) == 0 - ) # yapf: disable - - if self.parameters.probe_options is not None: - if self.parameters.probe_options.recover_probe: - - if self.parameters.probe_options.median_filter_abs_probe: - self.parameters.probe = self.comm.pool.map( - apply_median_filter_abs_probe, - self.parameters.probe, - med_filt_px = self.parameters.probe_options.median_filter_abs_probe_px - ) - - if self.parameters.probe_options.force_centered_intensity: - self.parameters.probe = self.comm.pool.map( - constrain_center_peak, - self.parameters.probe, - ) - - if self.parameters.probe_options.force_sparsity < 1: - self.parameters.probe = self.comm.pool.map( - constrain_probe_sparsity, - self.parameters.probe, - f=self.parameters.probe_options.force_sparsity, - ) - - if self.parameters.probe_options.force_orthogonality: - ( - self.parameters.probe, - power, - ) = (list(a) for a in zip(*self.comm.pool.map( - tike.ptycho.probe.orthogonalize_eig, - self.parameters.probe, - ))) - else: - power = self.comm.pool.map( - tike.ptycho.probe.power, - self.parameters.probe, - ) - - self.parameters.probe_options.power.append( - power[0].get()) - - ( - self.parameters.object_options, - self.parameters.probe_options, - ) = solvers.update_preconditioners( - comm=self.comm, - operator=self.operator, - scan=self.parameters.scan, - probe=self.parameters.probe, - psi=self.parameters.psi, - object_options=self.parameters.object_options, - probe_options=self.parameters.probe_options, - ) - - self.parameters = getattr( - solvers, - self.parameters.algorithm_options.name, - )( - self.operator, - self.comm, - data=self.data, - batches=self.batches, - parameters=self.parameters, - epoch=len(self.parameters.algorithm_options.times), + logger.info( + f"{self.parameters[0].algorithm_options.name} epoch " + f"{len(self.parameters[0].algorithm_options.times):,d}" ) - if self.parameters.object_options.positivity_constraint: - self.parameters.psi = self.comm.pool.map( - tike.ptycho.object.positivity_constraint, - self.parameters.psi, - r=self.parameters.object_options.positivity_constraint, - ) + total_epochs = len(self.parameters[0].algorithm_options.times) - if self.parameters.object_options.smoothness_constraint: - self.parameters.psi = self.comm.pool.map( - tike.ptycho.object.smoothness_constraint, - self.parameters.psi, - a=self.parameters.object_options.smoothness_constraint, - ) + # if self.parameters.probe_options is not None: + # self.parameters.probe_options.recover_probe = ( + # total_epochs >= self.parameters.probe_options.update_start + # and (total_epochs % self.parameters.probe_options.update_period) == 0 + # ) # yapf: disable - if self.parameters.object_options.clip_magnitude: - self.parameters.psi = self.comm.pool.map( - _clip_magnitude, - self.parameters.psi, - a_max=1.0, - ) + self.parameters = self.comm.pool.map( + _apply_probe_constraints, + self.parameters, + ) - if ( - self.parameters.algorithm_options.name != 'dm' - and self.parameters.algorithm_options.rescale_method == 'mean_of_abs_object' - and self.parameters.object_options.preconditioner is not None - and len(self.parameters.algorithm_options.costs) % self.parameters.algorithm_options.rescale_period == 0 - ): # yapf: disable - ( - self.parameters.psi, - self.parameters.probe, - ) = (list(a) for a in zip(*self.comm.pool.map( - tike.ptycho.object.remove_object_ambiguity, - self.parameters.psi, - self.parameters.probe, - self.parameters.object_options.preconditioner, - ))) - - elif self.parameters.probe_options is not None: - if ( - self.parameters.probe_options.recover_probe - and self.parameters.algorithm_options.rescale_method == 'constant_probe_photons' - and len(self.parameters.algorithm_options.costs) % self.parameters.algorithm_options.rescale_period == 0 - ): # yapf: disable - - self.parameters.probe = self.comm.pool.map( - tike.ptycho.probe - .rescale_probe_using_fixed_intensity_photons, - self.parameters.probe, - Nphotons=self.parameters.probe_options.probe_photons, - probe_power_fraction=None, - ) + self.parameters = solvers.update_preconditioners( + comm=self.comm, + parameters=self.parameters, + operator=self.operator, + ) - if ( - self.parameters.probe_options is not None - and self.parameters.eigen_probe is not None - and self.parameters.probe_options.recover_probe - ): #yapf: disable - ( - self.parameters.eigen_probe, - self.parameters.eigen_weights, - ) = tike.ptycho.probe.constrain_variable_probe( - self.comm, - self.parameters.eigen_probe, - self.parameters.eigen_weights, - ) + self.parameters = self.comm.pool.map( + solvers.dm, + self.parameters, + self.data, + self.batches, + self.comm.streams, + op=self.operator, + epoch=len(self.parameters[0].algorithm_options.times), + ) - if self.parameters.position_options: - ( - self.parameters.scan, - self.parameters.position_options, - ) = affine_position_regularization( - self.comm, - updated=self.parameters.scan, - position_options=self.parameters.position_options, - regularization_enabled=self.parameters.position_options[ - 0 - ].use_position_regularization, + for i, reduced_probe in enumerate( + self.comm.Allreduce_mean( + [e.probe[None, ...] for e in self.parameters], + axis=0, ) - - self.parameters.algorithm_options.times.append(time.perf_counter() - - start) + ): + self.parameters[i].probe = reduced_probe + + pw = self.parameters[0].probe.shape[-2] + for swapped, parameters in zip( + # TODO: Try blending edges during swap instead of replacing + self.comm.swap_edges( + [e.psi for e in self.parameters], + # reduce overlap to stay away from edge noise + overlap=pw // 5, + # The actual edge is centered on the probe + edges=[e + pw // 2 for e in self.comm.stripe_start], + ), + self.parameters, + ): + parameters.psi = swapped + + # if self.parameters.object_options.positivity_constraint: + # self.parameters.psi = self.comm.pool.map( + # tike.ptycho.object.positivity_constraint, + # self.parameters.psi, + # r=self.parameters.object_options.positivity_constraint, + # ) + + # if self.parameters.object_options.smoothness_constraint: + # self.parameters.psi = self.comm.pool.map( + # tike.ptycho.object.smoothness_constraint, + # self.parameters.psi, + # a=self.parameters.object_options.smoothness_constraint, + # ) + + # if self.parameters.object_options.clip_magnitude: + # self.parameters.psi = self.comm.pool.map( + # _clip_magnitude, + # self.parameters.psi, + # a_max=1.0, + # ) + + # if ( + # self.parameters.algorithm_options.name != 'dm' + # and self.parameters.algorithm_options.rescale_method == 'mean_of_abs_object' + # and self.parameters.object_options.preconditioner is not None + # and len(self.parameters.algorithm_options.costs) % self.parameters.algorithm_options.rescale_period == 0 + # ): # yapf: disable + # ( + # self.parameters.psi, + # self.parameters.probe, + # ) = (list(a) for a in zip(*self.comm.pool.map( + # tike.ptycho.object.remove_object_ambiguity, + # self.parameters.psi, + # self.parameters.probe, + # self.parameters.object_options.preconditioner, + # ))) + + # elif self.parameters.probe_options is not None: + # if ( + # self.parameters.probe_options.recover_probe + # and self.parameters.algorithm_options.rescale_method == 'constant_probe_photons' + # and len(self.parameters.algorithm_options.costs) % self.parameters.algorithm_options.rescale_period == 0 + # ): # yapf: disable + + # self.parameters.probe = self.comm.pool.map( + # tike.ptycho.probe + # .rescale_probe_using_fixed_intensity_photons, + # self.parameters.probe, + # Nphotons=self.parameters.probe_options.probe_photons, + # probe_power_fraction=None, + # ) + + # if ( + # self.parameters.probe_options is not None + # and self.parameters.eigen_probe is not None + # and self.parameters.probe_options.recover_probe + # ): #yapf: disable + # ( + # self.parameters.eigen_probe, + # self.parameters.eigen_weights, + # ) = tike.ptycho.probe.constrain_variable_probe( + # self.comm, + # self.parameters.eigen_probe, + # self.parameters.eigen_weights, + # ) + + # if self.parameters.position_options: + # ( + # self.parameters.scan, + # self.parameters.position_options, + # ) = affine_position_regularization( + # self.comm, + # updated=self.parameters.scan, + # position_options=self.parameters.position_options, + # regularization_enabled=self.parameters.position_options[ + # 0 + # ].use_position_regularization, + # ) + + self.parameters[0].algorithm_options.times.append( + time.perf_counter() - start + ) start = time.perf_counter() - update_norm = tike.linalg.mnorm(self.parameters.psi[0] - - psi_previous) + # update_norm = tike.linalg.mnorm(self.parameters.psi[0] - + # psi_previous) - self.parameters.object_options.update_mnorm.append( - update_norm.get()) + # self.parameters.object_options.update_mnorm.append( + # update_norm.get()) - logger.info(f"The object update mean-norm is {update_norm:.3e}") + # logger.info(f"The object update mean-norm is {update_norm:.3e}") - if (np.mean(self.parameters.object_options.update_mnorm[-5:]) - < self.parameters.object_options.convergence_tolerance): - logger.info( - f"The object seems converged. {update_norm:.3e} < " - f"{self.parameters.object_options.convergence_tolerance:.3e}" - ) - break + # if (np.mean(self.parameters.object_options.update_mnorm[-5:]) + # < self.parameters.object_options.convergence_tolerance): + # logger.info( + # f"The object seems converged. {update_norm:.3e} < " + # f"{self.parameters.object_options.convergence_tolerance:.3e}" + # ) + # break logger.info( - '%10s cost is %+1.3e', - self.parameters.exitwave_options.noise_model, - np.mean(self.parameters.algorithm_options.costs[-1]), + "%10s cost is %+1.3e", + self.parameters[0].exitwave_options.noise_model, + np.mean(self.parameters[0].algorithm_options.costs[-1]), ) - def get_scan(self): + def get_scan(self) -> npt.NDArray: reorder = np.argsort(np.concatenate(self.comm.order)) return self.comm.pool.gather_host( self.parameters.scan, axis=-2, )[reorder] - def get_result(self): + def get_result(self) -> solvers.PtychoParameters: """Return the current parameter estimates.""" reorder = np.argsort(np.concatenate(self.comm.order)) - parameters = solvers.PtychoParameters( - probe=self.parameters.probe[0].get(), - psi=self.parameters.psi[0].get(), - scan=self.comm.pool.gather_host( - self.parameters.scan, - axis=-2, - )[reorder], - algorithm_options=self.parameters.algorithm_options, - ) - if self.parameters.eigen_probe is not None: - parameters.eigen_probe = self.parameters.eigen_probe[0].get() - - if self.parameters.eigen_weights is not None: - parameters.eigen_weights = self.comm.pool.gather( - self.parameters.eigen_weights, - axis=-3, - )[reorder].get() - - if self.parameters.probe_options is not None: - parameters.probe_options = self.parameters.probe_options.copy_to_host( - ) + assert len(self.parameters) == self.comm.pool.num_workers, ( + len(self.parameters), + self.comm.pool.num_workers, + ) - if self.parameters.object_options is not None: - parameters.object_options = self.parameters.object_options.copy_to_host( - ) + parameters = self.comm.pool.map( + solvers.PtychoParameters.copy_to_host, + self.parameters, + ) - if self.parameters.exitwave_options is not None: - parameters.exitwave_options = self.parameters.exitwave_options.copy_to_host( - ) + parameters = solvers.PtychoParameters.join( + parameters, + reorder, + stripe_start=self.comm.stripe_start, + ) - if self.parameters.position_options is not None: - host_position_options = self.parameters.position_options[0].empty() - for x, o in zip( - self.comm.pool.map( - PositionOptions.copy_to_host, - self.parameters.position_options, - ), - self.comm.order, - ): - host_position_options = host_position_options.join(x, o) - parameters.position_options = host_position_options + print(np.array(parameters.algorithm_options.costs).shape) + print(np.array(parameters.algorithm_options.times).shape) return parameters @@ -806,18 +745,54 @@ def append_new_data( new_scan, ) +def _apply_probe_constraints( + parameters: solvers.PtychoParameters, +) -> solvers.PtychoParameters: + if parameters.probe_options is not None: + if parameters.probe_options.recover_probe: + + if parameters.probe_options.median_filter_abs_probe: + parameters.probe = apply_median_filter_abs_probe( + parameters.probe, + med_filt_px=parameters.probe_options.median_filter_abs_probe_px, + ) + + if parameters.probe_options.force_centered_intensity: + parameters.probe = constrain_center_peak( + parameters.probe, + ) + + if parameters.probe_options.force_sparsity < 1: + parameters.probe = constrain_probe_sparsity( + parameters.probe, + f=parameters.probe_options.force_sparsity, + ) + + if parameters[0].probe_options.force_orthogonality: + ( + parameters.probe, + power, + ) = orthogonalize_eig( + parameters.probe, + ) + else: + power = tike.ptycho.probe.power( + parameters.probe, + ) + + parameters.probe_options.power.append(power[0].get()) + + return parameters + def _order_join(a, b): return np.append(a, b + len(a)) def _get_rescale( - data, - measured_pixels, - psi, - scan, - probe, - streams, + data: npt.ArrayLike, + parameters: solvers.PtychoParameters, + streams: typing.List[cp.cuda.Stream], *, operator: tike.operators.Ptycho, ): @@ -833,17 +808,21 @@ def make_certain_args_constant( ( data, ) = ind_args - nonlocal sums, scan + nonlocal sums intensity, _ = operator._compute_intensity( None, - psi, - scan[lo:hi], - probe, + parameters.psi, + parameters.scan[lo:hi], + parameters.probe, ) - sums[0] += cp.sum(data[:, measured_pixels], dtype=np.double) - sums[1] += cp.sum(intensity[:, measured_pixels], dtype=np.double) + sums[0] += cp.sum( + data[:, parameters.exitwave_options.measured_pixels], dtype=np.double + ) + sums[1] += cp.sum( + intensity[:, parameters.exitwave_options.measured_pixels], dtype=np.double + ) tike.communicators.stream.stream_and_modify2( f=make_certain_args_constant, @@ -858,8 +837,12 @@ def make_certain_args_constant( return sums -def _rescale_probe(operator, comm, data, exitwave_options, psi, scan, probe, - num_batch): +def _rescale_probe( + operator: tike.operators.Ptycho, + comm: tike.communicators.Comm, + data: typing.List[npt.ArrayLike], + parameters: typing.List[solvers.PtychoParameters], +): """Rescale probe so model and measured intensity are similar magnitude. Rescales the probe so that the sum of modeled intensity at the detector is @@ -869,10 +852,7 @@ def _rescale_probe(operator, comm, data, exitwave_options, psi, scan, probe, n = comm.pool.map( _get_rescale, data, - exitwave_options.measured_pixels, - psi, - scan, - probe, + parameters, comm.streams, operator=operator, ) @@ -888,9 +868,26 @@ def _rescale_probe(operator, comm, data, exitwave_options, psi, scan, probe, logger.info("Probe rescaled by %f", rescale) - probe[0] *= rescale + rescale = comm.pool.bcast([rescale]) + return comm.pool.map( + _rescale_probe_helper, + parameters, + rescale, + ) + + +def _rescale_probe_helper( + parameters: solvers.PtychoParameters, + rescale: float, +) -> solvers.PtychoParameters: + parameters.probe = parameters.probe * rescale + + if np.isnan(parameters.probe_options.probe_photons): + parameters.probe_options.probe_photons = cp.sum( + cp.square(cp.abs(parameters.probe)) + ).get() - return comm.pool.bcast([probe[0]]) + return parameters def reconstruct_multigrid( diff --git a/src/tike/ptycho/solvers/_preconditioner.py b/src/tike/ptycho/solvers/_preconditioner.py index 8e4812ed..2ad803fb 100644 --- a/src/tike/ptycho/solvers/_preconditioner.py +++ b/src/tike/ptycho/solvers/_preconditioner.py @@ -7,12 +7,27 @@ import tike.operators import tike.precision -from .options import ObjectOptions, ProbeOptions +from .options import ObjectOptions, ProbeOptions, PtychoParameters -@cp.fuse() -def _rolling_average(old, new): - return 0.5 * (new + old) +def _rolling_average_object(parameters: PtychoParameters, new): + if parameters.object_options.preconditioner is None: + parameters.object_options.preconditioner = new + else: + parameters.object_options.preconditioner = 0.5 * ( + new + parameters.object_options.preconditioner + ) + return parameters + + +def _rolling_average_probe(parameters: PtychoParameters, new): + if parameters.probe_options.preconditioner is None: + parameters.probe_options.preconditioner = new + else: + parameters.probe_options.preconditioner = 0.5 * ( + new + parameters.probe_options.preconditioner + ) + return parameters @cp.fuse() @@ -24,9 +39,7 @@ def _probe_amp_sum(probe): def _psi_preconditioner( - psi: npt.NDArray[tike.precision.cfloating], - scan: npt.NDArray[tike.precision.floating], - probe: npt.NDArray[tike.precision.cfloating], + parameters: PtychoParameters, streams: typing.List[cp.cuda.Stream], *, operator: tike.operators.Ptycho, @@ -34,8 +47,8 @@ def _psi_preconditioner( # FIXME: Generated only one preconditioner for all slices psi_update_denominator = cp.zeros( - shape=psi.shape[-2:], - dtype=psi.dtype, + shape=parameters.psi.shape[-2:], + dtype=parameters.psi.dtype, ) def make_certain_args_constant( @@ -45,11 +58,11 @@ def make_certain_args_constant( ) -> None: nonlocal psi_update_denominator - probe_amp = _probe_amp_sum(probe)[:, 0] + probe_amp = _probe_amp_sum(parameters.probe)[:, 0] psi_update_denominator[...] = operator.diffraction.patch.adj( patches=probe_amp, images=psi_update_denominator, - positions=scan[lo:hi], + positions=parameters.scan[lo:hi], ) tike.communicators.stream.stream_and_modify2( @@ -57,7 +70,7 @@ def make_certain_args_constant( ind_args=[], streams=streams, lo=0, - hi=len(scan), + hi=len(parameters.scan), ) return psi_update_denominator @@ -73,17 +86,15 @@ def _patch_amp_sum(patches): def _probe_preconditioner( - psi: npt.NDArray[tike.precision.cfloating], - scan: npt.NDArray[tike.precision.floating], - probe: npt.NDArray[tike.precision.cfloating], + parameters: PtychoParameters, streams: typing.List[cp.cuda.Stream], *, operator: tike.operators.Ptycho, ) -> npt.NDArray: probe_update_denominator = cp.zeros( - shape=probe.shape[-2:], - dtype=probe.dtype, + shape=parameters.probe.shape[-2:], + dtype=parameters.probe.dtype, ) def make_certain_args_constant( @@ -95,9 +106,9 @@ def make_certain_args_constant( # FIXME: Only use the first slice for the probe preconditioner patches = operator.diffraction.patch.fwd( - images=psi[0], - positions=scan[lo:hi], - patch_width=probe.shape[-1], + images=parameters.psi[0], + positions=parameters.scan[lo:hi], + patch_width=parameters.probe.shape[-1], ) probe_update_denominator[...] += _patch_amp_sum(patches) assert probe_update_denominator.ndim == 2 @@ -107,7 +118,7 @@ def make_certain_args_constant( ind_args=[], streams=streams, lo=0, - hi=len(scan), + hi=len(parameters.scan), ) return probe_update_denominator @@ -115,56 +126,41 @@ def make_certain_args_constant( def update_preconditioners( comm: tike.communicators.Comm, + parameters: typing.List[PtychoParameters], operator: tike.operators.Ptycho, - scan, - probe, - psi, - object_options: typing.Optional[ObjectOptions] = None, - probe_options: typing.Optional[ProbeOptions] = None, -) -> typing.Tuple[ObjectOptions, ProbeOptions]: +) -> typing.List[PtychoParameters]: """Update the probe and object preconditioners.""" - if object_options: + if parameters[0].object_options: preconditioner = comm.pool.map( _psi_preconditioner, - psi, - scan, - probe, + parameters, comm.streams, operator=operator, ) - preconditioner = comm.Allreduce(preconditioner) - - if object_options.preconditioner is None: - object_options.preconditioner = preconditioner - else: - object_options.preconditioner = comm.pool.map( - _rolling_average, - object_options.preconditioner, - preconditioner, - ) + # preconditioner = comm.Allreduce(preconditioner) - if probe_options: + parameters = comm.pool.map( + _rolling_average_object, + parameters, + preconditioner, + ) + if parameters[0].probe_options: preconditioner = comm.pool.map( _probe_preconditioner, - psi, - scan, - probe, + parameters, comm.streams, operator=operator, ) - preconditioner = comm.Allreduce(preconditioner) + # preconditioner = comm.Allreduce(preconditioner) - if probe_options.preconditioner is None: - probe_options.preconditioner = preconditioner - else: - probe_options.preconditioner = comm.pool.map( - _rolling_average, - probe_options.preconditioner, - preconditioner, - ) + parameters = comm.pool.map( + _rolling_average_probe, + parameters, + preconditioner, + ) - return object_options, probe_options + return parameters diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index 4c0f0274..dc2b80ff 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -3,6 +3,7 @@ import cupy as cp import numpy.typing as npt +import numpy as np import tike.communicators import tike.linalg @@ -24,12 +25,12 @@ def dm( - op: tike.operators.Ptycho, - streams: typing.List[cp.cuda.Stream], + parameters: PtychoParameters, data: npt.NDArray, batches: typing.List[npt.NDArray[cp.intc]], + streams: typing.List[cp.cuda.Stream], *, - parameters: PtychoParameters, + op: tike.operators.Ptycho, epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using the difference map approach. @@ -41,7 +42,7 @@ def dm( comm : :py:class:`tike.communicators.Comm` An object which manages communications between GPUs and nodes. data : (FRAME, WIDE, HIGH) float32 - A unique CuPy array containing + A CuPy pinned host memory array containing the intensity (square of the absolute value) of the propagated wavefront; i.e. what the detector records. FFT-shifted so the diffraction peak is at the corners. @@ -67,7 +68,7 @@ def dm( """ assert isinstance(op, tike.operators.Operator) - assert isinstance(data, npt.ArrayLike) + assert isinstance(data, np.ndarray), type(data) assert isinstance(parameters, PtychoParameters) assert isinstance(epoch, int) assert isinstance(batches, list) diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index b8a24f34..26c95605 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -2,6 +2,7 @@ import abc import dataclasses import typing +import copy import numpy as np import numpy.typing as npt @@ -255,23 +256,66 @@ def copy_to_host(self) -> PtychoParameters: else None, ) - def split(self, indices: npt.NDArray[np.int]) -> PtychoParameters: + @staticmethod + def split( + indices: npt.NDArray[np.intc], + *, + x: PtychoParameters, + ) -> PtychoParameters: """Return a new PtychoParameters with only the data from the indices""" return PtychoParameters( - probe=self.probe, - psi=self.psi, - scan=self.scan[indices], - eigen_probe=self.eigen_probe, - eigen_weights=self.eigen_weights[indices] - if self.eigen_weights is not None + probe=x.probe, + psi=x.psi, + scan=x.scan[indices], + eigen_probe=x.eigen_probe, + eigen_weights=x.eigen_weights[indices] + if x.eigen_weights is not None + else None, + algorithm_options=copy.deepcopy(x.algorithm_options), + exitwave_options=x.exitwave_options, + probe_options=x.probe_options, + object_options=x.object_options, + position_options=x.position_options.split(indices) + if x.position_options is not None else None, - algorithm_options=self.algorithm_options, - exitwave_options=self.exitwave_options, - probe_options=self.probe_options, - object_options=self.object_options, - position_options=self.position_options, ) + @staticmethod + def join( + x: typing.Iterable[PtychoParameters], + reorder: npt.NDArray[np.intc], + stripe_start: typing.List[int], + ) -> PtychoParameters: + joined_psi = x[0].psi + pw = x[0].probe.shape[-2] // 2 + for i in range(1, len(x)): + lo = stripe_start[i] + pw + hi = stripe_start[i + 1] + pw if i + 1 < len(x) else x[0].psi.shape[1] + joined_psi[:, lo:hi, :] = x[i].psi[:, lo:hi, :] + + return PtychoParameters( + probe=x[0].probe, + psi=joined_psi, + scan=np.concatenate( + [e.scan for e in x], + axis=0, + )[reorder], + eigen_probe=x[0].eigen_probe, + eigen_weights=np.concatenate( + [e.eigen_weights for e in x], + axis=0, + )[reorder] + if x[0].eigen_weights is not None + else None, + algorithm_options=x[0].algorithm_options, + exitwave_options=x[0].exitwave_options, + probe_options=x[0].probe_options, + object_options=x[0].object_options, + position_options=PositionOptions.join( + [e.position_options for e in x], + reorder, + ), + ) def _resize_spline(x: np.ndarray, f: float) -> np.ndarray: diff --git a/tests/communicators/test_pool.py b/tests/communicators/test_pool.py index 9f4234fe..575eb873 100644 --- a/tests/communicators/test_pool.py +++ b/tests/communicators/test_pool.py @@ -140,6 +140,23 @@ def test_reduce_mean(self): # print(result.shape, type(truth)) self.xp.testing.assert_array_equal(result, truth) + def test_swap_edges(self): + + def init(i): + return self.xp.ones((1, 4 * self.pool.num_workers), dtype=int) * i + + x = self.pool.map(init, list(range(self.pool.num_workers))) + + x1 = self.pool.swap_edges( + x, + overlap=1, + edges=np.arange(self.pool.num_workers, dtype=int) * 4, + ) + + print() + for element in x1: + print(element) + class TestSoloThreadPool(TestThreadPool): From a8ecdea9017dc8d8eaf1cca77b6697b5f6cbd71b Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Mon, 24 Jun 2024 14:36:46 -0500 Subject: [PATCH 06/31] DEV: Implement object joining function --- src/tike/ptycho/object.py | 49 ++++++++++++++++++++++++++++++ src/tike/ptycho/solvers/options.py | 13 +++----- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 291ae777..f43e3456 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -119,6 +119,55 @@ def resample(self, factor: float, interp) -> ObjectOptions: return options # Momentum reset to zero when grid scale changes + @staticmethod + def join_psi( + x: typing.Iterable[np.ndarray], + stripe_start: typing.Iterable[int], + probe_width: int, + ) -> np.ndarray: + joined_psi = x[0] + for i in range(1, len(x)): + lo = stripe_start[i] + probe_width + hi = stripe_start[i + 1] + probe_width if i + 1 < len(x) else x[0].shape[1] + joined_psi[:, lo:hi, :] = x[i][:, lo:hi, :] + return joined_psi + + @staticmethod + def join( + x: typing.Iterable[ObjectOptions], + stripe_start: typing.Iterable[int], + probe_width: int, + ) -> ObjectOptions: + options = ObjectOptions( + convergence_tolerance=x[0].convergence_tolerance, + positivity_constraint=x[0].positivity_constraint, + smoothness_constraint=x[0].smoothness_constraint, + use_adaptive_moment=x[0].use_adaptive_moment, + vdecay=x[0].vdecay, + mdecay=x[0].mdecay, + clip_magnitude=x[0].clip_magnitude, + ) + options.update_mnorm = copy.copy(x[0].update_mnorm) + if x[0].v is not None: + options.v = ObjectOptions.join_psi( + [e.v for e in x], + stripe_start, + probe_width, + ) + if x[0].m is not None: + options.m = ObjectOptions.join_psi( + [e.m for e in x], + stripe_start, + probe_width, + ) + if x[0].preconditioner is not None: + options.preconditioner = ObjectOptions.join_psi( + [e.preconditioner for e in x], + stripe_start, + probe_width, + ) + + def positivity_constraint(x, r): """Constrains the amplitude of x to be positive with sum of abs(x) and x. diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index 26c95605..57d9ca7c 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -286,16 +286,13 @@ def join( reorder: npt.NDArray[np.intc], stripe_start: typing.List[int], ) -> PtychoParameters: - joined_psi = x[0].psi - pw = x[0].probe.shape[-2] // 2 - for i in range(1, len(x)): - lo = stripe_start[i] + pw - hi = stripe_start[i + 1] + pw if i + 1 < len(x) else x[0].psi.shape[1] - joined_psi[:, lo:hi, :] = x[i].psi[:, lo:hi, :] - return PtychoParameters( probe=x[0].probe, - psi=joined_psi, + psi=ObjectOptions.join_psi( + [e.psi for e in x], + probe_width=x[0].probe.shape[-2] // 2, + stripe_start=stripe_start, + ), scan=np.concatenate( [e.scan for e in x], axis=0, From 1986196d130d467efd17ffd46e2d49a902e31e32 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 25 Jun 2024 18:37:22 -0500 Subject: [PATCH 07/31] REF: Synchronize position updates --- src/tike/ptycho/position.py | 90 ++++++++-------- src/tike/ptycho/probe.py | 15 +-- src/tike/ptycho/ptycho.py | 198 ++++++++++++++++++++++-------------- 3 files changed, 171 insertions(+), 132 deletions(-) diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index a09145e7..029a2526 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -157,7 +157,15 @@ def resample(self, factor: float) -> AffineTransform: ) @classmethod - def fromarray(self, T: np.ndarray) -> AffineTransform: + def frombuffer(cls, buffer: np.ndarray) -> AffineTransform: + return AffineTransform(*buffer) + + def asbuffer(self) -> np.ndarray: + """Return the constructor parameters in a tuple.""" + return np.array(self.astuple()) + + @classmethod + def fromarray(cls, T: np.ndarray) -> AffineTransform: """Return an Affine Transfrom from a 2x2 matrix. Use decomposition method from Graphics Gems 2 Section 7.1 @@ -181,8 +189,8 @@ def fromarray(self, T: np.ndarray) -> AffineTransform: scale1=float(scale1), shear1=float(shear1), angle=float(angle), - t0=T[2, 0] if T.shape[0] > 2 else 0, - t1=T[2, 1] if T.shape[0] > 2 else 0, + t0=float(T[2, 0] if T.shape[0] > 2 else 0), + t1=float(T[2, 1] if T.shape[0] > 2 else 0), ) def asarray(self, xp=np) -> np.ndarray: @@ -348,7 +356,10 @@ class PositionOptions: transform: AffineTransform = AffineTransform() """Global transform of positions.""" - origin: tuple[float, float] = (0, 0) + origin: npt.ArrayLike = dataclasses.field( + init=True, + default_factory=lambda: np.zeros(2), + ) """The rotation center of the transformation. This shift is applied to the scan positions before computing the global transformation.""" @@ -361,6 +372,11 @@ class PositionOptions: update_start: int = 0 """Start position updates at this epoch.""" + _momentum: np.ndarray = dataclasses.field( + init=False, + default_factory=lambda: None, + ) + def __post_init__(self): self.initial_scan = self.initial_scan.astype(tike.precision.floating) if self.confidence is None: @@ -460,25 +476,24 @@ def join( transform=x[0].transform, ) if x[0].confidence is not None: - new.confidence = ( - np.concatenate( - [e.confidence for e in x], - axis=0, - )[reorder], - ) + new.confidence = np.concatenate( + [e.confidence for e in x], + axis=0, + )[reorder] + if x[0].use_adaptive_moment: - new._momentum = ( - np.concatenate( - [e._momentum for e in x], - axis=0, - )[reorder], - ) + new._momentum = np.concatenate( + [e._momentum for e in x], + axis=0, + )[reorder] + return new def copy_to_device(self): """Copy to the current GPU memory.""" options = copy.copy(self) options.initial_scan = cp.asarray(self.initial_scan) + options.origin = cp.array(self.origin) if self.confidence is not None: options.confidence = cp.asarray(self.confidence) if self.use_adaptive_moment: @@ -489,6 +504,7 @@ def copy_to_host(self): """Copy to the host CPU memory.""" options = copy.copy(self) options.initial_scan = cp.asnumpy(self.initial_scan) + options.origin = cp.asnumpy(self.origin) if self.confidence is not None: options.confidence = cp.asnumpy(self.confidence) if self.use_adaptive_moment: @@ -691,12 +707,10 @@ def _affine_position_helper( # TODO: What is a good default value for max_error? def affine_position_regularization( - comm: tike.communicators.Comm, - updated: typing.List[cp.ndarray], - position_options: typing.List[PositionOptions], + updated: cp.ndarray, + position_options: PositionOptions, max_error: float = 32, - regularization_enabled: bool = False, -) -> typing.Tuple[typing.List[cp.ndarray], typing.List[PositionOptions]]: +) -> typing.Tuple[cp.ndarray, PositionOptions]: """Regularize position updates with an affine deformation constraint. Assume that the true position updates are a global affine transformation @@ -718,30 +732,20 @@ def affine_position_regularization( """ # Gather all of the scanning positions on one host - positions0 = comm.pool.gather_host( - [x.initial_scan for x in position_options], axis=0) - positions1 = comm.pool.gather_host(updated, axis=0) - positions0 = comm.mpi.Gather(positions0, axis=0, root=0) - positions1 = comm.mpi.Gather(positions1, axis=0, root=0) - - if comm.mpi.rank == 0: - new_transform, _ = estimate_global_transformation_ransac( - positions0=positions0 - position_options[0].origin, - positions1=positions1 - position_options[0].origin, - transform=position_options[0].transform, - max_error=max_error, - ) - else: - new_transform = None - - new_transform = comm.mpi.bcast(new_transform, root=0) + positions0 = position_options.initial_scan + positions1 = updated + + new_transform, _ = estimate_global_transformation_ransac( + positions0=positions0 - position_options.origin, + positions1=positions1 - position_options.origin, + transform=position_options.transform, + max_error=max_error, + ) - for i in range(len(position_options)): - position_options[i].transform = new_transform + position_options.transform = new_transform - if regularization_enabled: - updated = comm.pool.map( - _affine_position_helper, + if position_options.use_position_regularization: + updated = _affine_position_helper( updated, position_options, max_error=max_error, diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index e0134a2d..191b1bfa 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -293,7 +293,7 @@ def _constrain_variable_probe2(variable_probe, weights, power): return variable_probe, weights -def constrain_variable_probe(comm, variable_probe, weights): +def constrain_variable_probe(variable_probe, weights): """Add the following constraints to variable probe weights 1. Remove outliars from weights @@ -307,21 +307,16 @@ def constrain_variable_probe(comm, variable_probe, weights): # sorting and synchronizing the weights with the host OR implementing # smoothing of non-gridded data with splines using device-local data only. - variable_probe, weights, power = zip(*comm.pool.map( - _constrain_variable_probe1, + variable_probe, weights, power = _constrain_variable_probe1( variable_probe, weights, - )) - - # reduce power by sum across all devices - power = comm.pool.allreduce(power) + ) - variable_probe, weights = (list(a) for a in zip(*comm.pool.map( - _constrain_variable_probe2, + variable_probe, weights = _constrain_variable_probe2( variable_probe, weights, power, - ))) + ) return variable_probe, weights diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index ca959480..73429254 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -77,6 +77,7 @@ PositionOptions, check_allowed_positions, affine_position_regularization, + AffineTransform, ) from .probe import ( constrain_center_peak, @@ -435,7 +436,7 @@ def iterate(self, num_iter: int) -> None: # self.parameters.probe_options.recover_probe = ( # total_epochs >= self.parameters.probe_options.update_start # and (total_epochs % self.parameters.probe_options.update_period) == 0 - # ) # yapf: disable + # ) self.parameters = self.comm.pool.map( _apply_probe_constraints, @@ -458,6 +459,16 @@ def iterate(self, num_iter: int) -> None: epoch=len(self.parameters[0].algorithm_options.times), ) + self.parameters = self.comm.pool.map( + _apply_object_constraints, + self.parameters, + ) + + self.parameters = self.comm.pool.map( + _apply_position_constraints, + self.parameters, + ) + for i, reduced_probe in enumerate( self.comm.Allreduce_mean( [e.probe[None, ...] for e in self.parameters], @@ -466,6 +477,15 @@ def iterate(self, num_iter: int) -> None: ): self.parameters[i].probe = reduced_probe + if self.parameters[0].eigen_probe is not None: + for i, reduced_probe in enumerate( + self.comm.Allreduce_mean( + [e.eigen_probe[None, ...] for e in self.parameters], + axis=0, + ) + ): + self.parameters[i].eigen_probe = reduced_probe + pw = self.parameters[0].probe.shape[-2] for swapped, parameters in zip( # TODO: Try blending edges during swap instead of replacing @@ -480,84 +500,21 @@ def iterate(self, num_iter: int) -> None: ): parameters.psi = swapped - # if self.parameters.object_options.positivity_constraint: - # self.parameters.psi = self.comm.pool.map( - # tike.ptycho.object.positivity_constraint, - # self.parameters.psi, - # r=self.parameters.object_options.positivity_constraint, - # ) - - # if self.parameters.object_options.smoothness_constraint: - # self.parameters.psi = self.comm.pool.map( - # tike.ptycho.object.smoothness_constraint, - # self.parameters.psi, - # a=self.parameters.object_options.smoothness_constraint, - # ) - - # if self.parameters.object_options.clip_magnitude: - # self.parameters.psi = self.comm.pool.map( - # _clip_magnitude, - # self.parameters.psi, - # a_max=1.0, - # ) - - # if ( - # self.parameters.algorithm_options.name != 'dm' - # and self.parameters.algorithm_options.rescale_method == 'mean_of_abs_object' - # and self.parameters.object_options.preconditioner is not None - # and len(self.parameters.algorithm_options.costs) % self.parameters.algorithm_options.rescale_period == 0 - # ): # yapf: disable - # ( - # self.parameters.psi, - # self.parameters.probe, - # ) = (list(a) for a in zip(*self.comm.pool.map( - # tike.ptycho.object.remove_object_ambiguity, - # self.parameters.psi, - # self.parameters.probe, - # self.parameters.object_options.preconditioner, - # ))) - - # elif self.parameters.probe_options is not None: - # if ( - # self.parameters.probe_options.recover_probe - # and self.parameters.algorithm_options.rescale_method == 'constant_probe_photons' - # and len(self.parameters.algorithm_options.costs) % self.parameters.algorithm_options.rescale_period == 0 - # ): # yapf: disable - - # self.parameters.probe = self.comm.pool.map( - # tike.ptycho.probe - # .rescale_probe_using_fixed_intensity_photons, - # self.parameters.probe, - # Nphotons=self.parameters.probe_options.probe_photons, - # probe_power_fraction=None, - # ) - - # if ( - # self.parameters.probe_options is not None - # and self.parameters.eigen_probe is not None - # and self.parameters.probe_options.recover_probe - # ): #yapf: disable - # ( - # self.parameters.eigen_probe, - # self.parameters.eigen_weights, - # ) = tike.ptycho.probe.constrain_variable_probe( - # self.comm, - # self.parameters.eigen_probe, - # self.parameters.eigen_weights, - # ) - - # if self.parameters.position_options: - # ( - # self.parameters.scan, - # self.parameters.position_options, - # ) = affine_position_regularization( - # self.comm, - # updated=self.parameters.scan, - # position_options=self.parameters.position_options, - # regularization_enabled=self.parameters.position_options[ - # 0 - # ].use_position_regularization, - # ) + if self.parameters[0].position_options is not None: + for i, reduced_transform in enumerate( + self.comm.Allreduce_mean( + [ + e.position_options.transform.asbuffer()[None, ...] + for e in self.parameters + ], + axis=0, + ) + ): + self.parameters[ + i + ].position_options.transform = AffineTransform.frombuffer( + reduced_transform + ) self.parameters[0].algorithm_options.times.append( time.perf_counter() - start @@ -782,6 +739,89 @@ def _apply_probe_constraints( parameters.probe_options.power.append(power[0].get()) + if parameters.algorithm_options.rescale_method == "constant_probe_photons" and ( + len(parameters.algorithm_options.costs) + % parameters.algorithm_options.rescale_period + == 0 + ): + parameters.probe = ( + tike.ptycho.probe.rescale_probe_using_fixed_intensity_photons( + parameters.probe, + Nphotons=parameters.probe_options.probe_photons, + probe_power_fraction=None, + ) + ) + + if ( + parameters.eigen_probe is not None + and parameters.probe_options.recover_probe + ): + ( + parameters.eigen_probe, + parameters.eigen_weights, + ) = tike.ptycho.probe.constrain_variable_probe( + parameters.eigen_probe, + parameters.eigen_weights, + ) + + return parameters + + +def _apply_object_constraints( + parameters: solvers.PtychoParameters, +) -> solvers.PtychoParameters: + if parameters.object_options.positivity_constraint: + parameters.psi = tike.ptycho.object.positivity_constraint( + parameters.psi, + r=parameters.object_options.positivity_constraint, + ) + + if parameters.object_options.smoothness_constraint: + parameters.psi = tike.ptycho.object.smoothness_constraint( + parameters.psi, + a=parameters.object_options.smoothness_constraint, + ) + + if parameters.object_options.clip_magnitude: + parameters.psi = _clip_magnitude( + parameters.psi, + a_max=1.0, + ) + + if ( + parameters.algorithm_options.name != "dm" + and parameters.algorithm_options.rescale_method == "mean_of_abs_object" + and parameters.object_options.preconditioner is not None + and ( + len(parameters.algorithm_options.costs) + % parameters.algorithm_options.rescale_period + == 0 + ) + ): + ( + parameters.psi, + parameters.probe, + ) = tike.ptycho.object.remove_object_ambiguity( + parameters.psi, + parameters.probe, + parameters.object_options.preconditioner, + ) + + return parameters + + +def _apply_position_constraints( + parameters: solvers.PtychoParameters, +) -> solvers.PtychoParameters: + if parameters.position_options: + ( + parameters.scan, + parameters.position_options, + ) = affine_position_regularization( + updated=parameters.scan, + position_options=parameters.position_options, + ) + return parameters From 9a481c9709e1ab17679e4b631ae73f95a77976b9 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 26 Jun 2024 13:26:37 -0500 Subject: [PATCH 08/31] DEV: Refactor RPIE for new parallelism --- src/tike/ptycho/probe.py | 8 +- src/tike/ptycho/ptycho.py | 24 +- src/tike/ptycho/solvers/dm.py | 2 - src/tike/ptycho/solvers/lstsq.py | 2 +- src/tike/ptycho/solvers/rpie.py | 448 ++++++++++++++----------------- 5 files changed, 224 insertions(+), 260 deletions(-) diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index 191b1bfa..802ca94e 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -935,7 +935,13 @@ def constrain_probe_sparsity(probe, f): return probe -def finite_probe_support(probe, *, radius=0.5, degree=5, p=1.0): +def finite_probe_support( + probe, + *, + radius: float = 0.5, + degree: float = 5.0, + p: float = 1.0, +): """Returns a supergaussian penalty function for finite probe support. A mask which provides an illumination penalty is determined by the equation: diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 73429254..32ca81d6 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -501,21 +501,24 @@ def iterate(self, num_iter: int) -> None: parameters.psi = swapped if self.parameters[0].position_options is not None: - for i, reduced_transform in enumerate( - self.comm.Allreduce_mean( - [ - e.position_options.transform.asbuffer()[None, ...] - for e in self.parameters - ], - axis=0, - ) - ): + # FIXME: Synchronize across nodes + reduced_transform = np.mean( + [e.position_options.transform.asbuffer() for e in self.parameters], + axis=0, + ) + for i in range(len(self.parameters)): self.parameters[ i ].position_options.transform = AffineTransform.frombuffer( reduced_transform ) + reduced_cost = np.mean( + [e.algorithm_options.costs[-1] for e in self.parameters], + ) + for i in range(len(self.parameters)): + self.parameters[i].algorithm_options.costs[-1] = [reduced_cost] + self.parameters[0].algorithm_options.times.append( time.perf_counter() - start ) @@ -570,9 +573,6 @@ def get_result(self) -> solvers.PtychoParameters: stripe_start=self.comm.stripe_start, ) - print(np.array(parameters.algorithm_options.costs).shape) - print(np.array(parameters.algorithm_options.times).shape) - return parameters def __exit__(self, type, value, traceback): diff --git a/src/tike/ptycho/solvers/dm.py b/src/tike/ptycho/solvers/dm.py index dc2b80ff..1b469ae1 100644 --- a/src/tike/ptycho/solvers/dm.py +++ b/src/tike/ptycho/solvers/dm.py @@ -136,8 +136,6 @@ def _apply_update( ) new_psi = dpsi + psi psi = new_psi - else: - print('object update skipped') if probe_options: new_probe = probe_update_numerator / (probe_options.preconditioner + 1e-9) diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index ece6f3fd..204ff838 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -978,7 +978,7 @@ def _momentum_checked( ).real.flatten() if np.all(previous_update_correlation > 0): friction, _ = tike.opt.fit_line_least_squares( - x=np.arange(len(previous_update_correlation) + 1), + x=list(range(len(previous_update_correlation) + 1)), y=[ 0, ] + np.log(previous_update_correlation).tolist(), diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index 220cf1a7..005c3307 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -1,8 +1,10 @@ import logging +import typing import cupy as cp import cupyx.scipy.stats import numpy.typing as npt +import numpy as np import tike.communicators import tike.linalg @@ -15,19 +17,26 @@ import tike.precision import tike.random -from .options import * +from .options import ( + ExitWaveOptions, + ObjectOptions, + PositionOptions, + ProbeOptions, + PtychoParameters, + RpieOptions, +) from .lstsq import _momentum_checked logger = logging.getLogger(__name__) def rpie( - op: tike.operators.Ptycho, - comm: tike.communicators.Comm, - data: typing.List[npt.NDArray], - batches: typing.List[typing.List[npt.NDArray[cp.intc]]], - *, parameters: PtychoParameters, + data: npt.NDArray, + batches: typing.List[npt.NDArray[cp.intc]], + streams: typing.List[cp.cuda.Stream], + *, + op: tike.operators.Ptycho, epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using regularized ptychographical engine. @@ -74,144 +83,80 @@ def rpie( .. seealso:: :py:mod:`tike.ptycho` """ - probe = parameters.probe - scan = parameters.scan - psi = parameters.psi - algorithm_options = parameters.algorithm_options - exitwave_options = parameters.exitwave_options - probe_options = parameters.probe_options - if probe_options is None: - recover_probe = False - else: - recover_probe = probe_options.recover_probe - - position_options = parameters.position_options - object_options = parameters.object_options - eigen_probe = parameters.eigen_probe - eigen_weights = parameters.eigen_weights - - if eigen_probe is None: - beigen_probe = [None] * comm.pool.num_workers - else: - beigen_probe = eigen_probe - - if eigen_weights is None: - beigen_weights = [None] * comm.pool.num_workers - else: - beigen_weights = eigen_weights - if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation - psi_update_numerator = [None] * comm.pool.num_workers - probe_update_numerator = [None] * comm.pool.num_workers - position_update_numerator = [None] * comm.pool.num_workers - position_update_denominator = [None] * comm.pool.num_workers - - batch_cost: typing.List[float] = [] - for n in order(algorithm_options.num_batch): + psi_update_numerator: None | cp.ndarray = None + probe_update_numerator: None | cp.ndarray = None + position_update_numerator: None | cp.ndarray = None + position_update_denominator: None | cp.ndarray = None + for n in order(parameters.algorithm_options.num_batch): ( cost, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - beigen_weights, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_gradients, + parameters, + ) = _get_nearplane_gradients( data, - scan, - psi, - probe, - exitwave_options.measured_pixels, + parameters, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - beigen_probe, - beigen_weights, batches, - comm.streams, + streams, n=n, op=op, - object_options=object_options, - probe_options=probe_options, - recover_probe=recover_probe, - position_options=position_options, - exitwave_options=exitwave_options, - ))) - - batch_cost.append(comm.Allreduce_mean(cost, axis=None).get()) - - if algorithm_options.batch_method != 'compact': - ( - psi, - probe, - ) = _update( - comm, - psi, - probe, + epoch=epoch, + ) + + if parameters.algorithm_options.batch_method != "compact": + parameters = _update( + parameters, psi_update_numerator, probe_update_numerator, - object_options, - probe_options, - recover_probe, - algorithm_options, + recover_probe=parameters.probe_options.update_start >= epoch, ) - psi_update_numerator = [None] * comm.pool.num_workers - probe_update_numerator = [None] * comm.pool.num_workers + psi_update_numerator = None + probe_update_numerator = None - algorithm_options.costs.append(batch_cost) + parameters.algorithm_options.costs.append([cost]) - if position_options is not None: + if parameters.position_options is not None: ( - scan, - position_options, - ) = (list(a) for a in zip(*comm.pool.map( - _update_position, - scan, - position_options, + parameters.scan, + parameters.position_options, + ) = _update_position( + parameters.scan, + parameters.position_options, position_update_numerator, position_update_denominator, - max_shift=probe[0].shape[-1] * 0.1, - alpha=algorithm_options.alpha, + max_shift=parameters.probe.shape[-1] * 0.1, + alpha=parameters.algorithm_options.alpha, epoch=epoch, - ))) + ) - if algorithm_options.batch_method == 'compact': - ( - psi, - probe, - ) = _update( - comm, - psi, - probe, + if parameters.algorithm_options.batch_method == "compact": + parameters = _update( + parameters, psi_update_numerator, probe_update_numerator, - object_options, - probe_options, - recover_probe, - algorithm_options, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + recover_probe=parameters.probe_options.update_start >= epoch, + errors=list( + float(np.mean(x)) for x in parameters.algorithm_options.costs[-3:] + ), ) - if eigen_weights is not None: - eigen_weights = comm.pool.map( - _normalize_eigen_weights, - beigen_weights, + if parameters.eigen_weights is not None: + parameters.eigen_weights = _normalize_eigen_weights( + parameters.eigen_weights, ) - parameters.probe = probe - parameters.psi = psi - parameters.scan = scan - parameters.algorithm_options = algorithm_options - parameters.probe_options = probe_options - parameters.object_options = object_options - parameters.position_options = position_options - parameters.eigen_weights = eigen_weights return parameters @@ -224,147 +169,153 @@ def _normalize_eigen_weights(eigen_weights): def _update( - comm: tike.communicators.Comm, - psi: npt.NDArray[cp.csingle], - probe: npt.NDArray[cp.csingle], + parameters: PtychoParameters, psi_update_numerator: npt.NDArray[cp.csingle], probe_update_numerator: npt.NDArray[cp.csingle], - object_options: ObjectOptions, - probe_options: ProbeOptions, recover_probe: bool, - algorithm_options: RpieOptions, errors: typing.Union[None, typing.List[float]] = None, -): - if object_options: - psi_update_numerator = comm.Allreduce_reduce_gpu( - psi_update_numerator)[0] - dpsi = psi_update_numerator - deno = ( - (1 - algorithm_options.alpha) * object_options.preconditioner[0] + - algorithm_options.alpha * object_options.preconditioner[0].max( +) -> PtychoParameters: + if parameters.object_options: + dpsi = psi_update_numerator / ( + (1 - parameters.algorithm_options.alpha) + * parameters.object_options.preconditioner + + parameters.algorithm_options.alpha + * parameters.object_options.preconditioner.max( axis=(-2, -1), keepdims=True, - )) - psi[0] = psi[0] + dpsi / deno - if object_options.use_adaptive_moment: - if errors: + ) + ) + if parameters.object_options.use_adaptive_moment: + if errors is not None: ( dpsi, - object_options.v, - object_options.m, + parameters.object_options.v, + parameters.object_options.m, ) = _momentum_checked( g=dpsi, - v=object_options.v, - m=object_options.m, - mdecay=object_options.mdecay, + v=parameters.object_options.v, + m=parameters.object_options.m, + mdecay=parameters.object_options.mdecay, errors=errors, memory_length=3, ) else: ( dpsi, - object_options.v, - object_options.m, + parameters.object_options.v, + parameters.object_options.m, ) = tike.opt.adam( g=dpsi, - v=object_options.v, - m=object_options.m, - vdecay=object_options.vdecay, - mdecay=object_options.mdecay, + v=parameters.object_options.v, + m=parameters.object_options.m, + vdecay=parameters.object_options.vdecay, + mdecay=parameters.object_options.mdecay, ) - psi[0] = psi[0] + dpsi / deno - psi = comm.pool.bcast([psi[0]]) - - if recover_probe: + parameters.psi = parameters.psi + dpsi - probe_update_numerator = comm.Allreduce_reduce_gpu( - probe_update_numerator)[0] + if recover_probe and parameters.probe_options is not None: b0 = tike.ptycho.probe.finite_probe_support( - probe[0], - p=probe_options.probe_support, - radius=probe_options.probe_support_radius, - degree=probe_options.probe_support_degree, + parameters.probe, + p=parameters.probe_options.probe_support, + radius=parameters.probe_options.probe_support_radius, + degree=parameters.probe_options.probe_support_degree, + ) + b1 = ( + parameters.probe_options.additional_probe_penalty + * cp.linspace( + start=0, + stop=1, + num=parameters.probe.shape[-3], + dtype="float32", + )[..., None, None] ) - b1 = probe_options.additional_probe_penalty * cp.linspace( - 0, 1, probe[0].shape[-3], dtype='float32')[..., None, None] - dprobe = (probe_update_numerator - (b1 + b0) * probe[0]) - deno = ( - (1 - algorithm_options.alpha) * probe_options.preconditioner[0] + - algorithm_options.alpha * probe_options.preconditioner[0].max( + dprobe = (probe_update_numerator - (b1 + b0) * parameters.probe) / ( + (1 - parameters.algorithm_options.alpha) + * parameters.probe_options.preconditioner + + parameters.algorithm_options.alpha + * parameters.probe_options.preconditioner.max( axis=(-2, -1), keepdims=True, - ) + b0 + b1) - probe[0] = probe[0] + dprobe / deno - if probe_options.use_adaptive_moment: + ) + + b0 + + b1 + ) + if parameters.probe_options.use_adaptive_moment: # ptychoshelves only applies momentum to the main probe mode = 0 if errors: ( dprobe[0, 0, mode, :, :], - probe_options.v, - probe_options.m, + parameters.probe_options.v, + parameters.probe_options.m, ) = _momentum_checked( - g=(dprobe)[0, 0, mode, :, :], - v=probe_options.v, - m=probe_options.m, - mdecay=probe_options.mdecay, + g=dprobe[0, 0, mode, :, :], + v=parameters.probe_options.v, + m=parameters.probe_options.m, + mdecay=parameters.probe_options.mdecay, errors=errors, memory_length=3, ) else: ( dprobe[0, 0, mode, :, :], - probe_options.v, - probe_options.m, + parameters.probe_options.v, + parameters.probe_options.m, ) = tike.opt.adam( - g=(dprobe)[0, 0, mode, :, :], - v=probe_options.v, - m=probe_options.m, - vdecay=probe_options.vdecay, - mdecay=probe_options.mdecay, + g=dprobe[0, 0, mode, :, :], + v=parameters.probe_options.v, + m=parameters.probe_options.m, + vdecay=parameters.probe_options.vdecay, + mdecay=parameters.probe_options.mdecay, ) - probe[0] = probe[0] + dprobe / deno - probe = comm.pool.bcast([probe[0]]) + parameters.probe = parameters.probe + dprobe - return psi, probe + return parameters def _get_nearplane_gradients( data: npt.NDArray, - scan: npt.NDArray, - psi: npt.NDArray, - probe: npt.NDArray, - measured_pixels: npt.NDArray, + parameters: PtychoParameters, psi_update_numerator: typing.Union[None, npt.NDArray], probe_update_numerator: typing.Union[None, npt.NDArray], position_update_numerator: typing.Union[None, npt.NDArray], position_update_denominator: typing.Union[None, npt.NDArray], - eigen_probe: typing.Union[None, npt.NDArray], - eigen_weights: typing.Union[None, npt.NDArray], - batches: typing.List[typing.List[int]], + batches: typing.List[npt.NDArray[np.intc]], streams: typing.List[cp.cuda.Stream], *, n: int, op: tike.operators.Ptycho, - object_options: typing.Union[None, ObjectOptions] = None, - probe_options: typing.Union[None, ProbeOptions] = None, - recover_probe: bool, - position_options: typing.Union[None, PositionOptions], - exitwave_options: ExitWaveOptions, -) -> typing.List[npt.NDArray]: - - cost = 0.0 - count = 1.0 / len(batches[n]) - psi_update_numerator = cp.zeros_like( - psi) if psi_update_numerator is None else psi_update_numerator - probe_update_numerator = cp.zeros_like( - probe) if probe_update_numerator is None else probe_update_numerator - position_update_numerator = cp.empty_like( - scan - ) if position_update_numerator is None else position_update_numerator - position_update_denominator = cp.empty_like( - scan - ) if position_update_denominator is None else position_update_denominator + epoch: int, +) -> typing.Tuple[ + float, + npt.ArrayLike, + npt.ArrayLike, + npt.ArrayLike, + npt.ArrayLike, + PtychoParameters, +]: + cost = cp.zeros(1) + count = cp.array(1.0 / len(batches[n])) + psi_update_numerator = ( + cp.zeros_like(parameters.psi) + if psi_update_numerator is None + else psi_update_numerator + ) + probe_update_numerator = ( + cp.zeros_like(parameters.probe) + if probe_update_numerator is None + else probe_update_numerator + ) + position_update_numerator = ( + cp.empty_like(parameters.scan) + if position_update_numerator is None + else position_update_numerator + ) + position_update_denominator = ( + cp.empty_like(parameters.scan) + if position_update_denominator is None + else position_update_denominator + ) def keep_some_args_constant( ind_args, @@ -374,47 +325,50 @@ def keep_some_args_constant( (data,) = ind_args nonlocal cost, psi_update_numerator, probe_update_numerator nonlocal position_update_numerator, position_update_denominator - nonlocal eigen_weights, scan unique_probe = tike.ptycho.probe.get_varying_probe( - probe, - eigen_probe, - eigen_weights[lo:hi] if eigen_weights is not None else None, + parameters.probe, + parameters.eigen_probe, + parameters.eigen_weights[lo:hi] + if parameters.eigen_weights is not None + else None, ) - farplane = op.fwd(probe=unique_probe, scan=scan[lo:hi], psi=psi) + farplane = op.fwd( + probe=unique_probe, + scan=parameters.scan[lo:hi], + psi=parameters.psi, + ) intensity = cp.sum( cp.square(cp.abs(farplane)), axis=list(range(1, farplane.ndim - 2)), ) each_cost = getattr( tike.operators, - f'{exitwave_options.noise_model}_each_pattern', + f"{parameters.exitwave_options.noise_model}_each_pattern", )( - data[:, measured_pixels][:, None, :], - intensity[:, measured_pixels][:, None, :], + data[:, parameters.exitwave_options.measured_pixels][:, None, :], + intensity[:, parameters.exitwave_options.measured_pixels][:, None, :], ) cost += cp.sum(each_cost) * count - if exitwave_options.noise_model == 'poisson': - + if parameters.exitwave_options.noise_model == "poisson": xi = (1 - data / intensity)[:, None, None, :, :] grad_cost = farplane * xi step_length = cp.full( shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1), - fill_value=exitwave_options.step_length_start, + fill_value=parameters.exitwave_options.step_length_start, ) - if exitwave_options.step_length_usemodes == 'dominant_mode': - + if parameters.exitwave_options.step_length_usemodes == "dominant_mode": step_length = tike.ptycho.exitwave.poisson_steplength_dominant_mode( xi, intensity, data, - measured_pixels, + parameters.exitwave_options.measured_pixels, step_length, - exitwave_options.step_length_weight, + parameters.exitwave_options.step_length_weight, ) else: @@ -424,61 +378,67 @@ def keep_some_args_constant( cp.square(cp.abs(farplane)), intensity, data, - measured_pixels, + parameters.exitwave_options.measured_pixels, step_length, - exitwave_options.step_length_weight, + parameters.exitwave_options.step_length_weight, ) - farplane[..., measured_pixels] = (-step_length * - grad_cost)[..., measured_pixels] + farplane[..., parameters.exitwave_options.measured_pixels] = ( + -step_length * grad_cost + )[..., parameters.exitwave_options.measured_pixels] else: # Gaussian noise model for exitwave updates, steplength = 1 # TODO: optimal step lengths using 2nd order taylor expansion - farplane[..., measured_pixels] = -getattr( - tike.operators, f'{exitwave_options.noise_model}_grad')( - data, - farplane, - intensity, - )[..., measured_pixels] + farplane[..., parameters.exitwave_options.measured_pixels] = -getattr( + tike.operators, f"{parameters.exitwave_options.noise_model}_grad" + )( + data, + farplane, + intensity, + )[..., parameters.exitwave_options.measured_pixels] - unmeasured_pixels = cp.logical_not(measured_pixels) + unmeasured_pixels = cp.logical_not(parameters.exitwave_options.measured_pixels) farplane[..., unmeasured_pixels] *= ( - exitwave_options.unmeasured_pixels_scaling - 1.0) + parameters.exitwave_options.unmeasured_pixels_scaling - 1.0 + ) pad, end = op.diffraction.pad, op.diffraction.end diff = op.propagation.adj(farplane, overwrite=True)[..., pad:end, pad:end] - if object_options: - grad_psi = (cp.conj(unique_probe) * diff / probe.shape[-3]).reshape( - scan[lo:hi].shape[0] * probe.shape[-3], *probe.shape[-2:]) - psi_update_numerator[0] = op.diffraction.patch.adj( + if parameters.object_options: + grad_psi = ( + cp.conj(unique_probe) * diff / parameters.probe.shape[-3] + ).reshape( + parameters.scan[lo:hi].shape[0] * parameters.probe.shape[-3], + *parameters.probe.shape[-2:], + ) + psi_update_numerator = op.diffraction.patch.adj( patches=grad_psi, - images=psi_update_numerator[0], - positions=scan[lo:hi], - nrepeat=probe.shape[-3], + images=psi_update_numerator, + positions=parameters.scan[lo:hi], + nrepeat=parameters.probe.shape[-3], ) - if position_options or probe_options: - + if parameters.position_options or parameters.probe_options: patches = op.diffraction.patch.fwd( patches=cp.zeros_like(diff[..., 0, 0, :, :]), - images=psi[0], - positions=scan[lo:hi], + images=parameters.psi, + positions=parameters.scan[lo:hi], )[..., None, None, :, :] - if recover_probe: + if parameters.probe_options and parameters.probe_options.update_start >= epoch: probe_update_numerator += cp.sum( cp.conj(patches) * diff, axis=-5, keepdims=True, ) - if eigen_weights is not None: + if parameters.eigen_weights: m: int = 0 - OP = patches * probe[..., m:m + 1, :, :] + OP = patches * parameters.probe[..., m : m + 1, :, :] eigen_numerator = cp.sum( cp.real(cp.conj(OP) * diff[..., m:m + 1, :, :]), axis=(-1, -2), @@ -487,14 +447,14 @@ def keep_some_args_constant( cp.abs(OP)**2, axis=(-1, -2), ) - eigen_weights[lo:hi, ..., 0:1, m:m+1] += ( + parameters.eigen_weights[lo:hi, ..., 0:1, m:m+1] += ( 0.1 * (eigen_numerator / eigen_denominator) ) # yapf: disable - if position_options: + if parameters.position_options: grad_x, grad_y = tike.ptycho.position.gaussian_gradient(patches) - crop = probe.shape[-1] // 4 + crop = parameters.probe.shape[-1] // 4 position_update_numerator[lo:hi, ..., 0] = cp.sum( cp.real( @@ -544,12 +504,12 @@ def keep_some_args_constant( ) return ( - cost, + float(cost.get()), psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - eigen_weights, + parameters, ) From 438501905beeff11aaf8c71f2b268daa95bb6233 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 27 Jun 2024 12:26:09 -0500 Subject: [PATCH 09/31] DEV: Start transitioning lstsq method and fix bugs --- src/tike/opt.py | 6 +- src/tike/ptycho/ptycho.py | 2 +- src/tike/ptycho/solvers/lstsq.py | 429 +++++++++++++---------------- src/tike/ptycho/solvers/options.py | 4 +- src/tike/ptycho/solvers/rpie.py | 20 +- 5 files changed, 213 insertions(+), 248 deletions(-) diff --git a/src/tike/opt.py b/src/tike/opt.py index ac063e8e..34b7961f 100644 --- a/src/tike/opt.py +++ b/src/tike/opt.py @@ -381,9 +381,9 @@ def conjugate_gradient( def fit_line_least_squares( - y: typing.List[float], - x: typing.List[float], -) -> typing.Tuple[float, float]: + y: npt.NDArray[np.floating], + x: npt.NDArray[np.floating], +) -> typing.Tuple[np.floating, np.floating]: """Return the `slope`, `intercept` pair that best fits `y`, `x` to a line. y = slope * x + intercept diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 32ca81d6..f7d4a994 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -450,7 +450,7 @@ def iterate(self, num_iter: int) -> None: ) self.parameters = self.comm.pool.map( - solvers.dm, + getattr(solvers, self.parameters[0].algorithm_options.name), self.parameters, self.data, self.batches, diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 204ff838..e2de5116 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -17,18 +17,25 @@ import tike.ptycho.exitwave import tike.precision -from .options import * +from .options import ( + ExitWaveOptions, + ObjectOptions, + PositionOptions, + ProbeOptions, + PtychoParameters, + LstsqOptions, +) logger = logging.getLogger(__name__) def lstsq_grad( - op: tike.operators.Ptycho, - comm: tike.communicators.Comm, - data: typing.List[npt.NDArray], + parameters: PtychoParameters, + data: npt.NDArray, batches: typing.List[npt.NDArray[cp.intc]], + streams: typing.List[cp.cuda.Stream], *, - parameters: PtychoParameters, + op: tike.operators.Ptycho, epoch: int, ): """Solve the ptychography problem using Odstrcil et al's approach. @@ -69,55 +76,25 @@ def lstsq_grad( .. seealso:: :py:mod:`tike.ptycho` """ - probe = parameters.probe - scan = parameters.scan - psi = parameters.psi - - algorithm_options = parameters.algorithm_options - - probe_options = parameters.probe_options - if probe_options is None: - recover_probe = False - else: - recover_probe = probe_options.recover_probe - - position_options = parameters.position_options - object_options = parameters.object_options - exitwave_options = parameters.exitwave_options + print('helloworld3') - eigen_probe = parameters.eigen_probe - eigen_weights = parameters.eigen_weights - - position_update_numerator = [None] * comm.pool.num_workers - position_update_denominator = [None] * comm.pool.num_workers - - if eigen_probe is None: - beigen_probe = [None] * comm.pool.num_workers - else: - beigen_probe = eigen_probe - - if eigen_weights is None: - beigen_weights = [None] * comm.pool.num_workers - else: - beigen_weights = eigen_weights - - if object_options is not None: - if algorithm_options.batch_method == 'compact': - object_options.combined_update = cp.zeros_like(psi[0]) - - if recover_probe: - probe_options.probe_update_sum = cp.zeros_like(probe[0]) if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation + psi_combined_update: None | cp.ndarray = None + probe_combined_update: None | cp.ndarray = None + position_update_numerator: None | cp.ndarray = None + position_update_denominator: None | cp.ndarray = None + + recover_probe: bool = parameters.probe_options is not None + batch_cost = [] beta_object = [] beta_probe = [] - for batch_index in order(algorithm_options.num_batch): - + for batch_index in order(parameters.algorithm_options.num_batch): ( diff, unique_probe, @@ -128,59 +105,45 @@ def lstsq_grad( patches, position_update_numerator, position_update_denominator, - position_options, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_gradients, + parameters.position_options, + ) = _get_nearplane_gradients( data, - psi, - scan, - probe, - beigen_probe, - beigen_weights, + parameters.psi, + parameters.scan, + parameters.probe, + parameters.eigen_probe, + parameters.eigen_weights, batches, position_update_numerator, position_update_denominator, - [None] * comm.pool.num_workers if position_options is - None else position_options, - comm.streams, - exitwave_options.measured_pixels, - object_options.preconditioner, + parameters.position_options, + streams, + parameters.exitwave_options.measured_pixels, + parameters.object_options.preconditioner, batch_index=batch_index, - num_batch=algorithm_options.num_batch, - exitwave_options=exitwave_options, + num_batch=parameters.algorithm_options.num_batch, + exitwave_options=parameters.exitwave_options, op=op, - recover_psi=object_options is not None, + recover_psi=parameters.object_options is not None, recover_probe=recover_probe, - recover_positions=position_options is not None, - ))) - position_options = None if position_options[ - 0] is None else position_options - - if object_options is not None: - object_upd_sum = comm.Allreduce(object_upd_sum) - - if recover_probe: - m_probe_update = comm.pool.bcast( - [comm.Allreduce_mean( - m_probe_update, - axis=-5, - )]) + recover_positions=parameters.position_options is not None, + ) + if parameters.probe_options: ( - beigen_probe, - beigen_weights, + parameters.eigen_probe, + parameters.eigen_weights, ) = _update_nearplane( - comm, diff, probe_update, m_probe_update, - probe, - beigen_probe, - beigen_weights, + parameters.probe, + parameters.eigen_probe, + parameters.eigen_weights, patches, batches, batch_index=batch_index, - num_batch=algorithm_options.num_batch, + num_batch=parameters.algorithm_options.num_batch, ) ( @@ -190,40 +153,38 @@ def lstsq_grad( A4, b1, b2, - ) = (list(a) for a in zip(*comm.pool.map( - _precondition_nearplane_gradients, + ) = _precondition_nearplane_gradients( diff, - scan, + parameters.scan, unique_probe, - probe, + parameters.probe, object_upd_sum, m_probe_update, - object_options.preconditioner, + parameters.object_options.preconditioner, patches, batches, batch_index=batch_index, op=op, m=0, - recover_psi=object_options is not None, + recover_psi=parameters.object_options is not None, recover_probe=recover_probe, - probe_options=probe_options, - ))) + probe_options=parameters.probe_options, + ) - if object_options is not None: - A1_delta = comm.pool.bcast([comm.Allreduce_mean(A1, axis=-3)]) + if parameters.object_options is not None: + A1_delta = cp.mean(A1, axis=-3) else: - A1_delta = [None] * comm.pool.num_workers + A1_delta = None if recover_probe: - A4_delta = comm.pool.bcast([comm.Allreduce_mean(A4, axis=-3)]) + A4_delta = cp.mean(A4, axis=-3) else: - A4_delta = [None] * comm.pool.num_workers + A4_delta = None ( weighted_step_psi, weighted_step_probe, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_steps, + ) = _get_nearplane_steps( A1, A2, A4, @@ -231,171 +192,163 @@ def lstsq_grad( b2, A1_delta, A4_delta, - recover_psi=object_options is not None, + recover_psi=parameters.object_options is not None, recover_probe=recover_probe, m=0, - ))) + ) - if object_options is not None: - bbeta_object = comm.Allreduce_mean( + if parameters.object_options is not None: + bbeta_object = cp.mean( weighted_step_psi, axis=-5, )[..., 0, 0, 0] if recover_probe: - bbeta_probe = comm.Allreduce_mean( + bbeta_probe = cp.mean( weighted_step_probe, axis=-5, ) + print('helloworld1') + # Update each direction - if object_options is not None: - if algorithm_options.batch_method != 'compact': + if parameters.object_options is not None: + + print('helloworld') + + if parameters.algorithm_options.batch_method != "compact": # (27b) Object update - dpsi = bbeta_object[0] * object_update_precond[0] + dpsi = bbeta_object * object_update_precond - if object_options.use_adaptive_moment: + print(dpsi.shape, parameters.psi.shape) + + if parameters.object_options.use_adaptive_moment: ( dpsi, - object_options.v, - object_options.m, + parameters.object_options.v, + parameters.object_options.m, ) = tike.opt.momentum( g=dpsi, - v=object_options.v, - m=object_options.m, - vdecay=object_options.vdecay, - mdecay=object_options.mdecay, + v=parameters.object_options.v, + m=parameters.object_options.m, + vdecay=parameters.object_options.vdecay, + mdecay=parameters.object_options.mdecay, ) - psi[0] = psi[0] + dpsi - psi = comm.pool.bcast([psi[0]]) + parameters.psi = parameters.psi + dpsi else: - object_options.combined_update += object_upd_sum[0] + psi_combined_update += object_upd_sum if recover_probe: - dprobe = bbeta_probe[0] * m_probe_update[0] - probe_options.probe_update_sum += dprobe / algorithm_options.num_batch + dprobe = bbeta_probe * m_probe_update + probe_combined_update += dprobe / parameters.algorithm_options.num_batch # (27a) Probe update - probe[0] += dprobe - probe = comm.pool.bcast([probe[0]]) + parameters.probe += dprobe for c in costs: batch_cost = batch_cost + c.tolist() - if object_options is not None: + if parameters.object_options is not None: beta_object.append(bbeta_object) if recover_probe: beta_probe.append(bbeta_probe) - if eigen_probe is not None: - eigen_probe = beigen_probe - - if eigen_weights is not None: - eigen_weights = beigen_weights - - if position_options: - scan, position_options = zip(*comm.pool.map( - _update_position, - scan, - position_options, + if parameters.position_options: + parameters.scan, parameters.position_options = _update_position( + parameters.scan, + parameters.position_options, position_update_numerator, position_update_denominator, epoch=epoch, - )) + ) - algorithm_options.costs.append(batch_cost) + parameters.algorithm_options.costs.append(batch_cost) - if object_options and algorithm_options.batch_method == 'compact': + if ( + parameters.object_options + and parameters.algorithm_options.batch_method == "compact" + ): object_update_precond = _precondition_object_update( - object_options.combined_update, - object_options.preconditioner[0], + psi_combined_update, + parameters.object_options.preconditioner, ) # (27b) Object update beta_object = cp.mean(cp.stack(beta_object)) dpsi = beta_object * object_update_precond - psi[0] = psi[0] + dpsi + parameters.psi = psi + dpsi - if object_options.use_adaptive_moment: + if parameters.object_options.use_adaptive_moment: ( dpsi, - object_options.v, - object_options.m, + parameters.object_options.v, + parameters.object_options.m, ) = _momentum_checked( g=dpsi, - v=object_options.v, - m=object_options.m, - mdecay=object_options.mdecay, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + v=parameters.object_options.v, + m=parameters.object_options.m, + mdecay=parameters.object_options.mdecay, + errors=list( + float(np.mean(x)) for x in parameters.algorithm_options.costs[-3:] + ), beta=beta_object, memory_length=3, ) - weight = object_options.preconditioner[0] + weight = parameters.object_options.preconditioner weight = weight / (0.1 * weight.max() + weight) - psi[0] = psi[0] + weight * dpsi - - psi = comm.pool.bcast([psi[0]]) + parameters.psi = parameters.psi + weight * dpsi if recover_probe: - if probe_options.use_adaptive_moment: + if parameters.probe_options.use_adaptive_moment: beta_probe = cp.mean(cp.stack(beta_probe)) - dprobe = probe_options.probe_update_sum - if probe_options.v is None: - probe_options.v = np.zeros_like( + dprobe = probe_combined_update + if parameters.probe_options.v is None: + parameters.probe_options.v = np.zeros_like( dprobe, shape=(3, *dprobe.shape), ) - if probe_options.m is None: - probe_options.m = np.zeros_like(dprobe,) + if parameters.probe_options.m is None: + parameters.probe_options.m = np.zeros_like( + dprobe, + ) # ptychoshelves only applies momentum to the main probe mode = 0 ( d, - probe_options.v[..., mode, :, :], - probe_options.m[..., mode, :, :], + parameters.probe_options.v[..., mode, :, :], + parameters.probe_options.m[..., mode, :, :], ) = _momentum_checked( g=dprobe[..., mode, :, :], - v=probe_options.v[..., mode, :, :], - m=probe_options.m[..., mode, :, :], - mdecay=probe_options.mdecay, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + v=parameters.probe_options.v[..., mode, :, :], + m=parameters.probe_options.m[..., mode, :, :], + mdecay=parameters.probe_options.mdecay, + errors=list( + float(np.mean(x)) for x in parameters.algorithm_options.costs[-3:] + ), beta=beta_probe, memory_length=3, ) - probe[0][..., mode, :, :] = probe[0][..., mode, :, :] + d - probe = comm.pool.bcast([probe[0]]) - - parameters.probe = probe - parameters.psi = psi - parameters.scan = scan - parameters.algorithm_options = algorithm_options - parameters.probe_options = probe_options - parameters.object_options = object_options - parameters.position_options = position_options - parameters.eigen_weights = eigen_weights - parameters.eigen_probe = eigen_probe + parameters.probe[..., mode, :, :] = parameters.probe[..., mode, :, :] + d + return parameters def _update_nearplane( - comm: tike.communicators.Comm, - diff, - probe_update, - m_probe_update, - probe: typing.List[npt.NDArray[cp.csingle]], - eigen_probe: typing.List[npt.NDArray[cp.csingle]], - eigen_weights: typing.List[npt.NDArray[cp.single]], - patches, - batches, + diff: npt.NDArray[cp.csingle], + probe_update: npt.NDArray[cp.csingle], + m_probe_update: npt.NDArray[cp.csingle], + probe: npt.NDArray[cp.csingle], + eigen_probe: npt.NDArray[cp.csingle], + eigen_weights: npt.NDArray[cp.single], + patches: npt.NDArray[cp.csingle], + batches: typing.List[npt.NDArray[np.intc]], *, batch_index: int, num_batch: int, ): m = 0 - if eigen_weights[0] is not None: - - eigen_weights = comm.pool.map( - _get_coefs_intensity, + if eigen_weights is not None: + eigen_weights = _get_coefs_intensity( eigen_weights, diff, probe, @@ -406,23 +359,20 @@ def _update_nearplane( ) # (30) residual probe updates - if eigen_weights[0].shape[-2] > 1: - R = comm.pool.map( - _get_residuals, + if eigen_weights.shape[-2] > 1: + R = _get_residuals( probe_update, m_probe_update, m=m, ) - if eigen_probe[0] is not None and m < eigen_probe[0].shape[-3]: - assert eigen_weights[0].shape[-2] == eigen_probe[0].shape[-4] + 1 - for eigen_index in range(1, eigen_probe[0].shape[-4] + 1): - + if eigen_probe is not None and m < eigen_probe.shape[-3]: + assert eigen_weights.shape[-2] == eigen_probe.shape[-4] + 1 + for eigen_index in range(1, eigen_probe.shape[-4] + 1): ( eigen_probe, eigen_weights, ) = tike.ptycho.probe.update_eigen_probe( - comm, R, eigen_probe, eigen_weights, @@ -435,10 +385,9 @@ def _update_nearplane( m=m, ) - if eigen_index + 1 < eigen_weights[0].shape[-2]: + if eigen_index + 1 < eigen_weights.shape[-2]: # Subtract projection of R onto new probe from R - R = comm.pool.map( - _update_residuals, + R = _update_residuals( R, eigen_probe, batches, @@ -459,11 +408,11 @@ def _get_nearplane_gradients( psi: npt.NDArray[cp.csingle], scan: npt.NDArray[cp.single], probe: npt.NDArray[cp.csingle], - eigen_probe, - eigen_weights, - batches, - position_update_numerator, - position_update_denominator, + eigen_probe: npt.NDArray[cp.csingle], + eigen_weights: npt.NDArray[cp.csingle], + batches: typing.List[npt.NDArray[np.intc]], + position_update_numerator: npt.NDArray[cp.csingle], + position_update_denominator: npt.NDArray[cp.csingle], position_options: PositionOptions, streams: typing.List[cp.cuda.Stream], measured_pixels: npt.NDArray, @@ -610,12 +559,13 @@ def keep_some_args_constant( else: object_upd_sum = None + bpatches[blo:bhi] = op.diffraction.patch.fwd( + patches=cp.zeros_like(bchi[blo:bhi, ..., 0, 0, :, :]), + images=psi[0], + positions=scan[lo:hi], + )[..., None, None, :, :] + if recover_probe: - bpatches[blo:bhi] = op.diffraction.patch.fwd( - patches=cp.zeros_like(bchi[blo:bhi, ..., 0, 0, :, :]), - images=psi[0], - positions=scan[lo:hi], - )[..., None, None, :, :] # (24a) bprobe_update[blo:bhi] = cp.conj(bpatches[blo:bhi]) * bchi[blo:bhi] # (25a) Common probe gradient. Use simple average instead of @@ -629,7 +579,6 @@ def keep_some_args_constant( else: bprobe_update = None m_probe_update = None - bpatches = None if position_options: m = 0 @@ -706,23 +655,23 @@ def _precondition_object_update( def _precondition_nearplane_gradients( - nearplane, - scan, - unique_probe, - probe, - object_upd_sum, - m_probe_update, - psi_update_denominator, - patches, - batches, + nearplane: npt.NDArray[cp.csingle], + scan: npt.NDArray[cp.single], + unique_probe: npt.NDArray[cp.csingle], + probe: npt.NDArray[cp.csingle], + object_upd_sum: npt.NDArray[cp.csingle], + m_probe_update: npt.NDArray[cp.csingle], + psi_update_denominator: npt.NDArray[cp.csingle], + patches: npt.NDArray[cp.csingle], + batches: typing.List[npt.NDArray[np.intc]], *, batch_index: int, - op, - m, - recover_psi, - recover_probe, - alpha=0.05, - probe_options, + op: tike.operators.Ptycho, + m: int, + recover_psi: bool, + recover_probe: bool, + alpha: float = 0.05, + probe_options: ProbeOptions, ): lo = batches[batch_index][0] hi = lo + len(batches[batch_index]) @@ -737,6 +686,7 @@ def _precondition_nearplane_gradients( dOP = None dPO = None object_update_proj = None + object_update_precond = None if recover_psi: object_update_precond = _precondition_object_update( @@ -769,7 +719,7 @@ def _precondition_nearplane_gradients( 1, probe[0].shape[-3], dtype=tike.precision.floating, - )[..., m : m + 1, None, None] + )[..., m : m + 1, None, None] # type: ignore ) m_probe_update = m_probe_update - (b0 + b1) * probe[..., m : m + 1, :, :] @@ -787,16 +737,16 @@ def _precondition_nearplane_gradients( dPO = m_probe_update[..., m:m + 1, :, :] * patches A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1)) - if recover_psi and recover_probe: + if dOP is not None and dPO is not None: b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) A2 = cp.sum((dOP * dPO.conj()), axis=(-2, -1)) - elif recover_psi: + elif dOP is not None: b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) - elif recover_probe: + elif dPO is not None: b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) @@ -810,9 +760,18 @@ def _precondition_nearplane_gradients( ) -def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi, - recover_probe, m): - +def _get_nearplane_steps( + A1, + A2, + A4, + b1, + b2, + A1_delta, + A4_delta, + recover_psi, + recover_probe, + m, +) -> typing.Tuple[npt.NDArray | None, npt.NDArray | None]: if recover_psi: A1 += 0.5 * A1_delta if recover_probe: @@ -832,7 +791,7 @@ def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi, x1 = None x2 = None - if recover_psi: + if x1 is not None: step = 0.9 * cp.maximum(0, x1[..., None, None].real) # (27b) Object update @@ -840,7 +799,7 @@ def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi, else: beta_object = None - if recover_probe: + if x2 is not None: step = 0.9 * cp.maximum(0, x2[..., None, None].real) beta_probe = cp.mean(step, axis=-5, keepdims=True) @@ -902,7 +861,7 @@ def _update_position( alpha=0.05, max_shift=1, epoch=0, -): +) -> typing.Tuple[npt.NDArray, PositionOptions]: if epoch < position_options.update_start: return scan, position_options @@ -978,10 +937,10 @@ def _momentum_checked( ).real.flatten() if np.all(previous_update_correlation > 0): friction, _ = tike.opt.fit_line_least_squares( - x=list(range(len(previous_update_correlation) + 1)), - y=[ - 0, - ] + np.log(previous_update_correlation).tolist(), + x=np.arange(len(previous_update_correlation) + 1, dtype=np.floating), + y=np.array( + [0, + ] + np.log(previous_update_correlation).tolist()), ) friction = 0.5 * max(-friction, 0) m = (1 - friction) * m + g diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index 57d9ca7c..3de25d87 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -134,7 +134,7 @@ class PtychoParameters(): default_factory=RpieOptions,) """A class containing algorithm specific parameters""" - exitwave_options: typing.Union[ExitWaveOptions, None] = None + exitwave_options: ExitWaveOptions = None """A class containing settings related to exitwave updates.""" probe_options: typing.Union[ProbeOptions, None] = None @@ -304,8 +304,10 @@ def join( )[reorder] if x[0].eigen_weights is not None else None, + # TODO: costs and times should be joined somehow? algorithm_options=x[0].algorithm_options, exitwave_options=x[0].exitwave_options, + # TODO: synchronize probe momentum elsewhere probe_options=x[0].probe_options, object_options=x[0].object_options, position_options=PositionOptions.join( diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index 005c3307..ccd17e16 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -176,7 +176,8 @@ def _update( errors: typing.Union[None, typing.List[float]] = None, ) -> PtychoParameters: if parameters.object_options: - dpsi = psi_update_numerator / ( + dpsi = psi_update_numerator + deno = ( (1 - parameters.algorithm_options.alpha) * parameters.object_options.preconditioner + parameters.algorithm_options.alpha @@ -185,6 +186,7 @@ def _update( keepdims=True, ) ) + parameters.psi = parameters.psi + dpsi / deno if parameters.object_options.use_adaptive_moment: if errors is not None: ( @@ -211,7 +213,7 @@ def _update( vdecay=parameters.object_options.vdecay, mdecay=parameters.object_options.mdecay, ) - parameters.psi = parameters.psi + dpsi + parameters.psi = parameters.psi + dpsi / deno if recover_probe and parameters.probe_options is not None: b0 = tike.ptycho.probe.finite_probe_support( @@ -229,7 +231,8 @@ def _update( dtype="float32", )[..., None, None] ) - dprobe = (probe_update_numerator - (b1 + b0) * parameters.probe) / ( + dprobe = (probe_update_numerator - (b1 + b0) * parameters.probe) + deno = ( (1 - parameters.algorithm_options.alpha) * parameters.probe_options.preconditioner + parameters.algorithm_options.alpha @@ -240,6 +243,7 @@ def _update( + b0 + b1 ) + parameters.probe = parameters.probe + dprobe / deno if parameters.probe_options.use_adaptive_moment: # ptychoshelves only applies momentum to the main probe mode = 0 @@ -268,7 +272,7 @@ def _update( vdecay=parameters.probe_options.vdecay, mdecay=parameters.probe_options.mdecay, ) - parameters.probe = parameters.probe + dprobe + parameters.probe = parameters.probe + dprobe / deno return parameters @@ -416,9 +420,9 @@ def keep_some_args_constant( parameters.scan[lo:hi].shape[0] * parameters.probe.shape[-3], *parameters.probe.shape[-2:], ) - psi_update_numerator = op.diffraction.patch.adj( + psi_update_numerator[0] = op.diffraction.patch.adj( patches=grad_psi, - images=psi_update_numerator, + images=psi_update_numerator[0], positions=parameters.scan[lo:hi], nrepeat=parameters.probe.shape[-3], ) @@ -426,7 +430,7 @@ def keep_some_args_constant( if parameters.position_options or parameters.probe_options: patches = op.diffraction.patch.fwd( patches=cp.zeros_like(diff[..., 0, 0, :, :]), - images=parameters.psi, + images=parameters.psi[0], positions=parameters.scan[lo:hi], )[..., None, None, :, :] @@ -436,7 +440,7 @@ def keep_some_args_constant( axis=-5, keepdims=True, ) - if parameters.eigen_weights: + if parameters.eigen_weights is not None: m: int = 0 OP = patches * parameters.probe[..., m : m + 1, :, :] eigen_numerator = cp.sum( From f6c9c182cc71b1cf0f988317f88836affe4efb1e Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 27 Jun 2024 12:26:26 -0500 Subject: [PATCH 10/31] DEV: Transition probe variable updates --- src/tike/ptycho/probe.py | 53 ++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index 802ca94e..fce5851a 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -405,7 +405,6 @@ def _get_weights_mean(n, d, d_mean, weights, batches, *, batch_index, c, m): def update_eigen_probe( - comm, R, eigen_probe, weights, @@ -453,14 +452,13 @@ def update_eigen_probe( least-squares solver for generalized maximum-likelihood ptychography. Optics Express. 2018. """ - assert R[0].shape[-3] == R[0].shape[-4] == 1 - assert 1 == eigen_probe[0].shape[-5] - assert R[0].shape[:-5] == eigen_probe[0].shape[:-5] == weights[0].shape[:-3] - assert weights[0][batches[0][batch_index], :, :].shape[-3] == R[0].shape[-5] - assert R[0].shape[-2:] == eigen_probe[0].shape[-2:] - - update = comm.pool.map( - _get_update, + assert R.shape[-3] == R.shape[-4] == 1 + assert 1 == eigen_probe.shape[-5] + assert R.shape[:-5] == eigen_probe.shape[:-5] == weights.shape[:-3] + assert weights[batches[batch_index], :, :].shape[-3] == R.shape[-5] + assert R.shape[-2:] == eigen_probe.shape[-2:] + + update = _get_update( R, eigen_probe, weights, @@ -469,13 +467,12 @@ def update_eigen_probe( c=c, m=m, ) - update = comm.pool.bcast([comm.Allreduce_mean( + update = cp.mean( update, axis=-5, - )]) + ) - (eigen_probe, n, d, d_mean) = (list(a) for a in zip(*comm.pool.map( - _get_d, + (eigen_probe, n, d, d_mean) = _get_d( patches, diff, eigen_probe, @@ -483,25 +480,23 @@ def update_eigen_probe( β=β, c=c, m=m, - ))) + ) - d_mean = comm.pool.bcast([comm.Allreduce_mean( + d_mean = cp.mean( d_mean, axis=-3, - )]) - - weights = list( - comm.pool.map( - _get_weights_mean, - n, - d, - d_mean, - weights, - batches, - batch_index=batch_index, - c=c, - m=m, - )) + ) + + weights = _get_weights_mean( + n, + d, + d_mean, + weights, + batches, + batch_index=batch_index, + c=c, + m=m, + ) return eigen_probe, weights From 281dd7b27219c311d7374874c3134d5cc696713d Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 27 Jun 2024 16:37:44 -0500 Subject: [PATCH 11/31] DEV: Strip comm from inside rpie solver --- src/tike/ptycho/solvers/rpie.py | 188 ++++++++++++++++++-------------- 1 file changed, 104 insertions(+), 84 deletions(-) diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index 220cf1a7..bf4321ef 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -1,7 +1,9 @@ import logging +import typing import cupy as cp import cupyx.scipy.stats +import numpy as np import numpy.typing as npt import tike.communicators @@ -74,67 +76,73 @@ def rpie( .. seealso:: :py:mod:`tike.ptycho` """ - probe = parameters.probe - scan = parameters.scan - psi = parameters.psi + data0 = data[0] + batches0 = batches[0] + + probe = parameters.probe[0] + scan = parameters.scan[0] + psi = parameters.psi[0] algorithm_options = parameters.algorithm_options exitwave_options = parameters.exitwave_options probe_options = parameters.probe_options - if probe_options is None: - recover_probe = False - else: - recover_probe = probe_options.recover_probe + recover_probe = True - position_options = parameters.position_options + if parameters.position_options is None: + position_options = None + else: + position_options = parameters.position_options[0] object_options = parameters.object_options - eigen_probe = parameters.eigen_probe - eigen_weights = parameters.eigen_weights - - if eigen_probe is None: - beigen_probe = [None] * comm.pool.num_workers + if parameters.eigen_probe is None: + eigen_probe = None else: - beigen_probe = eigen_probe - - if eigen_weights is None: - beigen_weights = [None] * comm.pool.num_workers + eigen_probe = parameters.eigen_probe[0] + if parameters.eigen_weights is None: + eigen_weights = None else: - beigen_weights = eigen_weights + eigen_weights = parameters.eigen_weights[0] + + measured_pixels = exitwave_options.measured_pixels[0] + + streams0 = comm.streams[0] + + probe_options.preconditioner = probe_options.preconditioner[0] + object_options.preconditioner = object_options.preconditioner[0] + + # CONVERSTION AREA ABOVE --------------------------------------- if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation - psi_update_numerator = [None] * comm.pool.num_workers - probe_update_numerator = [None] * comm.pool.num_workers - position_update_numerator = [None] * comm.pool.num_workers - position_update_denominator = [None] * comm.pool.num_workers + psi_update_numerator = None + probe_update_numerator = None + position_update_numerator = None + position_update_denominator = None batch_cost: typing.List[float] = [] for n in order(algorithm_options.num_batch): - ( cost, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - beigen_weights, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_gradients, - data, + eigen_weights, + ) = _get_nearplane_gradients( + data0, scan, psi, probe, - exitwave_options.measured_pixels, + measured_pixels, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - beigen_probe, - beigen_weights, - batches, - comm.streams, + eigen_probe, + eigen_weights, + batches0, + streams0, n=n, op=op, object_options=object_options, @@ -142,16 +150,15 @@ def rpie( recover_probe=recover_probe, position_options=position_options, exitwave_options=exitwave_options, - ))) + ) - batch_cost.append(comm.Allreduce_mean(cost, axis=None).get()) + batch_cost.append(cost) if algorithm_options.batch_method != 'compact': ( psi, probe, ) = _update( - comm, psi, probe, psi_update_numerator, @@ -161,8 +168,8 @@ def rpie( recover_probe, algorithm_options, ) - psi_update_numerator = [None] * comm.pool.num_workers - probe_update_numerator = [None] * comm.pool.num_workers + psi_update_numerator = None + probe_update_numerator = None algorithm_options.costs.append(batch_cost) @@ -170,8 +177,7 @@ def rpie( ( scan, position_options, - ) = (list(a) for a in zip(*comm.pool.map( - _update_position, + ) = _update_position( scan, position_options, position_update_numerator, @@ -179,14 +185,13 @@ def rpie( max_shift=probe[0].shape[-1] * 0.1, alpha=algorithm_options.alpha, epoch=epoch, - ))) + ) if algorithm_options.batch_method == 'compact': ( psi, probe, ) = _update( - comm, psi, probe, psi_update_numerator, @@ -195,23 +200,37 @@ def rpie( probe_options, recover_probe, algorithm_options, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + errors=[float(np.mean(x)) for x in algorithm_options.costs[-3:]], ) if eigen_weights is not None: - eigen_weights = comm.pool.map( - _normalize_eigen_weights, - beigen_weights, + eigen_weights = _normalize_eigen_weights( + eigen_weights, ) - parameters.probe = probe - parameters.psi = psi - parameters.scan = scan + # CONVERSION AREA BELOW ---------------------- + + probe_options.preconditioner = [probe_options.preconditioner] + object_options.preconditioner = [object_options.preconditioner] + + parameters.probe[0] = probe + parameters.psi[0] = psi + parameters.scan[0] = scan parameters.algorithm_options = algorithm_options parameters.probe_options = probe_options parameters.object_options = object_options - parameters.position_options = position_options - parameters.eigen_weights = eigen_weights + if position_options is None: + parameters.position_options = None + else: + parameters.position_options[0] = position_options + if eigen_probe is None: + parameters.eigen_probe = None + else: + parameters.eigen_probe[0] = eigen_probe + if eigen_weights is None: + parameters.eigen_weights = None + else: + parameters.eigen_weights[0] = eigen_weights return parameters @@ -224,7 +243,6 @@ def _normalize_eigen_weights(eigen_weights): def _update( - comm: tike.communicators.Comm, psi: npt.NDArray[cp.csingle], probe: npt.NDArray[cp.csingle], psi_update_numerator: npt.NDArray[cp.csingle], @@ -233,19 +251,19 @@ def _update( probe_options: ProbeOptions, recover_probe: bool, algorithm_options: RpieOptions, - errors: typing.Union[None, typing.List[float]] = None, -): + errors: typing.Union[None, npt.NDArray] = None, +) -> typing.Tuple[npt.NDArray[cp.csingle], npt.NDArray[cp.csingle]]: if object_options: - psi_update_numerator = comm.Allreduce_reduce_gpu( - psi_update_numerator)[0] dpsi = psi_update_numerator deno = ( - (1 - algorithm_options.alpha) * object_options.preconditioner[0] + - algorithm_options.alpha * object_options.preconditioner[0].max( + (1 - algorithm_options.alpha) * object_options.preconditioner + + algorithm_options.alpha + * object_options.preconditioner.max( axis=(-2, -1), keepdims=True, - )) - psi[0] = psi[0] + dpsi / deno + ) + ) + psi = psi + dpsi / deno if object_options.use_adaptive_moment: if errors: ( @@ -272,29 +290,31 @@ def _update( vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) - psi[0] = psi[0] + dpsi / deno - psi = comm.pool.bcast([psi[0]]) + psi = psi + dpsi / deno if recover_probe: - - probe_update_numerator = comm.Allreduce_reduce_gpu( - probe_update_numerator)[0] b0 = tike.ptycho.probe.finite_probe_support( - probe[0], + probe, p=probe_options.probe_support, radius=probe_options.probe_support_radius, degree=probe_options.probe_support_degree, ) - b1 = probe_options.additional_probe_penalty * cp.linspace( - 0, 1, probe[0].shape[-3], dtype='float32')[..., None, None] - dprobe = (probe_update_numerator - (b1 + b0) * probe[0]) + b1 = ( + probe_options.additional_probe_penalty + * cp.linspace(0, 1, probe.shape[-3], dtype="float32")[..., None, None] + ) + dprobe = probe_update_numerator - (b1 + b0) * probe deno = ( - (1 - algorithm_options.alpha) * probe_options.preconditioner[0] + - algorithm_options.alpha * probe_options.preconditioner[0].max( + (1 - algorithm_options.alpha) * probe_options.preconditioner + + algorithm_options.alpha + * probe_options.preconditioner.max( axis=(-2, -1), keepdims=True, - ) + b0 + b1) - probe[0] = probe[0] + dprobe / deno + ) + + b0 + + b1 + ) + probe = probe + dprobe / deno if probe_options.use_adaptive_moment: # ptychoshelves only applies momentum to the main probe mode = 0 @@ -323,8 +343,7 @@ def _update( vdecay=probe_options.vdecay, mdecay=probe_options.mdecay, ) - probe[0] = probe[0] + dprobe / deno - probe = comm.pool.bcast([probe[0]]) + probe = probe + dprobe / deno return psi, probe @@ -341,7 +360,7 @@ def _get_nearplane_gradients( position_update_denominator: typing.Union[None, npt.NDArray], eigen_probe: typing.Union[None, npt.NDArray], eigen_weights: typing.Union[None, npt.NDArray], - batches: typing.List[typing.List[int]], + batches: typing.List[npt.NDArray[np.intc]], streams: typing.List[cp.cuda.Stream], *, n: int, @@ -351,10 +370,11 @@ def _get_nearplane_gradients( recover_probe: bool, position_options: typing.Union[None, PositionOptions], exitwave_options: ExitWaveOptions, -) -> typing.List[npt.NDArray]: - - cost = 0.0 - count = 1.0 / len(batches[n]) +) -> typing.Tuple[ + float, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray | None +]: + cost: float = 0.0 + count: float = 1.0 / len(batches[n]) psi_update_numerator = cp.zeros_like( psi) if psi_update_numerator is None else psi_update_numerator probe_update_numerator = cp.zeros_like( @@ -544,7 +564,7 @@ def keep_some_args_constant( ) return ( - cost, + float(cost), psi_update_numerator, probe_update_numerator, position_update_numerator, @@ -559,10 +579,10 @@ def _update_position( position_update_numerator: npt.NDArray, position_update_denominator: npt.NDArray, *, - alpha=0.05, - max_shift=1, - epoch=0, -): + alpha: float = 0.05, + max_shift: float = 1.0, + epoch: int = 0, +) -> typing.Tuple[cp.ndarray, PositionOptions]: if epoch < position_options.update_start: return scan, position_options From 80aabce3ad94e46f0c1c4faf78d1110cf63cc43a Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 27 Jun 2024 17:35:02 -0500 Subject: [PATCH 12/31] DEV: Remove comm pool from lstsq implementation --- src/tike/ptycho/probe.py | 63 ++++---- src/tike/ptycho/solvers/lstsq.py | 260 +++++++++++++++---------------- 2 files changed, 154 insertions(+), 169 deletions(-) diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index a5d2e847..8877975c 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -139,7 +139,7 @@ class ProbeOptions: """ median_filter_abs_probe: bool = False - """Binary switch on whether to apply a median filter to absolute value of + """Binary switch on whether to apply a median filter to absolute value of each shared probe mode. """ @@ -202,7 +202,7 @@ def resample(self, factor: float, interp) -> ProbeOptions: probe_support=self.probe_support, probe_support_degree=self.probe_support_degree, probe_support_radius=self.probe_support_radius, - median_filter_abs_probe=self.median_filter_abs_probe, + median_filter_abs_probe=self.median_filter_abs_probe, median_filter_abs_probe_px=self.median_filter_abs_probe_px, ) return options @@ -411,7 +411,6 @@ def _get_weights_mean(n, d, d_mean, weights, batches, *, batch_index, c, m): def update_eigen_probe( - comm, R, eigen_probe, weights, @@ -459,14 +458,13 @@ def update_eigen_probe( least-squares solver for generalized maximum-likelihood ptychography. Optics Express. 2018. """ - assert R[0].shape[-3] == R[0].shape[-4] == 1 - assert 1 == eigen_probe[0].shape[-5] - assert R[0].shape[:-5] == eigen_probe[0].shape[:-5] == weights[0].shape[:-3] - assert weights[0][batches[0][batch_index], :, :].shape[-3] == R[0].shape[-5] - assert R[0].shape[-2:] == eigen_probe[0].shape[-2:] - - update = comm.pool.map( - _get_update, + assert R.shape[-3] == R.shape[-4] == 1 + assert 1 == eigen_probe.shape[-5] + assert R.shape[:-5] == eigen_probe.shape[:-5] == weights.shape[:-3] + assert weights[batches[batch_index], :, :].shape[-3] == R.shape[-5] + assert R.shape[-2:] == eigen_probe.shape[-2:] + + update = _get_update( R, eigen_probe, weights, @@ -475,13 +473,18 @@ def update_eigen_probe( c=c, m=m, ) - update = comm.pool.bcast([comm.Allreduce_mean( + + update = cp.mean( update, axis=-5, - )]) + ) - (eigen_probe, n, d, d_mean) = (list(a) for a in zip(*comm.pool.map( - _get_d, + ( + eigen_probe, + n, + d, + d_mean, + ) = _get_d( patches, diff, eigen_probe, @@ -489,25 +492,23 @@ def update_eigen_probe( β=β, c=c, m=m, - ))) + ) - d_mean = comm.pool.bcast([comm.Allreduce_mean( + d_mean = cp.mean( d_mean, axis=-3, - )]) - - weights = list( - comm.pool.map( - _get_weights_mean, - n, - d, - d_mean, - weights, - batches, - batch_index=batch_index, - c=c, - m=m, - )) + ) + + weights = _get_weights_mean( + n, + d, + d_mean, + weights, + batches, + batch_index=batch_index, + c=c, + m=m, + ) return eigen_probe, weights diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index ece6f3fd..0cc5b816 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -69,55 +69,54 @@ def lstsq_grad( .. seealso:: :py:mod:`tike.ptycho` """ - probe = parameters.probe - scan = parameters.scan - psi = parameters.psi + data0 = data[0] + batches0 = batches[0] + probe = parameters.probe[0] + scan = parameters.scan[0] + psi = parameters.psi[0] algorithm_options = parameters.algorithm_options - + exitwave_options = parameters.exitwave_options probe_options = parameters.probe_options - if probe_options is None: - recover_probe = False - else: - recover_probe = probe_options.recover_probe + recover_probe = True - position_options = parameters.position_options + if parameters.position_options is None: + position_options = None + else: + position_options = parameters.position_options[0] object_options = parameters.object_options - exitwave_options = parameters.exitwave_options - - eigen_probe = parameters.eigen_probe - eigen_weights = parameters.eigen_weights - - position_update_numerator = [None] * comm.pool.num_workers - position_update_denominator = [None] * comm.pool.num_workers - - if eigen_probe is None: - beigen_probe = [None] * comm.pool.num_workers + if parameters.eigen_probe is None: + eigen_probe = None else: - beigen_probe = eigen_probe - - if eigen_weights is None: - beigen_weights = [None] * comm.pool.num_workers + eigen_probe = parameters.eigen_probe[0] + if parameters.eigen_weights is None: + eigen_weights = None else: - beigen_weights = eigen_weights + eigen_weights = parameters.eigen_weights[0] - if object_options is not None: - if algorithm_options.batch_method == 'compact': - object_options.combined_update = cp.zeros_like(psi[0]) + measured_pixels = exitwave_options.measured_pixels[0] - if recover_probe: - probe_options.probe_update_sum = cp.zeros_like(probe[0]) + streams0 = comm.streams[0] + + probe_options.preconditioner = probe_options.preconditioner[0] + object_options.preconditioner = object_options.preconditioner[0] + + # CONVERSTION AREA ABOVE --------------------------------------- if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation - batch_cost = [] - beta_object = [] - beta_probe = [] - for batch_index in order(algorithm_options.num_batch): + object_combined_update = cp.zeros_like(psi) + probe_combined_update = cp.zeros_like(probe) + position_update_numerator = None + position_update_denominator = None + batch_cost: typing.List[float] = [] + beta_object: typing.List[float] = [] + beta_probe: typing.List[float] = [] + for batch_index in order(algorithm_options.num_batch): ( diff, unique_probe, @@ -129,21 +128,19 @@ def lstsq_grad( position_update_numerator, position_update_denominator, position_options, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_gradients, - data, + ) = _get_nearplane_gradients( + data0, psi, scan, probe, - beigen_probe, - beigen_weights, - batches, + eigen_probe, + eigen_weights, + batches0, position_update_numerator, position_update_denominator, - [None] * comm.pool.num_workers if position_options is - None else position_options, - comm.streams, - exitwave_options.measured_pixels, + position_options, + streams0, + measured_pixels, object_options.preconditioner, batch_index=batch_index, num_batch=algorithm_options.num_batch, @@ -152,33 +149,26 @@ def lstsq_grad( recover_psi=object_options is not None, recover_probe=recover_probe, recover_positions=position_options is not None, - ))) - position_options = None if position_options[ - 0] is None else position_options - - if object_options is not None: - object_upd_sum = comm.Allreduce(object_upd_sum) + ) if recover_probe: - m_probe_update = comm.pool.bcast( - [comm.Allreduce_mean( - m_probe_update, - axis=-5, - )]) + m_probe_update = cp.mean( + m_probe_update, + axis=-5, + ) ( - beigen_probe, - beigen_weights, + eigen_probe, + eigen_weights, ) = _update_nearplane( - comm, diff, probe_update, m_probe_update, probe, - beigen_probe, - beigen_weights, + eigen_probe, + eigen_weights, patches, - batches, + batches0, batch_index=batch_index, num_batch=algorithm_options.num_batch, ) @@ -190,8 +180,7 @@ def lstsq_grad( A4, b1, b2, - ) = (list(a) for a in zip(*comm.pool.map( - _precondition_nearplane_gradients, + ) = _precondition_nearplane_gradients( diff, scan, unique_probe, @@ -200,30 +189,29 @@ def lstsq_grad( m_probe_update, object_options.preconditioner, patches, - batches, + batches0, batch_index=batch_index, op=op, m=0, recover_psi=object_options is not None, recover_probe=recover_probe, probe_options=probe_options, - ))) + ) if object_options is not None: - A1_delta = comm.pool.bcast([comm.Allreduce_mean(A1, axis=-3)]) + A1_delta = cp.mean(A1, axis=-3) else: - A1_delta = [None] * comm.pool.num_workers + A1_delta = None if recover_probe: - A4_delta = comm.pool.bcast([comm.Allreduce_mean(A4, axis=-3)]) + A4_delta = cp.mean(A4, axis=-3) else: - A4_delta = [None] * comm.pool.num_workers + A4_delta = None ( weighted_step_psi, weighted_step_probe, - ) = (list(a) for a in zip(*comm.pool.map( - _get_nearplane_steps, + ) = _get_nearplane_steps( A1, A2, A4, @@ -234,16 +222,16 @@ def lstsq_grad( recover_psi=object_options is not None, recover_probe=recover_probe, m=0, - ))) + ) if object_options is not None: - bbeta_object = comm.Allreduce_mean( + bbeta_object = cp.mean( weighted_step_psi, axis=-5, )[..., 0, 0, 0] if recover_probe: - bbeta_probe = comm.Allreduce_mean( + bbeta_probe = cp.mean( weighted_step_probe, axis=-5, ) @@ -252,7 +240,7 @@ def lstsq_grad( if object_options is not None: if algorithm_options.batch_method != 'compact': # (27b) Object update - dpsi = bbeta_object[0] * object_update_precond[0] + dpsi = bbeta_object * object_update_precond if object_options.use_adaptive_moment: ( @@ -266,55 +254,47 @@ def lstsq_grad( vdecay=object_options.vdecay, mdecay=object_options.mdecay, ) - psi[0] = psi[0] + dpsi - psi = comm.pool.bcast([psi[0]]) + psi = psi + dpsi else: - object_options.combined_update += object_upd_sum[0] - - if recover_probe: - dprobe = bbeta_probe[0] * m_probe_update[0] - probe_options.probe_update_sum += dprobe / algorithm_options.num_batch - # (27a) Probe update - probe[0] += dprobe - probe = comm.pool.bcast([probe[0]]) + object_combined_update += object_upd_sum - for c in costs: - batch_cost = batch_cost + c.tolist() - - if object_options is not None: beta_object.append(bbeta_object) if recover_probe: - beta_probe.append(bbeta_probe) + dprobe = bbeta_probe * m_probe_update + probe_combined_update += dprobe / algorithm_options.num_batch + # (27a) Probe update + probe += dprobe - if eigen_probe is not None: - eigen_probe = beigen_probe + beta_probe.append(bbeta_probe) - if eigen_weights is not None: - eigen_weights = beigen_weights + batch_cost += costs.tolist() - if position_options: - scan, position_options = zip(*comm.pool.map( - _update_position, + if ( + position_options is not None + and position_update_numerator is not None + and position_update_denominator is not None + ): + scan, position_options = _update_position( scan, position_options, position_update_numerator, position_update_denominator, epoch=epoch, - )) + ) algorithm_options.costs.append(batch_cost) if object_options and algorithm_options.batch_method == 'compact': object_update_precond = _precondition_object_update( - object_options.combined_update, - object_options.preconditioner[0], + object_combined_update, + object_options.preconditioner, ) # (27b) Object update beta_object = cp.mean(cp.stack(beta_object)) dpsi = beta_object * object_update_precond - psi[0] = psi[0] + dpsi + psi = psi + dpsi if object_options.use_adaptive_moment: ( @@ -326,20 +306,18 @@ def lstsq_grad( v=object_options.v, m=object_options.m, mdecay=object_options.mdecay, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + errors=list(float(np.mean(x)) for x in algorithm_options.costs[-3:]), beta=beta_object, memory_length=3, ) - weight = object_options.preconditioner[0] + weight = object_options.preconditioner weight = weight / (0.1 * weight.max() + weight) - psi[0] = psi[0] + weight * dpsi - - psi = comm.pool.bcast([psi[0]]) + psi = psi + weight * dpsi if recover_probe: if probe_options.use_adaptive_moment: beta_probe = cp.mean(cp.stack(beta_probe)) - dprobe = probe_options.probe_update_sum + dprobe = probe_combined_update if probe_options.v is None: probe_options.v = np.zeros_like( dprobe, @@ -358,44 +336,54 @@ def lstsq_grad( v=probe_options.v[..., mode, :, :], m=probe_options.m[..., mode, :, :], mdecay=probe_options.mdecay, - errors=list(np.mean(x) for x in algorithm_options.costs[-3:]), + errors=list(float(np.mean(x)) for x in algorithm_options.costs[-3:]), beta=beta_probe, memory_length=3, ) - probe[0][..., mode, :, :] = probe[0][..., mode, :, :] + d - probe = comm.pool.bcast([probe[0]]) + probe[..., mode, :, :] = probe[..., mode, :, :] + d - parameters.probe = probe - parameters.psi = psi - parameters.scan = scan + # CONVERSION AREA BELOW ---------------------- + + probe_options.preconditioner = [probe_options.preconditioner] + object_options.preconditioner = [object_options.preconditioner] + + parameters.probe[0] = probe + parameters.psi[0] = psi + parameters.scan[0] = scan parameters.algorithm_options = algorithm_options parameters.probe_options = probe_options parameters.object_options = object_options - parameters.position_options = position_options - parameters.eigen_weights = eigen_weights - parameters.eigen_probe = eigen_probe + if position_options is None: + parameters.position_options = None + else: + parameters.position_options[0] = position_options + if eigen_probe is None: + parameters.eigen_probe = None + else: + parameters.eigen_probe[0] = eigen_probe + if eigen_weights is None: + parameters.eigen_weights = None + else: + parameters.eigen_weights[0] = eigen_weights return parameters def _update_nearplane( - comm: tike.communicators.Comm, - diff, - probe_update, - m_probe_update, - probe: typing.List[npt.NDArray[cp.csingle]], - eigen_probe: typing.List[npt.NDArray[cp.csingle]], - eigen_weights: typing.List[npt.NDArray[cp.single]], - patches, + diff: npt.NDArray[cp.csingle], + probe_update: npt.NDArray[cp.csingle], + m_probe_update: npt.NDArray[cp.csingle], + probe: npt.NDArray[cp.csingle], + eigen_probe: npt.NDArray[cp.csingle], + eigen_weights: npt.NDArray[cp.single], + patches: npt.NDArray[cp.csingle], batches, *, batch_index: int, num_batch: int, ): m = 0 - if eigen_weights[0] is not None: - - eigen_weights = comm.pool.map( - _get_coefs_intensity, + if eigen_weights is not None: + eigen_weights = _get_coefs_intensity( eigen_weights, diff, probe, @@ -406,23 +394,20 @@ def _update_nearplane( ) # (30) residual probe updates - if eigen_weights[0].shape[-2] > 1: - R = comm.pool.map( - _get_residuals, + if eigen_weights.shape[-2] > 1: + R = _get_residuals( probe_update, m_probe_update, m=m, ) - if eigen_probe[0] is not None and m < eigen_probe[0].shape[-3]: - assert eigen_weights[0].shape[-2] == eigen_probe[0].shape[-4] + 1 - for eigen_index in range(1, eigen_probe[0].shape[-4] + 1): - + if eigen_probe is not None and m < eigen_probe.shape[-3]: + assert eigen_weights.shape[-2] == eigen_probe.shape[-4] + 1 + for eigen_index in range(1, eigen_probe.shape[-4] + 1): ( eigen_probe, eigen_weights, ) = tike.ptycho.probe.update_eigen_probe( - comm, R, eigen_probe, eigen_weights, @@ -435,10 +420,9 @@ def _update_nearplane( m=m, ) - if eigen_index + 1 < eigen_weights[0].shape[-2]: + if eigen_index + 1 < eigen_weights.shape[-2]: # Subtract projection of R onto new probe from R - R = comm.pool.map( - _update_residuals, + R = _update_residuals( R, eigen_probe, batches, From dc47d7c37c6da3eb86cf05808dc619bca80f6ab8 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 27 Jun 2024 17:54:01 -0500 Subject: [PATCH 13/31] DEV: Revert changes to solvers --- src/tike/ptycho/probe.py | 23 +- src/tike/ptycho/solvers/lstsq.py | 371 ++++++++++++++------------ src/tike/ptycho/solvers/rpie.py | 438 +++++++++++++++++-------------- 3 files changed, 462 insertions(+), 370 deletions(-) diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index fce5851a..7dab1ac1 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -293,7 +293,7 @@ def _constrain_variable_probe2(variable_probe, weights, power): return variable_probe, weights -def constrain_variable_probe(variable_probe, weights): +def constrain_variable_probe(comm, variable_probe, weights): """Add the following constraints to variable probe weights 1. Remove outliars from weights @@ -307,16 +307,21 @@ def constrain_variable_probe(variable_probe, weights): # sorting and synchronizing the weights with the host OR implementing # smoothing of non-gridded data with splines using device-local data only. - variable_probe, weights, power = _constrain_variable_probe1( + variable_probe, weights, power = zip(*comm.pool.map( + _constrain_variable_probe1, variable_probe, weights, - ) + )) + + # reduce power by sum across all devices + power = comm.pool.allreduce(power) - variable_probe, weights = _constrain_variable_probe2( + variable_probe, weights = (list(a) for a in zip(*comm.pool.map( + _constrain_variable_probe2, variable_probe, weights, power, - ) + ))) return variable_probe, weights @@ -467,12 +472,18 @@ def update_eigen_probe( c=c, m=m, ) + update = cp.mean( update, axis=-5, ) - (eigen_probe, n, d, d_mean) = _get_d( + ( + eigen_probe, + n, + d, + d_mean, + ) = _get_d( patches, diff, eigen_probe, diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index e2de5116..0cc5b816 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -17,25 +17,18 @@ import tike.ptycho.exitwave import tike.precision -from .options import ( - ExitWaveOptions, - ObjectOptions, - PositionOptions, - ProbeOptions, - PtychoParameters, - LstsqOptions, -) +from .options import * logger = logging.getLogger(__name__) def lstsq_grad( - parameters: PtychoParameters, - data: npt.NDArray, + op: tike.operators.Ptycho, + comm: tike.communicators.Comm, + data: typing.List[npt.NDArray], batches: typing.List[npt.NDArray[cp.intc]], - streams: typing.List[cp.cuda.Stream], *, - op: tike.operators.Ptycho, + parameters: PtychoParameters, epoch: int, ): """Solve the ptychography problem using Odstrcil et al's approach. @@ -76,25 +69,54 @@ def lstsq_grad( .. seealso:: :py:mod:`tike.ptycho` """ - print('helloworld3') + data0 = data[0] + batches0 = batches[0] + + probe = parameters.probe[0] + scan = parameters.scan[0] + psi = parameters.psi[0] + algorithm_options = parameters.algorithm_options + exitwave_options = parameters.exitwave_options + probe_options = parameters.probe_options + recover_probe = True + + if parameters.position_options is None: + position_options = None + else: + position_options = parameters.position_options[0] + object_options = parameters.object_options + if parameters.eigen_probe is None: + eigen_probe = None + else: + eigen_probe = parameters.eigen_probe[0] + if parameters.eigen_weights is None: + eigen_weights = None + else: + eigen_weights = parameters.eigen_weights[0] + + measured_pixels = exitwave_options.measured_pixels[0] + streams0 = comm.streams[0] + + probe_options.preconditioner = probe_options.preconditioner[0] + object_options.preconditioner = object_options.preconditioner[0] + + # CONVERSTION AREA ABOVE --------------------------------------- if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation - psi_combined_update: None | cp.ndarray = None - probe_combined_update: None | cp.ndarray = None - position_update_numerator: None | cp.ndarray = None - position_update_denominator: None | cp.ndarray = None - - recover_probe: bool = parameters.probe_options is not None + object_combined_update = cp.zeros_like(psi) + probe_combined_update = cp.zeros_like(probe) + position_update_numerator = None + position_update_denominator = None - batch_cost = [] - beta_object = [] - beta_probe = [] - for batch_index in order(parameters.algorithm_options.num_batch): + batch_cost: typing.List[float] = [] + beta_object: typing.List[float] = [] + beta_probe: typing.List[float] = [] + for batch_index in order(algorithm_options.num_batch): ( diff, unique_probe, @@ -105,45 +127,50 @@ def lstsq_grad( patches, position_update_numerator, position_update_denominator, - parameters.position_options, + position_options, ) = _get_nearplane_gradients( - data, - parameters.psi, - parameters.scan, - parameters.probe, - parameters.eigen_probe, - parameters.eigen_weights, - batches, + data0, + psi, + scan, + probe, + eigen_probe, + eigen_weights, + batches0, position_update_numerator, position_update_denominator, - parameters.position_options, - streams, - parameters.exitwave_options.measured_pixels, - parameters.object_options.preconditioner, + position_options, + streams0, + measured_pixels, + object_options.preconditioner, batch_index=batch_index, - num_batch=parameters.algorithm_options.num_batch, - exitwave_options=parameters.exitwave_options, + num_batch=algorithm_options.num_batch, + exitwave_options=exitwave_options, op=op, - recover_psi=parameters.object_options is not None, + recover_psi=object_options is not None, recover_probe=recover_probe, - recover_positions=parameters.position_options is not None, + recover_positions=position_options is not None, ) - if parameters.probe_options: + if recover_probe: + m_probe_update = cp.mean( + m_probe_update, + axis=-5, + ) + ( - parameters.eigen_probe, - parameters.eigen_weights, + eigen_probe, + eigen_weights, ) = _update_nearplane( diff, probe_update, m_probe_update, - parameters.probe, - parameters.eigen_probe, - parameters.eigen_weights, + probe, + eigen_probe, + eigen_weights, patches, - batches, + batches0, batch_index=batch_index, - num_batch=parameters.algorithm_options.num_batch, + num_batch=algorithm_options.num_batch, ) ( @@ -155,23 +182,23 @@ def lstsq_grad( b2, ) = _precondition_nearplane_gradients( diff, - parameters.scan, + scan, unique_probe, - parameters.probe, + probe, object_upd_sum, m_probe_update, - parameters.object_options.preconditioner, + object_options.preconditioner, patches, - batches, + batches0, batch_index=batch_index, op=op, m=0, - recover_psi=parameters.object_options is not None, + recover_psi=object_options is not None, recover_probe=recover_probe, - probe_options=parameters.probe_options, + probe_options=probe_options, ) - if parameters.object_options is not None: + if object_options is not None: A1_delta = cp.mean(A1, axis=-3) else: A1_delta = None @@ -192,12 +219,12 @@ def lstsq_grad( b2, A1_delta, A4_delta, - recover_psi=parameters.object_options is not None, + recover_psi=object_options is not None, recover_probe=recover_probe, m=0, ) - if parameters.object_options is not None: + if object_options is not None: bbeta_object = cp.mean( weighted_step_psi, axis=-5, @@ -209,127 +236,135 @@ def lstsq_grad( axis=-5, ) - print('helloworld1') - # Update each direction - if parameters.object_options is not None: - - print('helloworld') - - if parameters.algorithm_options.batch_method != "compact": + if object_options is not None: + if algorithm_options.batch_method != 'compact': # (27b) Object update dpsi = bbeta_object * object_update_precond - print(dpsi.shape, parameters.psi.shape) - - if parameters.object_options.use_adaptive_moment: + if object_options.use_adaptive_moment: ( dpsi, - parameters.object_options.v, - parameters.object_options.m, + object_options.v, + object_options.m, ) = tike.opt.momentum( g=dpsi, - v=parameters.object_options.v, - m=parameters.object_options.m, - vdecay=parameters.object_options.vdecay, - mdecay=parameters.object_options.mdecay, + v=object_options.v, + m=object_options.m, + vdecay=object_options.vdecay, + mdecay=object_options.mdecay, ) - parameters.psi = parameters.psi + dpsi + psi = psi + dpsi else: - psi_combined_update += object_upd_sum + object_combined_update += object_upd_sum + + beta_object.append(bbeta_object) if recover_probe: dprobe = bbeta_probe * m_probe_update - probe_combined_update += dprobe / parameters.algorithm_options.num_batch + probe_combined_update += dprobe / algorithm_options.num_batch # (27a) Probe update - parameters.probe += dprobe - - for c in costs: - batch_cost = batch_cost + c.tolist() - - if parameters.object_options is not None: - beta_object.append(bbeta_object) + probe += dprobe - if recover_probe: beta_probe.append(bbeta_probe) - if parameters.position_options: - parameters.scan, parameters.position_options = _update_position( - parameters.scan, - parameters.position_options, + batch_cost += costs.tolist() + + if ( + position_options is not None + and position_update_numerator is not None + and position_update_denominator is not None + ): + scan, position_options = _update_position( + scan, + position_options, position_update_numerator, position_update_denominator, epoch=epoch, ) - parameters.algorithm_options.costs.append(batch_cost) + algorithm_options.costs.append(batch_cost) - if ( - parameters.object_options - and parameters.algorithm_options.batch_method == "compact" - ): + if object_options and algorithm_options.batch_method == 'compact': object_update_precond = _precondition_object_update( - psi_combined_update, - parameters.object_options.preconditioner, + object_combined_update, + object_options.preconditioner, ) # (27b) Object update beta_object = cp.mean(cp.stack(beta_object)) dpsi = beta_object * object_update_precond - parameters.psi = psi + dpsi + psi = psi + dpsi - if parameters.object_options.use_adaptive_moment: + if object_options.use_adaptive_moment: ( dpsi, - parameters.object_options.v, - parameters.object_options.m, + object_options.v, + object_options.m, ) = _momentum_checked( g=dpsi, - v=parameters.object_options.v, - m=parameters.object_options.m, - mdecay=parameters.object_options.mdecay, - errors=list( - float(np.mean(x)) for x in parameters.algorithm_options.costs[-3:] - ), + v=object_options.v, + m=object_options.m, + mdecay=object_options.mdecay, + errors=list(float(np.mean(x)) for x in algorithm_options.costs[-3:]), beta=beta_object, memory_length=3, ) - weight = parameters.object_options.preconditioner + weight = object_options.preconditioner weight = weight / (0.1 * weight.max() + weight) - parameters.psi = parameters.psi + weight * dpsi + psi = psi + weight * dpsi if recover_probe: - if parameters.probe_options.use_adaptive_moment: + if probe_options.use_adaptive_moment: beta_probe = cp.mean(cp.stack(beta_probe)) dprobe = probe_combined_update - if parameters.probe_options.v is None: - parameters.probe_options.v = np.zeros_like( + if probe_options.v is None: + probe_options.v = np.zeros_like( dprobe, shape=(3, *dprobe.shape), ) - if parameters.probe_options.m is None: - parameters.probe_options.m = np.zeros_like( - dprobe, - ) + if probe_options.m is None: + probe_options.m = np.zeros_like(dprobe,) # ptychoshelves only applies momentum to the main probe mode = 0 ( d, - parameters.probe_options.v[..., mode, :, :], - parameters.probe_options.m[..., mode, :, :], + probe_options.v[..., mode, :, :], + probe_options.m[..., mode, :, :], ) = _momentum_checked( g=dprobe[..., mode, :, :], - v=parameters.probe_options.v[..., mode, :, :], - m=parameters.probe_options.m[..., mode, :, :], - mdecay=parameters.probe_options.mdecay, - errors=list( - float(np.mean(x)) for x in parameters.algorithm_options.costs[-3:] - ), + v=probe_options.v[..., mode, :, :], + m=probe_options.m[..., mode, :, :], + mdecay=probe_options.mdecay, + errors=list(float(np.mean(x)) for x in algorithm_options.costs[-3:]), beta=beta_probe, memory_length=3, ) - parameters.probe[..., mode, :, :] = parameters.probe[..., mode, :, :] + d + probe[..., mode, :, :] = probe[..., mode, :, :] + d + + # CONVERSION AREA BELOW ---------------------- + probe_options.preconditioner = [probe_options.preconditioner] + object_options.preconditioner = [object_options.preconditioner] + + parameters.probe[0] = probe + parameters.psi[0] = psi + parameters.scan[0] = scan + parameters.algorithm_options = algorithm_options + parameters.probe_options = probe_options + parameters.object_options = object_options + if position_options is None: + parameters.position_options = None + else: + parameters.position_options[0] = position_options + if eigen_probe is None: + parameters.eigen_probe = None + else: + parameters.eigen_probe[0] = eigen_probe + if eigen_weights is None: + parameters.eigen_weights = None + else: + parameters.eigen_weights[0] = eigen_weights return parameters @@ -341,7 +376,7 @@ def _update_nearplane( eigen_probe: npt.NDArray[cp.csingle], eigen_weights: npt.NDArray[cp.single], patches: npt.NDArray[cp.csingle], - batches: typing.List[npt.NDArray[np.intc]], + batches, *, batch_index: int, num_batch: int, @@ -408,11 +443,11 @@ def _get_nearplane_gradients( psi: npt.NDArray[cp.csingle], scan: npt.NDArray[cp.single], probe: npt.NDArray[cp.csingle], - eigen_probe: npt.NDArray[cp.csingle], - eigen_weights: npt.NDArray[cp.csingle], - batches: typing.List[npt.NDArray[np.intc]], - position_update_numerator: npt.NDArray[cp.csingle], - position_update_denominator: npt.NDArray[cp.csingle], + eigen_probe, + eigen_weights, + batches, + position_update_numerator, + position_update_denominator, position_options: PositionOptions, streams: typing.List[cp.cuda.Stream], measured_pixels: npt.NDArray, @@ -559,13 +594,12 @@ def keep_some_args_constant( else: object_upd_sum = None - bpatches[blo:bhi] = op.diffraction.patch.fwd( - patches=cp.zeros_like(bchi[blo:bhi, ..., 0, 0, :, :]), - images=psi[0], - positions=scan[lo:hi], - )[..., None, None, :, :] - if recover_probe: + bpatches[blo:bhi] = op.diffraction.patch.fwd( + patches=cp.zeros_like(bchi[blo:bhi, ..., 0, 0, :, :]), + images=psi[0], + positions=scan[lo:hi], + )[..., None, None, :, :] # (24a) bprobe_update[blo:bhi] = cp.conj(bpatches[blo:bhi]) * bchi[blo:bhi] # (25a) Common probe gradient. Use simple average instead of @@ -579,6 +613,7 @@ def keep_some_args_constant( else: bprobe_update = None m_probe_update = None + bpatches = None if position_options: m = 0 @@ -655,23 +690,23 @@ def _precondition_object_update( def _precondition_nearplane_gradients( - nearplane: npt.NDArray[cp.csingle], - scan: npt.NDArray[cp.single], - unique_probe: npt.NDArray[cp.csingle], - probe: npt.NDArray[cp.csingle], - object_upd_sum: npt.NDArray[cp.csingle], - m_probe_update: npt.NDArray[cp.csingle], - psi_update_denominator: npt.NDArray[cp.csingle], - patches: npt.NDArray[cp.csingle], - batches: typing.List[npt.NDArray[np.intc]], + nearplane, + scan, + unique_probe, + probe, + object_upd_sum, + m_probe_update, + psi_update_denominator, + patches, + batches, *, batch_index: int, - op: tike.operators.Ptycho, - m: int, - recover_psi: bool, - recover_probe: bool, - alpha: float = 0.05, - probe_options: ProbeOptions, + op, + m, + recover_psi, + recover_probe, + alpha=0.05, + probe_options, ): lo = batches[batch_index][0] hi = lo + len(batches[batch_index]) @@ -686,7 +721,6 @@ def _precondition_nearplane_gradients( dOP = None dPO = None object_update_proj = None - object_update_precond = None if recover_psi: object_update_precond = _precondition_object_update( @@ -719,7 +753,7 @@ def _precondition_nearplane_gradients( 1, probe[0].shape[-3], dtype=tike.precision.floating, - )[..., m : m + 1, None, None] # type: ignore + )[..., m : m + 1, None, None] ) m_probe_update = m_probe_update - (b0 + b1) * probe[..., m : m + 1, :, :] @@ -737,16 +771,16 @@ def _precondition_nearplane_gradients( dPO = m_probe_update[..., m:m + 1, :, :] * patches A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1)) - if dOP is not None and dPO is not None: + if recover_psi and recover_probe: b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) A2 = cp.sum((dOP * dPO.conj()), axis=(-2, -1)) - elif dOP is not None: + elif recover_psi: b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) - elif dPO is not None: + elif recover_probe: b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) @@ -760,18 +794,9 @@ def _precondition_nearplane_gradients( ) -def _get_nearplane_steps( - A1, - A2, - A4, - b1, - b2, - A1_delta, - A4_delta, - recover_psi, - recover_probe, - m, -) -> typing.Tuple[npt.NDArray | None, npt.NDArray | None]: +def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi, + recover_probe, m): + if recover_psi: A1 += 0.5 * A1_delta if recover_probe: @@ -791,7 +816,7 @@ def _get_nearplane_steps( x1 = None x2 = None - if x1 is not None: + if recover_psi: step = 0.9 * cp.maximum(0, x1[..., None, None].real) # (27b) Object update @@ -799,7 +824,7 @@ def _get_nearplane_steps( else: beta_object = None - if x2 is not None: + if recover_probe: step = 0.9 * cp.maximum(0, x2[..., None, None].real) beta_probe = cp.mean(step, axis=-5, keepdims=True) @@ -861,7 +886,7 @@ def _update_position( alpha=0.05, max_shift=1, epoch=0, -) -> typing.Tuple[npt.NDArray, PositionOptions]: +): if epoch < position_options.update_start: return scan, position_options @@ -937,10 +962,10 @@ def _momentum_checked( ).real.flatten() if np.all(previous_update_correlation > 0): friction, _ = tike.opt.fit_line_least_squares( - x=np.arange(len(previous_update_correlation) + 1, dtype=np.floating), - y=np.array( - [0, - ] + np.log(previous_update_correlation).tolist()), + x=np.arange(len(previous_update_correlation) + 1), + y=[ + 0, + ] + np.log(previous_update_correlation).tolist(), ) friction = 0.5 * max(-friction, 0) m = (1 - friction) * m + g diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index ccd17e16..bf4321ef 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -3,8 +3,8 @@ import cupy as cp import cupyx.scipy.stats -import numpy.typing as npt import numpy as np +import numpy.typing as npt import tike.communicators import tike.linalg @@ -17,26 +17,19 @@ import tike.precision import tike.random -from .options import ( - ExitWaveOptions, - ObjectOptions, - PositionOptions, - ProbeOptions, - PtychoParameters, - RpieOptions, -) +from .options import * from .lstsq import _momentum_checked logger = logging.getLogger(__name__) def rpie( - parameters: PtychoParameters, - data: npt.NDArray, - batches: typing.List[npt.NDArray[cp.intc]], - streams: typing.List[cp.cuda.Stream], - *, op: tike.operators.Ptycho, + comm: tike.communicators.Comm, + data: typing.List[npt.NDArray], + batches: typing.List[typing.List[npt.NDArray[cp.intc]]], + *, + parameters: PtychoParameters, epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using regularized ptychographical engine. @@ -83,80 +76,161 @@ def rpie( .. seealso:: :py:mod:`tike.ptycho` """ + data0 = data[0] + batches0 = batches[0] + + probe = parameters.probe[0] + scan = parameters.scan[0] + psi = parameters.psi[0] + algorithm_options = parameters.algorithm_options + exitwave_options = parameters.exitwave_options + probe_options = parameters.probe_options + recover_probe = True + + if parameters.position_options is None: + position_options = None + else: + position_options = parameters.position_options[0] + object_options = parameters.object_options + if parameters.eigen_probe is None: + eigen_probe = None + else: + eigen_probe = parameters.eigen_probe[0] + if parameters.eigen_weights is None: + eigen_weights = None + else: + eigen_weights = parameters.eigen_weights[0] + + measured_pixels = exitwave_options.measured_pixels[0] + + streams0 = comm.streams[0] + + probe_options.preconditioner = probe_options.preconditioner[0] + object_options.preconditioner = object_options.preconditioner[0] + + # CONVERSTION AREA ABOVE --------------------------------------- + if parameters.algorithm_options.batch_method == 'compact': order = range else: order = tike.random.randomizer_np.permutation - psi_update_numerator: None | cp.ndarray = None - probe_update_numerator: None | cp.ndarray = None - position_update_numerator: None | cp.ndarray = None - position_update_denominator: None | cp.ndarray = None + psi_update_numerator = None + probe_update_numerator = None + position_update_numerator = None + position_update_denominator = None - for n in order(parameters.algorithm_options.num_batch): + batch_cost: typing.List[float] = [] + for n in order(algorithm_options.num_batch): ( cost, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - parameters, + eigen_weights, ) = _get_nearplane_gradients( - data, - parameters, + data0, + scan, + psi, + probe, + measured_pixels, psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - batches, - streams, + eigen_probe, + eigen_weights, + batches0, + streams0, n=n, op=op, - epoch=epoch, + object_options=object_options, + probe_options=probe_options, + recover_probe=recover_probe, + position_options=position_options, + exitwave_options=exitwave_options, ) - if parameters.algorithm_options.batch_method != "compact": - parameters = _update( - parameters, + batch_cost.append(cost) + + if algorithm_options.batch_method != 'compact': + ( + psi, + probe, + ) = _update( + psi, + probe, psi_update_numerator, probe_update_numerator, - recover_probe=parameters.probe_options.update_start >= epoch, + object_options, + probe_options, + recover_probe, + algorithm_options, ) psi_update_numerator = None probe_update_numerator = None - parameters.algorithm_options.costs.append([cost]) + algorithm_options.costs.append(batch_cost) - if parameters.position_options is not None: + if position_options is not None: ( - parameters.scan, - parameters.position_options, + scan, + position_options, ) = _update_position( - parameters.scan, - parameters.position_options, + scan, + position_options, position_update_numerator, position_update_denominator, - max_shift=parameters.probe.shape[-1] * 0.1, - alpha=parameters.algorithm_options.alpha, + max_shift=probe[0].shape[-1] * 0.1, + alpha=algorithm_options.alpha, epoch=epoch, ) - if parameters.algorithm_options.batch_method == "compact": - parameters = _update( - parameters, + if algorithm_options.batch_method == 'compact': + ( + psi, + probe, + ) = _update( + psi, + probe, psi_update_numerator, probe_update_numerator, - recover_probe=parameters.probe_options.update_start >= epoch, - errors=list( - float(np.mean(x)) for x in parameters.algorithm_options.costs[-3:] - ), + object_options, + probe_options, + recover_probe, + algorithm_options, + errors=[float(np.mean(x)) for x in algorithm_options.costs[-3:]], ) - if parameters.eigen_weights is not None: - parameters.eigen_weights = _normalize_eigen_weights( - parameters.eigen_weights, + if eigen_weights is not None: + eigen_weights = _normalize_eigen_weights( + eigen_weights, ) + # CONVERSION AREA BELOW ---------------------- + + probe_options.preconditioner = [probe_options.preconditioner] + object_options.preconditioner = [object_options.preconditioner] + + parameters.probe[0] = probe + parameters.psi[0] = psi + parameters.scan[0] = scan + parameters.algorithm_options = algorithm_options + parameters.probe_options = probe_options + parameters.object_options = object_options + if position_options is None: + parameters.position_options = None + else: + parameters.position_options[0] = position_options + if eigen_probe is None: + parameters.eigen_probe = None + else: + parameters.eigen_probe[0] = eigen_probe + if eigen_weights is None: + parameters.eigen_weights = None + else: + parameters.eigen_weights[0] = eigen_weights return parameters @@ -169,157 +243,148 @@ def _normalize_eigen_weights(eigen_weights): def _update( - parameters: PtychoParameters, + psi: npt.NDArray[cp.csingle], + probe: npt.NDArray[cp.csingle], psi_update_numerator: npt.NDArray[cp.csingle], probe_update_numerator: npt.NDArray[cp.csingle], + object_options: ObjectOptions, + probe_options: ProbeOptions, recover_probe: bool, - errors: typing.Union[None, typing.List[float]] = None, -) -> PtychoParameters: - if parameters.object_options: + algorithm_options: RpieOptions, + errors: typing.Union[None, npt.NDArray] = None, +) -> typing.Tuple[npt.NDArray[cp.csingle], npt.NDArray[cp.csingle]]: + if object_options: dpsi = psi_update_numerator deno = ( - (1 - parameters.algorithm_options.alpha) - * parameters.object_options.preconditioner - + parameters.algorithm_options.alpha - * parameters.object_options.preconditioner.max( + (1 - algorithm_options.alpha) * object_options.preconditioner + + algorithm_options.alpha + * object_options.preconditioner.max( axis=(-2, -1), keepdims=True, ) ) - parameters.psi = parameters.psi + dpsi / deno - if parameters.object_options.use_adaptive_moment: - if errors is not None: + psi = psi + dpsi / deno + if object_options.use_adaptive_moment: + if errors: ( dpsi, - parameters.object_options.v, - parameters.object_options.m, + object_options.v, + object_options.m, ) = _momentum_checked( g=dpsi, - v=parameters.object_options.v, - m=parameters.object_options.m, - mdecay=parameters.object_options.mdecay, + v=object_options.v, + m=object_options.m, + mdecay=object_options.mdecay, errors=errors, memory_length=3, ) else: ( dpsi, - parameters.object_options.v, - parameters.object_options.m, + object_options.v, + object_options.m, ) = tike.opt.adam( g=dpsi, - v=parameters.object_options.v, - m=parameters.object_options.m, - vdecay=parameters.object_options.vdecay, - mdecay=parameters.object_options.mdecay, + v=object_options.v, + m=object_options.m, + vdecay=object_options.vdecay, + mdecay=object_options.mdecay, ) - parameters.psi = parameters.psi + dpsi / deno + psi = psi + dpsi / deno - if recover_probe and parameters.probe_options is not None: + if recover_probe: b0 = tike.ptycho.probe.finite_probe_support( - parameters.probe, - p=parameters.probe_options.probe_support, - radius=parameters.probe_options.probe_support_radius, - degree=parameters.probe_options.probe_support_degree, + probe, + p=probe_options.probe_support, + radius=probe_options.probe_support_radius, + degree=probe_options.probe_support_degree, ) b1 = ( - parameters.probe_options.additional_probe_penalty - * cp.linspace( - start=0, - stop=1, - num=parameters.probe.shape[-3], - dtype="float32", - )[..., None, None] + probe_options.additional_probe_penalty + * cp.linspace(0, 1, probe.shape[-3], dtype="float32")[..., None, None] ) - dprobe = (probe_update_numerator - (b1 + b0) * parameters.probe) + dprobe = probe_update_numerator - (b1 + b0) * probe deno = ( - (1 - parameters.algorithm_options.alpha) - * parameters.probe_options.preconditioner - + parameters.algorithm_options.alpha - * parameters.probe_options.preconditioner.max( + (1 - algorithm_options.alpha) * probe_options.preconditioner + + algorithm_options.alpha + * probe_options.preconditioner.max( axis=(-2, -1), keepdims=True, ) + b0 + b1 ) - parameters.probe = parameters.probe + dprobe / deno - if parameters.probe_options.use_adaptive_moment: + probe = probe + dprobe / deno + if probe_options.use_adaptive_moment: # ptychoshelves only applies momentum to the main probe mode = 0 if errors: ( dprobe[0, 0, mode, :, :], - parameters.probe_options.v, - parameters.probe_options.m, + probe_options.v, + probe_options.m, ) = _momentum_checked( - g=dprobe[0, 0, mode, :, :], - v=parameters.probe_options.v, - m=parameters.probe_options.m, - mdecay=parameters.probe_options.mdecay, + g=(dprobe)[0, 0, mode, :, :], + v=probe_options.v, + m=probe_options.m, + mdecay=probe_options.mdecay, errors=errors, memory_length=3, ) else: ( dprobe[0, 0, mode, :, :], - parameters.probe_options.v, - parameters.probe_options.m, + probe_options.v, + probe_options.m, ) = tike.opt.adam( - g=dprobe[0, 0, mode, :, :], - v=parameters.probe_options.v, - m=parameters.probe_options.m, - vdecay=parameters.probe_options.vdecay, - mdecay=parameters.probe_options.mdecay, + g=(dprobe)[0, 0, mode, :, :], + v=probe_options.v, + m=probe_options.m, + vdecay=probe_options.vdecay, + mdecay=probe_options.mdecay, ) - parameters.probe = parameters.probe + dprobe / deno + probe = probe + dprobe / deno - return parameters + return psi, probe def _get_nearplane_gradients( data: npt.NDArray, - parameters: PtychoParameters, + scan: npt.NDArray, + psi: npt.NDArray, + probe: npt.NDArray, + measured_pixels: npt.NDArray, psi_update_numerator: typing.Union[None, npt.NDArray], probe_update_numerator: typing.Union[None, npt.NDArray], position_update_numerator: typing.Union[None, npt.NDArray], position_update_denominator: typing.Union[None, npt.NDArray], + eigen_probe: typing.Union[None, npt.NDArray], + eigen_weights: typing.Union[None, npt.NDArray], batches: typing.List[npt.NDArray[np.intc]], streams: typing.List[cp.cuda.Stream], *, n: int, op: tike.operators.Ptycho, - epoch: int, + object_options: typing.Union[None, ObjectOptions] = None, + probe_options: typing.Union[None, ProbeOptions] = None, + recover_probe: bool, + position_options: typing.Union[None, PositionOptions], + exitwave_options: ExitWaveOptions, ) -> typing.Tuple[ - float, - npt.ArrayLike, - npt.ArrayLike, - npt.ArrayLike, - npt.ArrayLike, - PtychoParameters, + float, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray, npt.NDArray | None ]: - cost = cp.zeros(1) - count = cp.array(1.0 / len(batches[n])) - psi_update_numerator = ( - cp.zeros_like(parameters.psi) - if psi_update_numerator is None - else psi_update_numerator - ) - probe_update_numerator = ( - cp.zeros_like(parameters.probe) - if probe_update_numerator is None - else probe_update_numerator - ) - position_update_numerator = ( - cp.empty_like(parameters.scan) - if position_update_numerator is None - else position_update_numerator - ) - position_update_denominator = ( - cp.empty_like(parameters.scan) - if position_update_denominator is None - else position_update_denominator - ) + cost: float = 0.0 + count: float = 1.0 / len(batches[n]) + psi_update_numerator = cp.zeros_like( + psi) if psi_update_numerator is None else psi_update_numerator + probe_update_numerator = cp.zeros_like( + probe) if probe_update_numerator is None else probe_update_numerator + position_update_numerator = cp.empty_like( + scan + ) if position_update_numerator is None else position_update_numerator + position_update_denominator = cp.empty_like( + scan + ) if position_update_denominator is None else position_update_denominator def keep_some_args_constant( ind_args, @@ -329,50 +394,47 @@ def keep_some_args_constant( (data,) = ind_args nonlocal cost, psi_update_numerator, probe_update_numerator nonlocal position_update_numerator, position_update_denominator + nonlocal eigen_weights, scan unique_probe = tike.ptycho.probe.get_varying_probe( - parameters.probe, - parameters.eigen_probe, - parameters.eigen_weights[lo:hi] - if parameters.eigen_weights is not None - else None, + probe, + eigen_probe, + eigen_weights[lo:hi] if eigen_weights is not None else None, ) - farplane = op.fwd( - probe=unique_probe, - scan=parameters.scan[lo:hi], - psi=parameters.psi, - ) + farplane = op.fwd(probe=unique_probe, scan=scan[lo:hi], psi=psi) intensity = cp.sum( cp.square(cp.abs(farplane)), axis=list(range(1, farplane.ndim - 2)), ) each_cost = getattr( tike.operators, - f"{parameters.exitwave_options.noise_model}_each_pattern", + f'{exitwave_options.noise_model}_each_pattern', )( - data[:, parameters.exitwave_options.measured_pixels][:, None, :], - intensity[:, parameters.exitwave_options.measured_pixels][:, None, :], + data[:, measured_pixels][:, None, :], + intensity[:, measured_pixels][:, None, :], ) cost += cp.sum(each_cost) * count - if parameters.exitwave_options.noise_model == "poisson": + if exitwave_options.noise_model == 'poisson': + xi = (1 - data / intensity)[:, None, None, :, :] grad_cost = farplane * xi step_length = cp.full( shape=(farplane.shape[0], 1, farplane.shape[2], 1, 1), - fill_value=parameters.exitwave_options.step_length_start, + fill_value=exitwave_options.step_length_start, ) - if parameters.exitwave_options.step_length_usemodes == "dominant_mode": + if exitwave_options.step_length_usemodes == 'dominant_mode': + step_length = tike.ptycho.exitwave.poisson_steplength_dominant_mode( xi, intensity, data, - parameters.exitwave_options.measured_pixels, + measured_pixels, step_length, - parameters.exitwave_options.step_length_weight, + exitwave_options.step_length_weight, ) else: @@ -382,67 +444,61 @@ def keep_some_args_constant( cp.square(cp.abs(farplane)), intensity, data, - parameters.exitwave_options.measured_pixels, + measured_pixels, step_length, - parameters.exitwave_options.step_length_weight, + exitwave_options.step_length_weight, ) - farplane[..., parameters.exitwave_options.measured_pixels] = ( - -step_length * grad_cost - )[..., parameters.exitwave_options.measured_pixels] + farplane[..., measured_pixels] = (-step_length * + grad_cost)[..., measured_pixels] else: # Gaussian noise model for exitwave updates, steplength = 1 # TODO: optimal step lengths using 2nd order taylor expansion - farplane[..., parameters.exitwave_options.measured_pixels] = -getattr( - tike.operators, f"{parameters.exitwave_options.noise_model}_grad" - )( - data, - farplane, - intensity, - )[..., parameters.exitwave_options.measured_pixels] + farplane[..., measured_pixels] = -getattr( + tike.operators, f'{exitwave_options.noise_model}_grad')( + data, + farplane, + intensity, + )[..., measured_pixels] - unmeasured_pixels = cp.logical_not(parameters.exitwave_options.measured_pixels) + unmeasured_pixels = cp.logical_not(measured_pixels) farplane[..., unmeasured_pixels] *= ( - parameters.exitwave_options.unmeasured_pixels_scaling - 1.0 - ) + exitwave_options.unmeasured_pixels_scaling - 1.0) pad, end = op.diffraction.pad, op.diffraction.end diff = op.propagation.adj(farplane, overwrite=True)[..., pad:end, pad:end] - if parameters.object_options: - grad_psi = ( - cp.conj(unique_probe) * diff / parameters.probe.shape[-3] - ).reshape( - parameters.scan[lo:hi].shape[0] * parameters.probe.shape[-3], - *parameters.probe.shape[-2:], - ) + if object_options: + grad_psi = (cp.conj(unique_probe) * diff / probe.shape[-3]).reshape( + scan[lo:hi].shape[0] * probe.shape[-3], *probe.shape[-2:]) psi_update_numerator[0] = op.diffraction.patch.adj( patches=grad_psi, images=psi_update_numerator[0], - positions=parameters.scan[lo:hi], - nrepeat=parameters.probe.shape[-3], + positions=scan[lo:hi], + nrepeat=probe.shape[-3], ) - if parameters.position_options or parameters.probe_options: + if position_options or probe_options: + patches = op.diffraction.patch.fwd( patches=cp.zeros_like(diff[..., 0, 0, :, :]), - images=parameters.psi[0], - positions=parameters.scan[lo:hi], + images=psi[0], + positions=scan[lo:hi], )[..., None, None, :, :] - if parameters.probe_options and parameters.probe_options.update_start >= epoch: + if recover_probe: probe_update_numerator += cp.sum( cp.conj(patches) * diff, axis=-5, keepdims=True, ) - if parameters.eigen_weights is not None: + if eigen_weights is not None: m: int = 0 - OP = patches * parameters.probe[..., m : m + 1, :, :] + OP = patches * probe[..., m:m + 1, :, :] eigen_numerator = cp.sum( cp.real(cp.conj(OP) * diff[..., m:m + 1, :, :]), axis=(-1, -2), @@ -451,14 +507,14 @@ def keep_some_args_constant( cp.abs(OP)**2, axis=(-1, -2), ) - parameters.eigen_weights[lo:hi, ..., 0:1, m:m+1] += ( + eigen_weights[lo:hi, ..., 0:1, m:m+1] += ( 0.1 * (eigen_numerator / eigen_denominator) ) # yapf: disable - if parameters.position_options: + if position_options: grad_x, grad_y = tike.ptycho.position.gaussian_gradient(patches) - crop = parameters.probe.shape[-1] // 4 + crop = probe.shape[-1] // 4 position_update_numerator[lo:hi, ..., 0] = cp.sum( cp.real( @@ -508,12 +564,12 @@ def keep_some_args_constant( ) return ( - float(cost.get()), + float(cost), psi_update_numerator, probe_update_numerator, position_update_numerator, position_update_denominator, - parameters, + eigen_weights, ) @@ -523,10 +579,10 @@ def _update_position( position_update_numerator: npt.NDArray, position_update_denominator: npt.NDArray, *, - alpha=0.05, - max_shift=1, - epoch=0, -): + alpha: float = 0.05, + max_shift: float = 1.0, + epoch: int = 0, +) -> typing.Tuple[cp.ndarray, PositionOptions]: if epoch < position_options.update_start: return scan, position_options From f95c260b063cd756620f36c60657658d6f23735d Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 27 Jun 2024 19:04:04 -0500 Subject: [PATCH 14/31] DEV: Merge two branches implementations --- src/tike/ptycho/solvers/lstsq.py | 85 +++++++++++--------------------- src/tike/ptycho/solvers/rpie.py | 81 ++++++++++-------------------- 2 files changed, 54 insertions(+), 112 deletions(-) diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 0cc5b816..133da4ac 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -23,14 +23,14 @@ def lstsq_grad( - op: tike.operators.Ptycho, - comm: tike.communicators.Comm, - data: typing.List[npt.NDArray], + parameters: PtychoParameters, + data: npt.NDArray, batches: typing.List[npt.NDArray[cp.intc]], + streams: typing.List[cp.cuda.Stream], *, - parameters: PtychoParameters, + op: tike.operators.Ptycho, epoch: int, -): +) -> PtychoParameters: """Solve the ptychography problem using Odstrcil et al's approach. Object and probe are updated simultaneously using optimal step sizes @@ -69,37 +69,18 @@ def lstsq_grad( .. seealso:: :py:mod:`tike.ptycho` """ - data0 = data[0] - batches0 = batches[0] - - probe = parameters.probe[0] - scan = parameters.scan[0] - psi = parameters.psi[0] + scan = parameters.scan + psi = parameters.psi + probe = parameters.probe algorithm_options = parameters.algorithm_options + eigen_weights = parameters.eigen_weights + eigen_probe = parameters.eigen_probe + measured_pixels = parameters.exitwave_options.measured_pixels exitwave_options = parameters.exitwave_options - probe_options = parameters.probe_options - recover_probe = True - - if parameters.position_options is None: - position_options = None - else: - position_options = parameters.position_options[0] + position_options = parameters.position_options object_options = parameters.object_options - if parameters.eigen_probe is None: - eigen_probe = None - else: - eigen_probe = parameters.eigen_probe[0] - if parameters.eigen_weights is None: - eigen_weights = None - else: - eigen_weights = parameters.eigen_weights[0] - - measured_pixels = exitwave_options.measured_pixels[0] - - streams0 = comm.streams[0] - - probe_options.preconditioner = probe_options.preconditioner[0] - object_options.preconditioner = object_options.preconditioner[0] + probe_options = parameters.probe_options + recover_probe = probe_options is not None and epoch >= probe_options.update_start # CONVERSTION AREA ABOVE --------------------------------------- @@ -129,17 +110,17 @@ def lstsq_grad( position_update_denominator, position_options, ) = _get_nearplane_gradients( - data0, + data, psi, scan, probe, eigen_probe, eigen_weights, - batches0, + batches, position_update_numerator, position_update_denominator, position_options, - streams0, + streams, measured_pixels, object_options.preconditioner, batch_index=batch_index, @@ -168,7 +149,7 @@ def lstsq_grad( eigen_probe, eigen_weights, patches, - batches0, + batches, batch_index=batch_index, num_batch=algorithm_options.num_batch, ) @@ -189,7 +170,7 @@ def lstsq_grad( m_probe_update, object_options.preconditioner, patches, - batches0, + batches, batch_index=batch_index, op=op, m=0, @@ -344,27 +325,17 @@ def lstsq_grad( # CONVERSION AREA BELOW ---------------------- - probe_options.preconditioner = [probe_options.preconditioner] - object_options.preconditioner = [object_options.preconditioner] - - parameters.probe[0] = probe - parameters.psi[0] = psi - parameters.scan[0] = scan + parameters.scan = scan + parameters.psi = psi + parameters.probe = probe parameters.algorithm_options = algorithm_options - parameters.probe_options = probe_options + parameters.eigen_weights = eigen_weights + parameters.eigen_probe = eigen_probe + parameters.exitwave_options = exitwave_options + parameters.position_options = position_options parameters.object_options = object_options - if position_options is None: - parameters.position_options = None - else: - parameters.position_options[0] = position_options - if eigen_probe is None: - parameters.eigen_probe = None - else: - parameters.eigen_probe[0] = eigen_probe - if eigen_weights is None: - parameters.eigen_weights = None - else: - parameters.eigen_weights[0] = eigen_weights + parameters.probe_options = probe_options + return parameters diff --git a/src/tike/ptycho/solvers/rpie.py b/src/tike/ptycho/solvers/rpie.py index bf4321ef..878bad82 100644 --- a/src/tike/ptycho/solvers/rpie.py +++ b/src/tike/ptycho/solvers/rpie.py @@ -24,12 +24,12 @@ def rpie( - op: tike.operators.Ptycho, - comm: tike.communicators.Comm, - data: typing.List[npt.NDArray], - batches: typing.List[typing.List[npt.NDArray[cp.intc]]], - *, parameters: PtychoParameters, + data: npt.NDArray, + batches: typing.List[npt.NDArray[cp.intc]], + streams: typing.List[cp.cuda.Stream], + *, + op: tike.operators.Ptycho, epoch: int, ) -> PtychoParameters: """Solve the ptychography problem using regularized ptychographical engine. @@ -76,37 +76,18 @@ def rpie( .. seealso:: :py:mod:`tike.ptycho` """ - data0 = data[0] - batches0 = batches[0] - - probe = parameters.probe[0] - scan = parameters.scan[0] - psi = parameters.psi[0] + scan = parameters.scan + psi = parameters.psi + probe = parameters.probe algorithm_options = parameters.algorithm_options + eigen_weights = parameters.eigen_weights + eigen_probe = parameters.eigen_probe + measured_pixels = parameters.exitwave_options.measured_pixels exitwave_options = parameters.exitwave_options - probe_options = parameters.probe_options - recover_probe = True - - if parameters.position_options is None: - position_options = None - else: - position_options = parameters.position_options[0] + position_options = parameters.position_options object_options = parameters.object_options - if parameters.eigen_probe is None: - eigen_probe = None - else: - eigen_probe = parameters.eigen_probe[0] - if parameters.eigen_weights is None: - eigen_weights = None - else: - eigen_weights = parameters.eigen_weights[0] - - measured_pixels = exitwave_options.measured_pixels[0] - - streams0 = comm.streams[0] - - probe_options.preconditioner = probe_options.preconditioner[0] - object_options.preconditioner = object_options.preconditioner[0] + probe_options = parameters.probe_options + recover_probe = probe_options is not None and epoch >= probe_options.update_start # CONVERSTION AREA ABOVE --------------------------------------- @@ -130,7 +111,7 @@ def rpie( position_update_denominator, eigen_weights, ) = _get_nearplane_gradients( - data0, + data, scan, psi, probe, @@ -141,8 +122,8 @@ def rpie( position_update_denominator, eigen_probe, eigen_weights, - batches0, - streams0, + batches, + streams, n=n, op=op, object_options=object_options, @@ -210,27 +191,17 @@ def rpie( # CONVERSION AREA BELOW ---------------------- - probe_options.preconditioner = [probe_options.preconditioner] - object_options.preconditioner = [object_options.preconditioner] - - parameters.probe[0] = probe - parameters.psi[0] = psi - parameters.scan[0] = scan + parameters.scan = scan + parameters.psi = psi + parameters.probe = probe parameters.algorithm_options = algorithm_options - parameters.probe_options = probe_options + parameters.eigen_weights = eigen_weights + parameters.eigen_probe = eigen_probe + parameters.exitwave_options = exitwave_options + parameters.position_options = position_options parameters.object_options = object_options - if position_options is None: - parameters.position_options = None - else: - parameters.position_options[0] = position_options - if eigen_probe is None: - parameters.eigen_probe = None - else: - parameters.eigen_probe[0] = eigen_probe - if eigen_weights is None: - parameters.eigen_weights = None - else: - parameters.eigen_weights[0] = eigen_weights + parameters.probe_options = probe_options + return parameters From 4ef26d4f6c5dd01ce562c629afde09a1a7feff0f Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 28 Jun 2024 14:43:18 -0500 Subject: [PATCH 15/31] DOC: Improve documentation --- src/tike/cluster.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/tike/cluster.py b/src/tike/cluster.py index 18fb5876..de4f2268 100644 --- a/src/tike/cluster.py +++ b/src/tike/cluster.py @@ -17,6 +17,7 @@ def _split_gpu( x: npt.ArrayLike, dtype: npt.DTypeLike, ) -> npt.ArrayLike: + """Return x[m] as a CuPy array on the current device.""" return cp.asarray(x[m], dtype=dtype) @@ -25,6 +26,7 @@ def _split_host( x: npt.ArrayLike, dtype: npt.DTypeLike, ) -> npt.ArrayLike: + """Return x[m] as a NumPy array.""" return np.asarray(x[m], dtype=dtype) @@ -33,6 +35,7 @@ def _split_pinned( x: npt.ArrayLike, dtype: npt.DTypeLike, ) -> npt.ArrayLike: + """Return x[m] as a CuPy pinned host memory array.""" pinned = cupyx.empty_pinned(shape=(len(m), *x.shape[1:]), dtype=dtype) pinned[...] = x[m] return pinned @@ -174,18 +177,21 @@ def by_scan_stripes_contiguous( pool: tike.communicators.ThreadPool, shape: typing.Tuple[int], scan: npt.NDArray[np.float32], - batch_method, + batch_method: typing.Literal[ + "compact", "wobbly_center", "wobbly_center_random_bootstrap" + ], num_batch: int, ) -> typing.Tuple[ typing.List[npt.NDArray], typing.List[typing.List[npt.NDArray]], typing.List[int], ]: - """Split data by into stripes and create contiguously ordered batches. + """Return the indices that will split `scan` into 2D stripes of equal count + and create contiguously ordered batches within those stripes. - Divide the field of view into one stripe per devices; within each stripe, - create batches according to the batch_method loading the batches into - contiguous blocks in device memory. + Divide the field of view into one stripe per worker in `pool`; within each + stripe, create batches according to the batch_method loading the batches + into contiguous blocks in device memory. Parameters ---------- From 0ab637eb49fb15936c5f3f4ba5e1ff83e0e1a765 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 28 Jun 2024 14:48:40 -0500 Subject: [PATCH 16/31] DOC: Add missing docstring --- src/tike/communicators/comm.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/tike/communicators/comm.py b/src/tike/communicators/comm.py index 058de609..8a40e96c 100644 --- a/src/tike/communicators/comm.py +++ b/src/tike/communicators/comm.py @@ -146,6 +146,13 @@ def swap_edges( overlap: int, edges: typing.List[int], ) -> typing.List[cp.ndarray]: + """Swap the region of each x with its neighbor around the given edges. + + Given iterable x, a list of ND arrays; edges, the coordinates in x + along dimension -2; and overlap, the width of the region to swap around + the edge; trade [..., edge-overlap:edge] with [..., edge:edge+overlap] + between neighbors. + """ # FIXME: Swap edges between MPI nodes return self.pool.swap_edges( x=x, From def77706a81159d4932f9765925a95a0eb8611bf Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 28 Jun 2024 15:14:11 -0500 Subject: [PATCH 17/31] TST: Fixup some tests and docs --- src/tike/communicators/comm.py | 3 +-- src/tike/communicators/pool.py | 24 +++++++++++++----------- tests/communicators/test_pool.py | 22 +++++++++++++++++----- tests/ptycho/test_position.py | 11 +++++------ 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/tike/communicators/comm.py b/src/tike/communicators/comm.py index 8a40e96c..90af4391 100644 --- a/src/tike/communicators/comm.py +++ b/src/tike/communicators/comm.py @@ -150,8 +150,7 @@ def swap_edges( Given iterable x, a list of ND arrays; edges, the coordinates in x along dimension -2; and overlap, the width of the region to swap around - the edge; trade [..., edge-overlap:edge] with [..., edge:edge+overlap] - between neighbors. + the edge; trade [..., edge:(edge + overlap), :] between neighbors. """ # FIXME: Swap edges between MPI nodes return self.pool.swap_edges( diff --git a/src/tike/communicators/pool.py b/src/tike/communicators/pool.py index 03b3dca2..74cc4786 100644 --- a/src/tike/communicators/pool.py +++ b/src/tike/communicators/pool.py @@ -411,19 +411,21 @@ def swap_edges( overlap: int, edges: typing.List[int], ): - """Swap edge:(edge + overlap) between neighbors in-place + """Swap [..., edge:(edge + overlap), :] between neighbors in-place - For example, given overlap=1 and edges=[4, 8, 12, 16], the following + For example, given overlap=3 and edges=[0, 4, 8, 12], the following swap would be returned: ``` - [[0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0]] - [[1 1 1 0 0 1 1 2 2 1 1 1 1 1 1 1]] - [[2 2 2 2 2 2 2 1 1 2 2 3 3 2 2 2]] - [[3 3 3 3 3 3 3 3 3 3 3 2 2 3 3 3]] + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + + [[0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0]] + [[1 1 1 1 0 0 0 1 2 2 2 1 1 1 1 1]] + [[2 2 2 2 2 2 2 2 1 1 1 2 3 3 3 2]] + [[3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 3]] ``` - Note that the minimum swapped region is 2 wide. + Note that the minimum swapped region is 1 wide. """ if overlap < 1: @@ -432,10 +434,10 @@ def swap_edges( for i in range(self.num_workers - 1): lo = edges[i + 1] hi = lo + overlap - temp0 = self._copy_to(x[i][:, lo:hi], self.workers[i + 1]) - temp1 = self._copy_to(x[i + 1][:, lo:hi], self.workers[i]) + temp0 = self._copy_to(x[i][..., lo:hi, :], self.workers[i + 1]) + temp1 = self._copy_to(x[i + 1][..., lo:hi, :], self.workers[i]) with self.Device(self.workers[i]): - x[i][:, lo:hi] = temp1 + x[i][..., lo:hi, :] = temp1 with self.Device(self.workers[i + 1]): - x[i + 1][:, lo:hi] = temp0 + x[i + 1][..., lo:hi, :] = temp0 return x diff --git a/tests/communicators/test_pool.py b/tests/communicators/test_pool.py index 575eb873..ed3ac4b8 100644 --- a/tests/communicators/test_pool.py +++ b/tests/communicators/test_pool.py @@ -143,19 +143,31 @@ def test_reduce_mean(self): def test_swap_edges(self): def init(i): - return self.xp.ones((1, 4 * self.pool.num_workers), dtype=int) * i + return self.xp.ones((1, 4 * self.pool.num_workers, 1), dtype=int) * i x = self.pool.map(init, list(range(self.pool.num_workers))) + edges = np.arange(self.pool.num_workers, dtype=int) * 4 + overlap = 3 + x1 = self.pool.swap_edges( x, - overlap=1, - edges=np.arange(self.pool.num_workers, dtype=int) * 4, + overlap=overlap, + edges=edges, ) print() - for element in x1: - print(element) + for i, element in enumerate(x1): + print(element.flatten()) + truth = self.xp.ones((1, 4 * self.pool.num_workers, 1), dtype=int) * i + if i > 0: + truth[..., edges[i] : edges[i] + overlap, :] = i - 1 + if i < len(x1) - 1: + truth[..., edges[i + 1] : (edges[i + 1] + overlap), :] = i + 1 + self.xp.testing.assert_array_equal( + element, + truth, + ) class TestSoloThreadPool(TestThreadPool): diff --git a/tests/ptycho/test_position.py b/tests/ptycho/test_position.py index 12482e54..b495ec78 100644 --- a/tests/ptycho/test_position.py +++ b/tests/ptycho/test_position.py @@ -27,6 +27,7 @@ def test_position_join(N=245, num_batch=11): assert np.amax(indices) == N - 1 np.random.shuffle(indices) batches = np.array_split(indices, num_batch) + reorder = np.argsort(np.concatenate(batches)) opts = tike.ptycho.PositionOptions( scan, @@ -35,19 +36,17 @@ def test_position_join(N=245, num_batch=11): optsb = [opts.split(b) for b in batches] - # Copies non-array params into new object - new_opts = optsb[0].split([]) + joined = PositionOptions.join(optsb, reorder=reorder) - for b, i in zip(optsb, batches): - new_opts = new_opts.join(b, i) + assert joined is not None np.testing.assert_array_equal( - new_opts.initial_scan, + joined.initial_scan, opts.initial_scan, ) np.testing.assert_array_equal( - new_opts._momentum, + joined._momentum, opts._momentum, ) From d7ed4c9b68f73d5b6da2df2d83e8e2a22e2b726b Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 28 Jun 2024 15:19:05 -0500 Subject: [PATCH 18/31] TST: Fix broken test --- tests/ptycho/test_probe.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/ptycho/test_probe.py b/tests/ptycho/test_probe.py index 5c1fa7cc..5e40b3c3 100644 --- a/tests/ptycho/test_probe.py +++ b/tests/ptycho/test_probe.py @@ -27,29 +27,21 @@ def test_eigen_probe(self): high = 21 posi = 53 eigen = 1 - comm = Comm(2) - R = comm.pool.bcast([np.random.rand(*leading, posi, 1, 1, wide, high)]) - eigen_probe = comm.pool.bcast( - [np.random.rand(*leading, 1, eigen, 1, wide, high)]) + R = np.random.rand(*leading, posi, 1, 1, wide, high) + eigen_probe = np.random.rand(*leading, 1, eigen, 1, wide, high) weights = np.random.rand(*leading, posi, eigen + 1, 1) weights -= np.mean(weights, axis=-3, keepdims=True) - weights = comm.pool.bcast([weights]) - patches = comm.pool.bcast( - [np.random.rand(*leading, posi, 1, 1, wide, high)]) - diff = comm.pool.bcast( - [np.random.rand(*leading, posi, 1, 1, wide, high)]) + patches = np.random.rand(*leading, posi, 1, 1, wide, high) + diff = np.random.rand(*leading, posi, 1, 1, wide, high) new_probe, new_weights = tike.ptycho.probe.update_eigen_probe( - comm=comm, R=R, eigen_probe=eigen_probe, weights=weights, patches=patches, diff=diff, - batches=[[ - list(range(53)), - ]], + batches=[list(range(53))], batch_index=0, c=1, m=0, From aa58d4aa135d7477eba223185d913982a460a01b Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Fri, 28 Jun 2024 15:43:24 -0500 Subject: [PATCH 19/31] DOC: Improve docs --- src/tike/ptycho/object.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index f43e3456..94b027c1 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -121,10 +121,11 @@ def resample(self, factor: float, interp) -> ObjectOptions: @staticmethod def join_psi( - x: typing.Iterable[np.ndarray], - stripe_start: typing.Iterable[int], + x: typing.List[np.ndarray], + stripe_start: typing.List[int], probe_width: int, ) -> np.ndarray: + """Recombine `x`, a list of psi, from a split reconstruction.""" joined_psi = x[0] for i in range(1, len(x)): lo = stripe_start[i] + probe_width @@ -134,10 +135,11 @@ def join_psi( @staticmethod def join( - x: typing.Iterable[ObjectOptions], - stripe_start: typing.Iterable[int], + x: typing.List[ObjectOptions], + stripe_start: typing.List[int], probe_width: int, ) -> ObjectOptions: + """Recombine `x`, a list of ObjectOptions, from split ObjectOptions""" options = ObjectOptions( convergence_tolerance=x[0].convergence_tolerance, positivity_constraint=x[0].positivity_constraint, @@ -166,6 +168,7 @@ def join( stripe_start, probe_width, ) + return options def positivity_constraint(x, r): From a4900f1dd7592bd06425efcac367331784fbe135 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Mon, 1 Jul 2024 14:01:15 -0500 Subject: [PATCH 20/31] API: Mark ptycho reconstruction get function as not implemented It is no longer safe to use the get_probe(), get_scan(), etc functions from the ptycho Reconstruction() context. These are marked as not implemented until they are fixed. --- src/tike/ptycho/ptycho.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index f7d4a994..c0a1bcb8 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -547,6 +547,7 @@ def iterate(self, num_iter: int) -> None: ) def get_scan(self) -> npt.NDArray: + raise NotImplementedError() reorder = np.argsort(np.concatenate(self.comm.order)) return self.comm.pool.gather_host( self.parameters.scan, @@ -576,7 +577,10 @@ def get_result(self) -> solvers.PtychoParameters: return parameters def __exit__(self, type, value, traceback): - self.parameters = self.get_result() + self.parameters = self.comm.pool.map( + solvers.PtychoParameters.copy_to_host, + self.parameters, + ) self.comm.__exit__(type, value, traceback) self.operator.__exit__(type, value, traceback) self.device.__exit__(type, value, traceback) @@ -589,6 +593,7 @@ def get_convergence( self ) -> typing.Tuple[typing.List[typing.List[float]], typing.List[float]]: """Return the cost function values and times as a tuple.""" + raise NotImplementedError() return ( self.parameters.algorithm_options.costs, self.parameters.algorithm_options.times, @@ -596,10 +601,12 @@ def get_convergence( def get_psi(self) -> np.array: """Return the current object estimate as a numpy array.""" + raise NotImplementedError() return self.parameters.psi[0].get() def get_probe(self) -> typing.Tuple[np.array, np.array, np.array]: """Return the current probe, eigen_probe, weights as numpy arrays.""" + raise NotImplementedError() reorder = np.argsort(np.concatenate(self.comm.order)) if self.parameters.eigen_probe is None: eigen_probe = None @@ -621,6 +628,7 @@ def peek(self) -> typing.Tuple[np.array, np.array, np.array, np.array]: Parameters returned in a tuple of object, probe, eigen_probe, eigen_weights. """ + raise NotImplementedError() psi = self.get_psi() probe, eigen_probe, eigen_weights = self.get_probe() return psi, probe, eigen_probe, eigen_weights From 20e7b6a885d568f819f978d372a4d29ee1c4aa6a Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Mon, 1 Jul 2024 17:57:34 -0500 Subject: [PATCH 21/31] REF: Avoid using copy.copy to prevent unused copies of arrays --- src/tike/ptycho/exitwave.py | 30 ++++++++++++++------- src/tike/ptycho/object.py | 26 ++++++++++++------ src/tike/ptycho/position.py | 34 +++++++++++++++++++----- src/tike/ptycho/probe.py | 53 ++++++++++++++++++++++++++++++------- 4 files changed, 109 insertions(+), 34 deletions(-) diff --git a/src/tike/ptycho/exitwave.py b/src/tike/ptycho/exitwave.py index 377b1b6c..8913a829 100644 --- a/src/tike/ptycho/exitwave.py +++ b/src/tike/ptycho/exitwave.py @@ -5,6 +5,7 @@ just free space propagation to the detector. """ + from __future__ import annotations import copy @@ -67,7 +68,7 @@ class ExitWaveOptions: exitwave updates in Fourier space. `1.0` for no scaling. """ - propagation_normalization: str = 'ortho' + propagation_normalization: str = "ortho" """Choose the scaling of the FFT in the forward model: "ortho" - the forward and adjoint operations are divided by sqrt(n) @@ -80,17 +81,27 @@ class ExitWaveOptions: def copy_to_device(self) -> ExitWaveOptions: """Copy to the current GPU memory.""" - options = copy.copy(self) - if self.measured_pixels is not None: - options.measured_pixels = cp.asarray(self.measured_pixels) - return options + return ExitWaveOptions( + measured_pixels=cp.asarray(self.measured_pixels), + noise_model=self.noise_model, + propagation_normalization=self.propagation_normalization, + step_length_start=self.step_length_start, + step_length_usemodes=self.step_length_usemodes, + step_length_weight=self.step_length_weight, + unmeasured_pixels_scaling=self.unmeasured_pixels_scaling, + ) def copy_to_host(self) -> ExitWaveOptions: """Copy to the host CPU memory.""" - options = copy.copy(self) - if self.measured_pixels is not None: - options.measured_pixels = cp.asnumpy(self.measured_pixels) - return options + return ExitWaveOptions( + measured_pixels=cp.asnumpy(self.measured_pixels), + noise_model=self.noise_model, + propagation_normalization=self.propagation_normalization, + step_length_start=self.step_length_start, + step_length_usemodes=self.step_length_usemodes, + step_length_weight=self.step_length_weight, + unmeasured_pixels_scaling=self.unmeasured_pixels_scaling, + ) def resample(self, factor: float) -> ExitWaveOptions: """Return a new `ExitWaveOptions` with the parameters rescaled.""" @@ -107,6 +118,7 @@ def resample(self, factor: float) -> ExitWaveOptions: propagation_normalization=self.propagation_normalization, ) + def poisson_steplength_all_modes( xi, abs2_Psi, diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 94b027c1..d00049f7 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -71,18 +71,20 @@ class ObjectOptions: ) """The magnitude of the illumination used for conditioning the object updates.""" - combined_update: typing.Union[npt.NDArray, None] = dataclasses.field( - init=False, - default_factory=lambda: None, - ) - """Used for compact batch updates.""" - clip_magnitude: bool = False """Whether to force the object magnitude to remain <= 1.""" def copy_to_device(self) -> ObjectOptions: """Copy to the current GPU memory.""" - options = copy.copy(self) + options = ObjectOptions( + convergence_tolerance=self.convergence_tolerance, + positivity_constraint=self.positivity_constraint, + smoothness_constraint=self.smoothness_constraint, + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + clip_magnitude=self.clip_magnitude, + ) options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: options.v = cp.asarray(self.v) @@ -94,7 +96,15 @@ def copy_to_device(self) -> ObjectOptions: def copy_to_host(self) -> ObjectOptions: """Copy to the host CPU memory.""" - options = copy.copy(self) + options = ObjectOptions( + convergence_tolerance=self.convergence_tolerance, + positivity_constraint=self.positivity_constraint, + smoothness_constraint=self.smoothness_constraint, + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + clip_magnitude=self.clip_magnitude, + ) options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: options.v = cp.asnumpy(self.v) diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 029a2526..9e241527 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -356,7 +356,7 @@ class PositionOptions: transform: AffineTransform = AffineTransform() """Global transform of positions.""" - origin: npt.ArrayLike = dataclasses.field( + origin: npt.NDArray = dataclasses.field( init=True, default_factory=lambda: np.zeros(2), ) @@ -491,9 +491,18 @@ def join( def copy_to_device(self): """Copy to the current GPU memory.""" - options = copy.copy(self) - options.initial_scan = cp.asarray(self.initial_scan) - options.origin = cp.array(self.origin) + options = PositionOptions( + initial_scan=cp.asarray(self.initial_scan), + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + use_position_regularization=self.use_position_regularization, + update_magnitude_limit=self.update_magnitude_limit, + transform=self.transform, + confidence=self.confidence, + update_start=self.update_start, + origin=cp.asarray(self.origin), + ) if self.confidence is not None: options.confidence = cp.asarray(self.confidence) if self.use_adaptive_moment: @@ -502,9 +511,18 @@ def copy_to_device(self): def copy_to_host(self): """Copy to the host CPU memory.""" - options = copy.copy(self) - options.initial_scan = cp.asnumpy(self.initial_scan) - options.origin = cp.asnumpy(self.origin) + options = PositionOptions( + initial_scan=cp.asnumpy(self.initial_scan), + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + use_position_regularization=self.use_position_regularization, + update_magnitude_limit=self.update_magnitude_limit, + transform=self.transform, + confidence=self.confidence, + update_start=self.update_start, + origin=cp.asnumpy(self.origin), + ) if self.confidence is not None: options.confidence = cp.asnumpy(self.confidence) if self.use_adaptive_moment: @@ -522,6 +540,8 @@ def resample(self, factor: float) -> PositionOptions: update_magnitude_limit=self.update_magnitude_limit, transform=self.transform.resample(factor), confidence=self.confidence, + update_start=self.update_start, + origin=self.origin * factor, ) # Momentum reset to zero when grid scale changes return new diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index 7dab1ac1..ae2d7d95 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -36,7 +36,6 @@ """ from __future__ import annotations -import copy import dataclasses import logging import typing @@ -146,12 +145,6 @@ class ProbeOptions: median_filter_abs_probe_px: typing.Tuple[float, float] = ( 1.0, 1.0 ) """A 2-element tuple with the median filter pixel widths along each dimension.""" - probe_update_sum: typing.Union[npt.NDArray, None] = dataclasses.field( - init=False, - default_factory=lambda: None, - ) - """Used for momentum updates.""" - preconditioner: typing.Union[npt.NDArray, None] = dataclasses.field( init=False, default_factory=lambda: None, @@ -165,7 +158,26 @@ class ProbeOptions: def copy_to_device(self) -> ProbeOptions: """Copy to the current GPU memory.""" - options = copy.copy(self) + options = ProbeOptions( + recover_probe=self.recover_probe, + update_start=self.update_start, + update_period=self.update_period, + init_rescale_from_measurements=self.init_rescale_from_measurements, + probe_photons=self.probe_photons, + force_orthogonality=self.force_orthogonality, + force_centered_intensity=self.force_centered_intensity, + force_sparsity=self.force_sparsity, + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + probe_support=self.probe_support, + probe_support_radius=self.probe_support_radius, + probe_support_degree=self.probe_support_degree, + additional_probe_penalty=self.additional_probe_penalty, + median_filter_abs_probe=self.median_filter_abs_probe, + median_filter_abs_probe_px=self.median_filter_abs_probe_px, + ) + options.power=self.power if self.v is not None: options.v = cp.asarray(self.v) if self.m is not None: @@ -176,7 +188,26 @@ def copy_to_device(self) -> ProbeOptions: def copy_to_host(self) -> ProbeOptions: """Copy to the host CPU memory.""" - options = copy.copy(self) + options = ProbeOptions( + recover_probe=self.recover_probe, + update_start=self.update_start, + update_period=self.update_period, + init_rescale_from_measurements=self.init_rescale_from_measurements, + probe_photons=self.probe_photons, + force_orthogonality=self.force_orthogonality, + force_centered_intensity=self.force_centered_intensity, + force_sparsity=self.force_sparsity, + use_adaptive_moment=self.use_adaptive_moment, + vdecay=self.vdecay, + mdecay=self.mdecay, + probe_support=self.probe_support, + probe_support_radius=self.probe_support_radius, + probe_support_degree=self.probe_support_degree, + additional_probe_penalty=self.additional_probe_penalty, + median_filter_abs_probe=self.median_filter_abs_probe, + median_filter_abs_probe_px=self.median_filter_abs_probe_px, + ) + options.power=self.power if self.v is not None: options.v = cp.asnumpy(self.v) if self.m is not None: @@ -200,11 +231,13 @@ def resample(self, factor: float, interp) -> ProbeOptions: vdecay=self.vdecay, mdecay=self.mdecay, probe_support=self.probe_support, - probe_support_degree=self.probe_support_degree, probe_support_radius=self.probe_support_radius, + probe_support_degree=self.probe_support_degree, + additional_probe_penalty=self.additional_probe_penalty, median_filter_abs_probe=self.median_filter_abs_probe, median_filter_abs_probe_px=self.median_filter_abs_probe_px, ) + options.power=self.power return options # Momentum reset to zero when grid scale changes From 97fb0261ab3329de5d0b52ca99e48495e76cbcd0 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 Jul 2024 12:20:22 -0500 Subject: [PATCH 22/31] BUG: Prevent accidental type promotion --- src/tike/ptycho/exitwave.py | 2 +- src/tike/ptycho/object.py | 15 +++++++++--- src/tike/ptycho/position.py | 5 +++- src/tike/ptycho/probe.py | 15 +++++++++--- src/tike/ptycho/ptycho.py | 22 +++++++++++++---- src/tike/ptycho/solvers/options.py | 38 ++++++++++++++++++++++-------- 6 files changed, 75 insertions(+), 22 deletions(-) diff --git a/src/tike/ptycho/exitwave.py b/src/tike/ptycho/exitwave.py index 8913a829..ce0fd3e0 100644 --- a/src/tike/ptycho/exitwave.py +++ b/src/tike/ptycho/exitwave.py @@ -82,7 +82,7 @@ class ExitWaveOptions: def copy_to_device(self) -> ExitWaveOptions: """Copy to the current GPU memory.""" return ExitWaveOptions( - measured_pixels=cp.asarray(self.measured_pixels), + measured_pixels=cp.asarray(self.measured_pixels, dtype=bool), noise_model=self.noise_model, propagation_normalization=self.propagation_normalization, step_length_start=self.step_length_start, diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index d00049f7..4fb45c6e 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -87,11 +87,20 @@ def copy_to_device(self) -> ObjectOptions: ) options.update_mnorm = copy.copy(self.update_mnorm) if self.v is not None: - options.v = cp.asarray(self.v) + options.v = cp.asarray( + self.v, + dtype=tike.precision.floating, + ) if self.m is not None: - options.m = cp.asarray(self.m) + options.m = cp.asarray( + self.m, + dtype=tike.precision.floating, + ) if self.preconditioner is not None: - options.preconditioner = cp.asarray(self.preconditioner) + options.preconditioner = cp.asarray( + self.preconditioner, + dtype=tike.precision.cfloating, + ) return options def copy_to_host(self) -> ObjectOptions: diff --git a/src/tike/ptycho/position.py b/src/tike/ptycho/position.py index 9e241527..d9766712 100644 --- a/src/tike/ptycho/position.py +++ b/src/tike/ptycho/position.py @@ -506,7 +506,10 @@ def copy_to_device(self): if self.confidence is not None: options.confidence = cp.asarray(self.confidence) if self.use_adaptive_moment: - options._momentum = cp.asarray(self._momentum) + options._momentum = cp.asarray( + self._momentum, + dtype=tike.precision.floating, + ) return options def copy_to_host(self): diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index ae2d7d95..5328265a 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -179,11 +179,20 @@ def copy_to_device(self) -> ProbeOptions: ) options.power=self.power if self.v is not None: - options.v = cp.asarray(self.v) + options.v = cp.asarray( + self.v, + dtype=tike.precision.floating, + ) if self.m is not None: - options.m = cp.asarray(self.m) + options.m = cp.asarray( + self.m, + dtype=tike.precision.floating, + ) if self.preconditioner is not None: - options.preconditioner = cp.asarray(self.preconditioner) + options.preconditioner = cp.asarray( + self.preconditioner, + dtype=tike.precision.cfloating, + ) return options def copy_to_host(self) -> ProbeOptions: diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index c0a1bcb8..51821a7e 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -391,6 +391,9 @@ def __enter__(self): len(self.parameters), self.comm.pool.num_workers, ) + assert self.parameters[0].psi.dtype == tike.precision.cfloating, self.parameters[0].psi.dtype + assert self.parameters[0].probe.dtype == tike.precision.cfloating, self.parameters[0].probe.dtype + assert self.parameters[0].scan.dtype == tike.precision.floating, self.parameters[0].probe.dtype self.parameters = self.comm.pool.map( solvers.PtychoParameters.copy_to_device, @@ -400,6 +403,9 @@ def __enter__(self): len(self.parameters), self.comm.pool.num_workers, ) + assert self.parameters[0].psi.dtype == tike.precision.cfloating, self.parameters[0].psi.dtype + assert self.parameters[0].probe.dtype == tike.precision.cfloating, self.parameters[0].probe.dtype + assert self.parameters[0].scan.dtype == tike.precision.floating, self.parameters[0].probe.dtype if self.parameters[0].probe_options is not None: if self.parameters[0].probe_options.init_rescale_from_measurements: @@ -409,6 +415,9 @@ def __enter__(self): self.data, self.parameters, ) + assert self.parameters[0].psi.dtype == tike.precision.cfloating, self.parameters[0].psi.dtype + assert self.parameters[0].probe.dtype == tike.precision.cfloating, self.parameters[0].probe.dtype + assert self.parameters[0].scan.dtype == tike.precision.floating, self.parameters[0].probe.dtype return self @@ -563,9 +572,13 @@ def get_result(self) -> solvers.PtychoParameters: self.comm.pool.num_workers, ) - parameters = self.comm.pool.map( - solvers.PtychoParameters.copy_to_host, - self.parameters, + # Use plain map here instead of threaded map so this method can be + # called when the context is closed. + parameters = list( + map( + solvers.PtychoParameters.copy_to_host, + self.parameters, + ) ) parameters = solvers.PtychoParameters.join( @@ -912,7 +925,8 @@ def _rescale_probe( n = np.sqrt(comm.Allreduce_reduce_cpu(n)) - rescale = cp.asarray(n[0] / n[1]) + # Force precision to prevent type promotion downstream + rescale = cp.asarray(n[0] / n[1], dtype=tike.precision.floating) logger.info("Probe rescaled by %f", rescale) diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index 3de25d87..d6a163a2 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -9,6 +9,7 @@ import scipy.ndimage import cupy as cp +import tike.precision from tike.ptycho.object import ObjectOptions from tike.ptycho.position import PositionOptions, check_allowed_positions from tike.ptycho.probe import ProbeOptions @@ -205,13 +206,28 @@ def resample( def copy_to_device(self) -> PtychoParameters: """Copy to the current device.""" return PtychoParameters( - probe=cp.asarray(self.probe), - psi=cp.asarray(self.psi), - scan=cp.asarray(self.scan), - eigen_probe=cp.asarray(self.eigen_probe) + probe=cp.asarray( + self.probe, + dtype=tike.precision.cfloating, + ), + psi=cp.asarray( + self.psi, + dtype=tike.precision.cfloating, + ), + scan=cp.asarray( + self.scan, + dtype=tike.precision.floating, + ), + eigen_probe=cp.asarray( + self.eigen_probe, + dtype=tike.precision.cfloating, + ) if self.eigen_probe is not None else None, - eigen_weights=cp.asarray(self.eigen_weights) + eigen_weights=cp.asarray( + self.eigen_weights, + dtype=tike.precision.floating, + ) if self.eigen_weights is not None else None, algorithm_options=self.algorithm_options, @@ -264,11 +280,13 @@ def split( ) -> PtychoParameters: """Return a new PtychoParameters with only the data from the indices""" return PtychoParameters( - probe=x.probe, - psi=x.psi, - scan=x.scan[indices], - eigen_probe=x.eigen_probe, - eigen_weights=x.eigen_weights[indices] + probe=x.probe.astype(tike.precision.cfloating), + psi=x.psi.astype(tike.precision.cfloating), + scan=x.scan[indices].astype(tike.precision.floating), + eigen_probe=x.eigen_probe.astype(tike.precision.cfloating) + if x.eigen_probe is not None + else None, + eigen_weights=x.eigen_weights[indices].astype(tike.precision.floating) if x.eigen_weights is not None else None, algorithm_options=copy.deepcopy(x.algorithm_options), From 6ce09ce767f0422d343df292611ea2aaa47a0cae Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 Jul 2024 14:33:17 -0500 Subject: [PATCH 23/31] BUG: Soft transition when for join_psi --- src/tike/communicators/pool.py | 30 ++++++++++++++++++++-- src/tike/ptycho/object.py | 5 ++-- src/tike/ptycho/ptycho.py | 4 +-- src/tike/ptycho/solvers/_preconditioner.py | 2 +- src/tike/ptycho/solvers/options.py | 8 ++++-- 5 files changed, 40 insertions(+), 9 deletions(-) diff --git a/src/tike/communicators/pool.py b/src/tike/communicators/pool.py index 74cc4786..49f3f521 100644 --- a/src/tike/communicators/pool.py +++ b/src/tike/communicators/pool.py @@ -437,7 +437,33 @@ def swap_edges( temp0 = self._copy_to(x[i][..., lo:hi, :], self.workers[i + 1]) temp1 = self._copy_to(x[i + 1][..., lo:hi, :], self.workers[i]) with self.Device(self.workers[i]): - x[i][..., lo:hi, :] = temp1 + rampu = cp.linspace( + 0.0, + 1.0, + overlap + 2, + endpoint=True, + )[1:-1][..., None] + rampd = cp.linspace( + 1.0, + 0.0, + overlap + 2, + endpoint=True, + )[1:-1][..., None] + x[i][..., lo:hi, :] = rampd * x[i][..., lo:hi, :] + rampu * temp1 with self.Device(self.workers[i + 1]): - x[i + 1][..., lo:hi, :] = temp0 + rampu = cp.linspace( + 0.0, + 1.0, + overlap + 2, + endpoint=True, + )[1:-1][..., None] + rampd = cp.linspace( + 1.0, + 0.0, + overlap + 2, + endpoint=True, + )[1:-1][..., None] + x[i + 1][..., lo:hi, :] = ( + rampd * temp0 + rampu * x[i + 1][..., lo:hi, :] + ) return x diff --git a/src/tike/ptycho/object.py b/src/tike/ptycho/object.py index 4fb45c6e..6677b0f3 100644 --- a/src/tike/ptycho/object.py +++ b/src/tike/ptycho/object.py @@ -146,9 +146,10 @@ def join_psi( ) -> np.ndarray: """Recombine `x`, a list of psi, from a split reconstruction.""" joined_psi = x[0] + w = probe_width // 2 for i in range(1, len(x)): - lo = stripe_start[i] + probe_width - hi = stripe_start[i + 1] + probe_width if i + 1 < len(x) else x[0].shape[1] + lo = stripe_start[i] + w + hi = stripe_start[i + 1] + w if i + 1 < len(x) else x[0].shape[1] joined_psi[:, lo:hi, :] = x[i][:, lo:hi, :] return joined_psi diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 51821a7e..f10327c3 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -501,9 +501,9 @@ def iterate(self, num_iter: int) -> None: self.comm.swap_edges( [e.psi for e in self.parameters], # reduce overlap to stay away from edge noise - overlap=pw // 5, + overlap=pw-1, # The actual edge is centered on the probe - edges=[e + pw // 2 for e in self.comm.stripe_start], + edges=self.comm.stripe_start, ), self.parameters, ): diff --git a/src/tike/ptycho/solvers/_preconditioner.py b/src/tike/ptycho/solvers/_preconditioner.py index 2ad803fb..8367456d 100644 --- a/src/tike/ptycho/solvers/_preconditioner.py +++ b/src/tike/ptycho/solvers/_preconditioner.py @@ -73,7 +73,7 @@ def make_certain_args_constant( hi=len(parameters.scan), ) - return psi_update_denominator + return psi_update_denominator[None, ...] @cp.fuse() diff --git a/src/tike/ptycho/solvers/options.py b/src/tike/ptycho/solvers/options.py index d6a163a2..1ef00677 100644 --- a/src/tike/ptycho/solvers/options.py +++ b/src/tike/ptycho/solvers/options.py @@ -308,7 +308,7 @@ def join( probe=x[0].probe, psi=ObjectOptions.join_psi( [e.psi for e in x], - probe_width=x[0].probe.shape[-2] // 2, + probe_width=x[0].probe.shape[-2], stripe_start=stripe_start, ), scan=np.concatenate( @@ -327,7 +327,11 @@ def join( exitwave_options=x[0].exitwave_options, # TODO: synchronize probe momentum elsewhere probe_options=x[0].probe_options, - object_options=x[0].object_options, + object_options=ObjectOptions.join( + [e.object_options for e in x], + stripe_start=stripe_start, + probe_width=x[0].probe.shape[-2], + ), position_options=PositionOptions.join( [e.position_options for e in x], reorder, From 20a0688a0f8206bfdf410349dd2712c58a6991c1 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 2 Jul 2024 14:33:43 -0500 Subject: [PATCH 24/31] BUG: Use get_result() instead of directly accessing parameters --- src/tike/ptycho/ptycho.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index f10327c3..50a7f181 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -228,27 +228,28 @@ def reconstruct( use_mpi, ) as context: context.iterate(parameters.algorithm_options.num_iter) + result = context.get_result() if ( logger.getEffectiveLevel() <= logging.INFO - ) and context.parameters.position_options: + ) and result.position_options: mean_scaling = 0.5 * ( - context.parameters.position_options.transform.scale0 - + context.parameters.position_options.transform.scale1 + result.position_options.transform.scale0 + + result.position_options.transform.scale1 ) logger.info( f"Global scaling of {mean_scaling:.3e} detected from position correction." " Probably your estimate of photon energy and/or sample to detector " "distance is off by that amount." ) - t = context.parameters.position_options.transform.asarray() + t = result.position_options.transform.asarray() logger.info(f"""Affine transform parameters: {t[0,0]: .3e}, {t[0,1]: .3e} {t[1,0]: .3e}, {t[1,1]: .3e} """) - return context.parameters + return result def _clip_magnitude(x, a_max): @@ -998,29 +999,30 @@ def reconstruct_multigrid( use_mpi=use_mpi, ) as context: context.iterate(resampled_parameters.algorithm_options.num_iter) + result = context.get_result() if level == 0: if ( logger.getEffectiveLevel() <= logging.INFO - ) and context.parameters.position_options: + ) and result.position_options: mean_scaling = 0.5 * ( - context.parameters.position_options.transform.scale0 - + context.parameters.position_options.transform.scale1 + result.position_options.transform.scale0 + + result.position_options.transform.scale1 ) logger.info( f"Global scaling of {mean_scaling:.3e} detected from position correction." " Probably your estimate of photon energy and/or sample to detector " "distance is off by that amount." ) - t = context.parameters.position_options.transform.asarray() + t = result.position_options.transform.asarray() logger.info(f"""Affine transform parameters: {t[0,0]: .3e}, {t[0,1]: .3e} {t[1,0]: .3e}, {t[1,1]: .3e} """) - return context.parameters + return result # Upsample result to next grid - resampled_parameters = context.parameters.resample(2.0, interp) + resampled_parameters = result.resample(2.0, interp) raise RuntimeError('This should not happen.') From 178cd292128eea249273de5c12f7ecb6d683eb46 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 3 Jul 2024 10:41:46 -0500 Subject: [PATCH 25/31] DEV: Reimplement some getter functions --- src/tike/ptycho/ptycho.py | 39 ++++++++++++++------------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 50a7f181..80cffec4 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -557,11 +557,10 @@ def iterate(self, num_iter: int) -> None: ) def get_scan(self) -> npt.NDArray: - raise NotImplementedError() reorder = np.argsort(np.concatenate(self.comm.order)) - return self.comm.pool.gather_host( - self.parameters.scan, - axis=-2, + return np.concatenate( + [cp.asnumpy(e.scan) for e in self.parameters], + axis=0, )[reorder] def get_result(self) -> solvers.PtychoParameters: @@ -607,46 +606,36 @@ def get_convergence( self ) -> typing.Tuple[typing.List[typing.List[float]], typing.List[float]]: """Return the cost function values and times as a tuple.""" - raise NotImplementedError() return ( - self.parameters.algorithm_options.costs, - self.parameters.algorithm_options.times, + self.parameters[0].algorithm_options.costs, + self.parameters[0].algorithm_options.times, ) def get_psi(self) -> np.array: """Return the current object estimate as a numpy array.""" - raise NotImplementedError() - return self.parameters.psi[0].get() + return ObjectOptions.join_psi( + [cp.asnumpy(e.psi) for e in self.parameters], + probe_width=self.parameters[0].probe.shape[-2], + stripe_start=self.comm.stripe_start, + ) def get_probe(self) -> typing.Tuple[np.array, np.array, np.array]: """Return the current probe, eigen_probe, weights as numpy arrays.""" - raise NotImplementedError() - reorder = np.argsort(np.concatenate(self.comm.order)) - if self.parameters.eigen_probe is None: + if self.parameters[0].eigen_probe is None: eigen_probe = None else: - eigen_probe = self.parameters.eigen_probe[0].get() + eigen_probe = self.parameters[0].eigen_probe.get() if self.parameters.eigen_weights is None: eigen_weights = None else: + reorder = np.argsort(np.concatenate(self.comm.order)) eigen_weights = self.comm.pool.gather( self.parameters.eigen_weights, axis=-3, )[reorder].get() - probe = self.parameters.probe[0].get() + probe = self.parameters[0].probe.get() return probe, eigen_probe, eigen_weights - def peek(self) -> typing.Tuple[np.array, np.array, np.array, np.array]: - """Return the curent values of object and probe as numpy arrays. - - Parameters returned in a tuple of object, probe, eigen_probe, - eigen_weights. - """ - raise NotImplementedError() - psi = self.get_psi() - probe, eigen_probe, eigen_weights = self.get_probe() - return psi, probe, eigen_probe, eigen_weights - def append_new_data( self, new_data: npt.NDArray, From 6b47861966d24add7f2ca255db7c916425195a2f Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 3 Jul 2024 16:33:05 -0500 Subject: [PATCH 26/31] BUG: Fix broken conditional that skips probe orthogonalization --- src/tike/ptycho/probe.py | 66 ++++++++++----------------------------- src/tike/ptycho/ptycho.py | 17 +++++----- 2 files changed, 24 insertions(+), 59 deletions(-) diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index 5328265a..6d975d32 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -56,9 +56,6 @@ class ProbeOptions: """Manage data and setting related to probe correction.""" - recover_probe: bool = False - """Boolean switch used to indicate whether to update probe or not.""" - update_start: int = 0 """Start probe updates at this epoch.""" @@ -156,10 +153,13 @@ class ProbeOptions: ) """The power of the primary probe modes at each iteration.""" + def recover_probe(self, epoch: int) -> bool: + """Return whether to update probe or not.""" + return (epoch >= self.update_start) and (epoch % self.update_period == 0) + def copy_to_device(self) -> ProbeOptions: """Copy to the current GPU memory.""" options = ProbeOptions( - recover_probe=self.recover_probe, update_start=self.update_start, update_period=self.update_period, init_rescale_from_measurements=self.init_rescale_from_measurements, @@ -198,7 +198,6 @@ def copy_to_device(self) -> ProbeOptions: def copy_to_host(self) -> ProbeOptions: """Copy to the host CPU memory.""" options = ProbeOptions( - recover_probe=self.recover_probe, update_start=self.update_start, update_period=self.update_period, init_rescale_from_measurements=self.init_rescale_from_measurements, @@ -228,7 +227,6 @@ def copy_to_host(self) -> ProbeOptions: def resample(self, factor: float, interp) -> ProbeOptions: """Return a new `ProbeOptions` with the parameters rescaled.""" options = ProbeOptions( - recover_probe=self.recover_probe, update_start=self.update_start, update_period=self.update_period, init_rescale_from_measurements=self.init_rescale_from_measurements, @@ -284,8 +282,19 @@ def get_varying_probe(shared_probe, eigen_probe=None, weights=None): return shared_probe.copy() -def _constrain_variable_probe1(variable_probe, weights): - """Help use the thread pool with constrain_variable_probe""" +def constrain_variable_probe(variable_probe, weights): + """Add the following constraints to variable probe weights + + 1. Remove outliars from weights + 2. Enforce orthogonality once per epoch + 3. Sort the variable probes by their total energy + 4. Normalize the variable probes so the energy is contained in the weight + + """ + # TODO: No smoothing of variable probe weights yet because the weights are + # not stored consecutively in device memory. Smoothing would require either + # sorting and synchronizing the weights with the host OR implementing + # smoothing of non-gridded data with splines using device-local data only. # Normalize variable probes vnorm = tike.linalg.mnorm(variable_probe, axis=(-2, -1), keepdims=True) @@ -307,12 +316,6 @@ def _constrain_variable_probe1(variable_probe, weights): axis=-3, )**2 - return variable_probe, weights, power - - -def _constrain_variable_probe2(variable_probe, weights, power): - """Help use the thread pool with constrain_variable_probe""" - # Sort the probes by energy probes_with_modes = variable_probe.shape[-3] for i in range(probes_with_modes): @@ -335,39 +338,6 @@ def _constrain_variable_probe2(variable_probe, weights, power): return variable_probe, weights -def constrain_variable_probe(comm, variable_probe, weights): - """Add the following constraints to variable probe weights - - 1. Remove outliars from weights - 2. Enforce orthogonality once per epoch - 3. Sort the variable probes by their total energy - 4. Normalize the variable probes so the energy is contained in the weight - - """ - # TODO: No smoothing of variable probe weights yet because the weights are - # not stored consecutively in device memory. Smoothing would require either - # sorting and synchronizing the weights with the host OR implementing - # smoothing of non-gridded data with splines using device-local data only. - - variable_probe, weights, power = zip(*comm.pool.map( - _constrain_variable_probe1, - variable_probe, - weights, - )) - - # reduce power by sum across all devices - power = comm.pool.allreduce(power) - - variable_probe, weights = (list(a) for a in zip(*comm.pool.map( - _constrain_variable_probe2, - variable_probe, - weights, - power, - ))) - - return variable_probe, weights - - def _get_update(R, eigen_probe, weights, batches, *, batch_index, c, m): """ Parameters @@ -473,8 +443,6 @@ def update_eigen_probe( Parameters ---------- - comm : :py:class:`tike.communicators.Comm` - An object which manages communications between both GPUs and nodes. R : (POSI, 1, 1, WIDE, HIGH) complex64 Residual probe updates; what's left after subtracting the shared probe update from the varying probe updates for each position diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 80cffec4..9bbf1760 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -442,15 +442,10 @@ def iterate(self, num_iter: int) -> None: total_epochs = len(self.parameters[0].algorithm_options.times) - # if self.parameters.probe_options is not None: - # self.parameters.probe_options.recover_probe = ( - # total_epochs >= self.parameters.probe_options.update_start - # and (total_epochs % self.parameters.probe_options.update_period) == 0 - # ) - self.parameters = self.comm.pool.map( _apply_probe_constraints, self.parameters, + epoch=total_epochs ) self.parameters = solvers.update_preconditioners( @@ -715,9 +710,11 @@ def append_new_data( def _apply_probe_constraints( parameters: solvers.PtychoParameters, + *, + epoch: int, ) -> solvers.PtychoParameters: if parameters.probe_options is not None: - if parameters.probe_options.recover_probe: + if parameters.probe_options.recover_probe(epoch): if parameters.probe_options.median_filter_abs_probe: parameters.probe = apply_median_filter_abs_probe( @@ -736,7 +733,7 @@ def _apply_probe_constraints( f=parameters.probe_options.force_sparsity, ) - if parameters[0].probe_options.force_orthogonality: + if parameters.probe_options.force_orthogonality: ( parameters.probe, power, @@ -748,7 +745,7 @@ def _apply_probe_constraints( parameters.probe, ) - parameters.probe_options.power.append(power[0].get()) + parameters.probe_options.power.append(cp.asnumpy(power)) if parameters.algorithm_options.rescale_method == "constant_probe_photons" and ( len(parameters.algorithm_options.costs) @@ -765,7 +762,7 @@ def _apply_probe_constraints( if ( parameters.eigen_probe is not None - and parameters.probe_options.recover_probe + and parameters.probe_options.recover_probe(epoch) ): ( parameters.eigen_probe, From 16698c428a3547979077e522ea9c5e1211c7b6bc Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Wed, 3 Jul 2024 18:38:19 -0500 Subject: [PATCH 27/31] Merge functions back together --- src/tike/ptycho/probe.py | 171 ++++++++++++--------------------------- 1 file changed, 52 insertions(+), 119 deletions(-) diff --git a/src/tike/ptycho/probe.py b/src/tike/ptycho/probe.py index 6d975d32..91a563a0 100644 --- a/src/tike/ptycho/probe.py +++ b/src/tike/ptycho/probe.py @@ -338,89 +338,6 @@ def constrain_variable_probe(variable_probe, weights): return variable_probe, weights -def _get_update(R, eigen_probe, weights, batches, *, batch_index, c, m): - """ - Parameters - ---------- - R : (B, 1, 1, H, W) - eigen_probe (1, C, M, H, W) - weights : (B, C, M) - """ - lo = batches[batch_index][0] - hi = lo + len(batches[batch_index]) - # (POSI, 1, 1, 1, 1) to match other arrays - weights = weights[lo:hi, c:c + 1, m:m + 1, None, None] - eigen_probe = eigen_probe[:, c - 1:c, m:m + 1, :, :] - norm_weights = tike.linalg.norm(weights, axis=-5, keepdims=True)**2 - - if np.all(norm_weights == 0): - raise ValueError('eigen_probe weights cannot all be zero?') - - # FIXME: What happens when weights is zero!? - proj = (np.real(R.conj() * eigen_probe) + weights) / norm_weights - return np.mean( - R * np.mean(proj, axis=(-2, -1), keepdims=True), - axis=-5, - keepdims=True, - ) - - -def _get_d(patches, diff, eigen_probe, update, *, β, c, m): - """ - Parameters - ---------- - patches : (B, 1, 1, H, W) - diff : (B, 1, M, H, W) - eigen_probe (1, C, M, H, W) - update : (1, 1, 1, H, W) - """ - eigen_probe[:, c - 1:c, m:m + 1, :, :] += β * update / tike.linalg.mnorm( - update, - axis=(-2, -1), - keepdims=True, - ) - eigen_probe[:, c - 1:c, m:m + 1, :, :] /= tike.linalg.mnorm( - eigen_probe[:, c - 1:c, m:m + 1, :, :], - axis=(-2, -1), - keepdims=True, - ) - assert np.all(np.isfinite(eigen_probe)) - - # Determine new eigen_weights for the updated eigen probe - phi = patches * eigen_probe[:, c - 1:c, m:m + 1, :, :] - n = np.mean( - np.real(diff[:, :, m:m + 1, :, :] * phi.conj()), - axis=(-1, -2), - keepdims=False, - ) - d = np.mean(np.square(np.abs(phi)), axis=(-1, -2), keepdims=False) - d_mean = np.mean(d, axis=-3, keepdims=True) - return eigen_probe, n, d, d_mean - - -def _get_weights_mean(n, d, d_mean, weights, batches, *, batch_index, c, m): - """ - Parameters - ---------- - n : (B, 1, 1) - d : (B, 1, 1) - d_mean : (1, 1, 1) - weights : (B, C, M) - """ - lo = batches[batch_index][0] - hi = lo + len(batches[batch_index]) - # yapf: disable - weight_update = ( - n / (d + 0.1 * d_mean) - ).reshape(*weights[lo:hi, c:c + 1, m:m + 1].shape) - # yapf: enable - assert np.all(np.isfinite(weight_update)) - - # (33) The sum of all previous steps constrained to zero-mean - weights[lo:hi, c:c + 1, m:m + 1] += weight_update - return weights - - def update_eigen_probe( R, eigen_probe, @@ -473,51 +390,67 @@ def update_eigen_probe( assert weights[batches[batch_index], :, :].shape[-3] == R.shape[-5] assert R.shape[-2:] == eigen_probe.shape[-2:] - update = _get_update( - R, - eigen_probe, - weights, - batches, - batch_index=batch_index, - c=c, - m=m, + lo = batches[batch_index][0] + hi = lo + len(batches[batch_index]) + # (POSI, 1, 1, 1, 1) to match other arrays + norm_weights = ( + tike.linalg.norm( + weights[lo:hi, c : c + 1, m : m + 1, None, None], + axis=-5, + keepdims=True, + ) + ** 2 ) - update = cp.mean( - update, + if np.all(norm_weights == 0): + raise ValueError("eigen_probe weights cannot all be zero?") + + # FIXME: What happens when weights is zero!? + proj = ( + np.real(R.conj() * eigen_probe[:, c - 1 : c, m : m + 1, :, :]) + + weights[lo:hi, c : c + 1, m : m + 1, None, None] + ) / norm_weights + update = np.mean( + R * np.mean(proj, axis=(-2, -1), keepdims=True), axis=-5, + keepdims=False, ) - ( - eigen_probe, - n, - d, - d_mean, - ) = _get_d( - patches, - diff, - eigen_probe, - update, - β=β, - c=c, - m=m, + eigen_probe[:, c - 1 : c, m : m + 1, :, :] += ( + β + * update + / tike.linalg.mnorm( + update, + axis=(-2, -1), + keepdims=True, + ) ) - - d_mean = cp.mean( - d_mean, - axis=-3, + eigen_probe[:, c - 1 : c, m : m + 1, :, :] /= tike.linalg.mnorm( + eigen_probe[:, c - 1 : c, m : m + 1, :, :], + axis=(-2, -1), + keepdims=True, ) + assert np.all(np.isfinite(eigen_probe)) - weights = _get_weights_mean( - n, - d, - d_mean, - weights, - batches, - batch_index=batch_index, - c=c, - m=m, + # Determine new eigen_weights for the updated eigen probe + phi = patches * eigen_probe[:, c - 1 : c, m : m + 1, :, :] + n = np.mean( + np.real(diff[:, :, m : m + 1, :, :] * phi.conj()), + axis=(-1, -2), + keepdims=False, ) + d = np.mean(np.square(np.abs(phi)), axis=(-1, -2), keepdims=False) + d_mean = np.mean(d, axis=-3, keepdims=False) + + # yapf: disable + weight_update = ( + n / (d + 0.1 * d_mean) + ).reshape(*weights[lo:hi, c:c + 1, m:m + 1].shape) + # yapf: enable + assert np.all(np.isfinite(weight_update)) + + # (33) The sum of all previous steps constrained to zero-mean + weights[lo:hi, c : c + 1, m : m + 1] += weight_update return eigen_probe, weights From d9f03b92dbd65b57e8a8ab2a18eb4a23e0d044f6 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Mon, 8 Jul 2024 12:35:16 -0500 Subject: [PATCH 28/31] API: Move work of default stream when using pool.map --- src/tike/communicators/comm.py | 5 ----- src/tike/communicators/pool.py | 13 ++++++++++--- src/tike/ptycho/ptycho.py | 4 ++-- src/tike/ptycho/solvers/_preconditioner.py | 4 ++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/tike/communicators/comm.py b/src/tike/communicators/comm.py index 90af4391..06a2387f 100644 --- a/src/tike/communicators/comm.py +++ b/src/tike/communicators/comm.py @@ -13,10 +13,6 @@ from .pool import ThreadPool -def _init_streams(): - return [cp.cuda.Stream() for _ in range(2)] - - class Comm: """A Ptychography communicator. @@ -47,7 +43,6 @@ def __init__( self.use_mpi = True self.mpi = mpi() self.pool = pool(gpu_count) - self.streams = self.pool.map(_init_streams) def __enter__(self): self.mpi.__enter__() diff --git a/src/tike/communicators/pool.py b/src/tike/communicators/pool.py index 49f3f521..78ec4056 100644 --- a/src/tike/communicators/pool.py +++ b/src/tike/communicators/pool.py @@ -89,6 +89,12 @@ def __init__( self.num_workers) if self.num_workers > 1 else NoPoolExecutor( self.num_workers) + def f(worker): + with self.Device(worker): + return [cp.cuda.Stream() for _ in range(2)] + + self.streams = list(self.executor.map(f, self.workers)) + def __enter__(self): if self.workers[0] != cp.cuda.Device().id: raise ValueError( @@ -397,13 +403,14 @@ def map( ) -> list: """ThreadPoolExecutor.map, but wraps call in a cuda.Device context.""" - def f(worker, *args): + def f(worker, streams, *args): with self.Device(worker): - return func(*args, **kwargs) + with streams[1]: + return func(*args, **kwargs) workers = self.workers if workers is None else workers - return list(self.executor.map(f, workers, *iterables)) + return list(self.executor.map(f, workers, self.streams, *iterables)) def swap_edges( self, diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 9bbf1760..6c3ccef3 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -459,7 +459,7 @@ def iterate(self, num_iter: int) -> None: self.parameters, self.data, self.batches, - self.comm.streams, + self.comm.pool.streams, op=self.operator, epoch=len(self.parameters[0].algorithm_options.times), ) @@ -901,7 +901,7 @@ def _rescale_probe( _get_rescale, data, parameters, - comm.streams, + comm.pool.streams, operator=operator, ) except cp.cuda.memory.OutOfMemoryError: diff --git a/src/tike/ptycho/solvers/_preconditioner.py b/src/tike/ptycho/solvers/_preconditioner.py index 8367456d..2afb9b68 100644 --- a/src/tike/ptycho/solvers/_preconditioner.py +++ b/src/tike/ptycho/solvers/_preconditioner.py @@ -135,7 +135,7 @@ def update_preconditioners( preconditioner = comm.pool.map( _psi_preconditioner, parameters, - comm.streams, + comm.pool.streams, operator=operator, ) @@ -151,7 +151,7 @@ def update_preconditioners( preconditioner = comm.pool.map( _probe_preconditioner, parameters, - comm.streams, + comm.pool.streams, operator=operator, ) From f6611c8485e08376b935a4c2c95c2f46ab310dcd Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 9 Jul 2024 12:12:13 -0500 Subject: [PATCH 29/31] REF: Reorder object probe constraints --- src/tike/ptycho/ptycho.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/tike/ptycho/ptycho.py b/src/tike/ptycho/ptycho.py index 6c3ccef3..b0228dee 100644 --- a/src/tike/ptycho/ptycho.py +++ b/src/tike/ptycho/ptycho.py @@ -464,16 +464,6 @@ def iterate(self, num_iter: int) -> None: epoch=len(self.parameters[0].algorithm_options.times), ) - self.parameters = self.comm.pool.map( - _apply_object_constraints, - self.parameters, - ) - - self.parameters = self.comm.pool.map( - _apply_position_constraints, - self.parameters, - ) - for i, reduced_probe in enumerate( self.comm.Allreduce_mean( [e.probe[None, ...] for e in self.parameters], @@ -493,7 +483,6 @@ def iterate(self, num_iter: int) -> None: pw = self.parameters[0].probe.shape[-2] for swapped, parameters in zip( - # TODO: Try blending edges during swap instead of replacing self.comm.swap_edges( [e.psi for e in self.parameters], # reduce overlap to stay away from edge noise @@ -518,6 +507,16 @@ def iterate(self, num_iter: int) -> None: reduced_transform ) + self.parameters = self.comm.pool.map( + _apply_object_constraints, + self.parameters, + ) + + self.parameters = self.comm.pool.map( + _apply_position_constraints, + self.parameters, + ) + reduced_cost = np.mean( [e.algorithm_options.costs[-1] for e in self.parameters], ) From 5d59edb12d228cd83b3afda65cd234dd6b80ae1b Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 9 Jul 2024 12:12:51 -0500 Subject: [PATCH 30/31] REF: Avoid device host transfer until end of epoch --- src/tike/ptycho/solvers/lstsq.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 133da4ac..6cfa5588 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -94,7 +94,7 @@ def lstsq_grad( position_update_numerator = None position_update_denominator = None - batch_cost: typing.List[float] = [] + batch_cost = cp.empty(algorithm_options.num_batch, dtype=tike.precision.floating) beta_object: typing.List[float] = [] beta_probe: typing.List[float] = [] for batch_index in order(algorithm_options.num_batch): @@ -249,7 +249,7 @@ def lstsq_grad( beta_probe.append(bbeta_probe) - batch_cost += costs.tolist() + batch_cost[batch_index] = cp.mean(costs) if ( position_options is not None @@ -264,7 +264,7 @@ def lstsq_grad( epoch=epoch, ) - algorithm_options.costs.append(batch_cost) + algorithm_options.costs.append(batch_cost.tolist()) if object_options and algorithm_options.batch_method == 'compact': object_update_precond = _precondition_object_update( @@ -287,7 +287,7 @@ def lstsq_grad( v=object_options.v, m=object_options.m, mdecay=object_options.mdecay, - errors=list(float(np.mean(x)) for x in algorithm_options.costs[-3:]), + errors=list(float(cp.mean(x)) for x in algorithm_options.costs[-3:]), beta=beta_object, memory_length=3, ) @@ -300,12 +300,12 @@ def lstsq_grad( beta_probe = cp.mean(cp.stack(beta_probe)) dprobe = probe_combined_update if probe_options.v is None: - probe_options.v = np.zeros_like( + probe_options.v = cp.zeros_like( dprobe, shape=(3, *dprobe.shape), ) if probe_options.m is None: - probe_options.m = np.zeros_like(dprobe,) + probe_options.m = cp.zeros_like(dprobe,) # ptychoshelves only applies momentum to the main probe mode = 0 ( From a7a40ec0ae4cf08638eead9e566e6781afe16577 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Tue, 9 Jul 2024 13:39:53 -0500 Subject: [PATCH 31/31] BUG: Fix bugs from merge --- src/tike/ptycho/solvers/_preconditioner.py | 18 +++++++++--------- src/tike/ptycho/solvers/lstsq.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/tike/ptycho/solvers/_preconditioner.py b/src/tike/ptycho/solvers/_preconditioner.py index 048cc783..64857f79 100644 --- a/src/tike/ptycho/solvers/_preconditioner.py +++ b/src/tike/ptycho/solvers/_preconditioner.py @@ -46,7 +46,7 @@ def _psi_preconditioner( ) -> npt.NDArray: psi_update_denominator = cp.zeros( - shape=parameters.psi.shape[-2:], + shape=parameters.psi.shape, dtype=parameters.psi.dtype, ) @@ -58,25 +58,25 @@ def make_certain_args_constant( nonlocal psi_update_denominator probe_amp = _probe_amp_sum(parameters.probe)[:, 0] - psi_update_denominator[...] = operator.diffraction.patch.adj( + psi_update_denominator[0] = operator.diffraction.patch.adj( patches=probe_amp, - images=psi_update_denominator, + images=psi_update_denominator[0], positions=parameters.scan[lo:hi], ) - probe1 = probe[:, 0] - for i in range(1, len(psi)): + probe1 = parameters.probe[:, 0] + for i in range(1, len(parameters.psi)): probe1 = operator.diffraction.diffraction.fwd( probe=probe1, - scan=scan[lo:hi], - psi=psi[i-1], + scan=parameters.scan[lo:hi], + psi=parameters.psi[i-1], ) probe1 = operator.diffraction.propagation.fwd(probe1) probe_amp = _probe_amp_sum(probe1) psi_update_denominator[i] = operator.diffraction.patch.adj( patches=probe_amp, images=psi_update_denominator[i], - positions=scan[lo:hi], + positions=parameters.scan[lo:hi], ) tike.communicators.stream.stream_and_modify2( @@ -87,7 +87,7 @@ def make_certain_args_constant( hi=len(parameters.scan), ) - return psi_update_denominator[None, ...] + return psi_update_denominator @cp.fuse() diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 6cfa5588..3700d123 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -287,7 +287,7 @@ def lstsq_grad( v=object_options.v, m=object_options.m, mdecay=object_options.mdecay, - errors=list(float(cp.mean(x)) for x in algorithm_options.costs[-3:]), + errors=list(float(np.mean(x)) for x in algorithm_options.costs[-3:]), beta=beta_object, memory_length=3, )