From 9fc95eeb5f0adc9cf3fab385c1c7556e287f81a1 Mon Sep 17 00:00:00 2001 From: Daniel Ching Date: Thu, 18 Jul 2024 12:15:07 -0500 Subject: [PATCH] REF: Merge functions that are no longer separated by reduction --- src/tike/ptycho/solvers/lstsq.py | 96 ++++++-------------------------- 1 file changed, 18 insertions(+), 78 deletions(-) diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 3700d123..0215ddee 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -133,11 +133,6 @@ def lstsq_grad( ) if recover_probe: - m_probe_update = cp.mean( - m_probe_update, - axis=-5, - ) - ( eigen_probe, eigen_weights, @@ -156,11 +151,8 @@ def lstsq_grad( ( object_update_precond, - A1, - A2, - A4, - b1, - b2, + bbeta_object, + bbeta_probe, ) = _precondition_nearplane_gradients( diff, scan, @@ -179,44 +171,6 @@ def lstsq_grad( probe_options=probe_options, ) - if object_options is not None: - A1_delta = cp.mean(A1, axis=-3) - else: - A1_delta = None - - if recover_probe: - A4_delta = cp.mean(A4, axis=-3) - else: - A4_delta = None - - ( - weighted_step_psi, - weighted_step_probe, - ) = _get_nearplane_steps( - A1, - A2, - A4, - b1, - b2, - A1_delta, - A4_delta, - recover_psi=object_options is not None, - recover_probe=recover_probe, - m=0, - ) - - if object_options is not None: - bbeta_object = cp.mean( - weighted_step_psi, - axis=-5, - )[..., 0, 0, 0] - - if recover_probe: - bbeta_probe = cp.mean( - weighted_step_probe, - axis=-5, - ) - # Update each direction if object_options is not None: if algorithm_options.batch_method != 'compact': @@ -708,6 +662,8 @@ def _precondition_nearplane_gradients( None, :, :] * unique_probe[..., m:m + 1, :, :] A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1)) + A1_delta = cp.mean(A1, axis=-3) + A1 += 0.5 * A1_delta if recover_probe: b0 = tike.ptycho.probe.finite_probe_support( @@ -741,47 +697,27 @@ def _precondition_nearplane_gradients( dPO = m_probe_update[..., m:m + 1, :, :] * patches A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1)) + A4_delta = cp.mean(A4, axis=-3) + A4 += 0.5 * A4_delta + # (22) Use least-squares to find the optimal step sizes simultaneously if recover_psi and recover_probe: b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, axis=(-2, -1)) A2 = cp.sum((dOP * dPO.conj()), axis=(-2, -1)) - elif recover_psi: - b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, - axis=(-2, -1)) - elif recover_probe: - b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, - axis=(-2, -1)) - - return ( - object_update_precond, - A1, - A2, - A4, - b1, - b2, - ) - - -def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi, - recover_probe, m): - - if recover_psi: - A1 += 0.5 * A1_delta - if recover_probe: - A4 += 0.5 * A4_delta - - # (22) Use least-squares to find the optimal step sizes simultaneously - if recover_psi and recover_probe: A3 = A2.conj() determinant = A1 * A4 - A2 * A3 x1 = -cp.conj(A2 * b2 - A4 * b1) / determinant x2 = cp.conj(A1 * b2 - A3 * b1) / determinant elif recover_psi: + b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, + axis=(-2, -1)) x1 = b1 / A1 elif recover_probe: + b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real, + axis=(-2, -1)) x2 = b2 / A4 else: x1 = None @@ -791,18 +727,22 @@ def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi, step = 0.9 * cp.maximum(0, x1[..., None, None].real) # (27b) Object update - beta_object = cp.mean(step, keepdims=True, axis=-5) + beta_object = cp.mean(step, keepdims=False, axis=-5)[..., 0, 0, 0] else: beta_object = None if recover_probe: step = 0.9 * cp.maximum(0, x2[..., None, None].real) - beta_probe = cp.mean(step, axis=-5, keepdims=True) + beta_probe = cp.mean(step, axis=-5, keepdims=False) else: beta_probe = None - return beta_object, beta_probe + return ( + object_update_precond, + beta_object, + beta_probe, + ) def _get_coefs_intensity(weights, xi, P, O, batches, *, batch_index, m):