Skip to content

Commit

Permalink
Merge pull request #328 from carterbox/merge-functions
Browse files Browse the repository at this point in the history
REF: Merge functions that are no longer separated by reduction
  • Loading branch information
carterbox authored Jul 29, 2024
2 parents cf8f005 + a3fbd94 commit 8532bfc
Showing 1 changed file with 18 additions and 78 deletions.
96 changes: 18 additions & 78 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,6 @@ def lstsq_grad(
)

if recover_probe:
m_probe_update = cp.mean(
m_probe_update,
axis=-5,
)

(
eigen_probe,
eigen_weights,
Expand All @@ -156,11 +151,8 @@ def lstsq_grad(

(
object_update_precond,
A1,
A2,
A4,
b1,
b2,
bbeta_object,
bbeta_probe,
) = _precondition_nearplane_gradients(
diff,
scan,
Expand All @@ -179,44 +171,6 @@ def lstsq_grad(
probe_options=probe_options,
)

if object_options is not None:
A1_delta = cp.mean(A1, axis=-3)
else:
A1_delta = None

if recover_probe:
A4_delta = cp.mean(A4, axis=-3)
else:
A4_delta = None

(
weighted_step_psi,
weighted_step_probe,
) = _get_nearplane_steps(
A1,
A2,
A4,
b1,
b2,
A1_delta,
A4_delta,
recover_psi=object_options is not None,
recover_probe=recover_probe,
m=0,
)

if object_options is not None:
bbeta_object = cp.mean(
weighted_step_psi,
axis=-5,
)[..., 0, 0, 0]

if recover_probe:
bbeta_probe = cp.mean(
weighted_step_probe,
axis=-5,
)

# Update each direction
if object_options is not None:
if algorithm_options.batch_method != 'compact':
Expand Down Expand Up @@ -709,51 +663,33 @@ def _precondition_nearplane_gradients(
None, :, :] * unique_probe[..., m:m + 1, :, :]

A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1))
A1_delta = cp.mean(A1, axis=-3)
A1 += 0.5 * A1_delta

if recover_probe:
dPO = m_probe_update[..., m:m + 1, :, :] * patches
A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1))
A4_delta = cp.mean(A4, axis=-3)
A4 += 0.5 * A4_delta

# (22) Use least-squares to find the optimal step sizes simultaneously
if recover_psi and recover_probe:
b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real,
axis=(-2, -1))
b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real,
axis=(-2, -1))
A2 = cp.sum((dOP * dPO.conj()), axis=(-2, -1))
elif recover_psi:
b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real,
axis=(-2, -1))
elif recover_probe:
b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real,
axis=(-2, -1))

return (
object_update_precond,
A1,
A2,
A4,
b1,
b2,
)


def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi,
recover_probe, m):

if recover_psi:
A1 += 0.5 * A1_delta
if recover_probe:
A4 += 0.5 * A4_delta

# (22) Use least-squares to find the optimal step sizes simultaneously
if recover_psi and recover_probe:
A3 = A2.conj()
determinant = A1 * A4 - A2 * A3
x1 = -cp.conj(A2 * b2 - A4 * b1) / determinant
x2 = cp.conj(A1 * b2 - A3 * b1) / determinant
elif recover_psi:
b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real,
axis=(-2, -1))
x1 = b1 / A1
elif recover_probe:
b2 = cp.sum((dPO.conj() * nearplane[..., m:m + 1, :, :]).real,
axis=(-2, -1))
x2 = b2 / A4
else:
x1 = None
Expand All @@ -763,18 +699,22 @@ def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi,
step = 0.9 * cp.maximum(0, x1[..., None, None].real)

# (27b) Object update
beta_object = cp.mean(step, keepdims=True, axis=-5)
beta_object = cp.mean(step, keepdims=False, axis=-5)[..., 0, 0, 0]
else:
beta_object = None

if recover_probe:
step = 0.9 * cp.maximum(0, x2[..., None, None].real)

beta_probe = cp.mean(step, axis=-5, keepdims=True)
beta_probe = cp.mean(step, axis=-5, keepdims=False)
else:
beta_probe = None

return beta_object, beta_probe
return (
object_update_precond,
beta_object,
beta_probe,
)


def _get_coefs_intensity(weights, xi, P, O, batches, *, batch_index, m):
Expand Down

0 comments on commit 8532bfc

Please sign in to comment.