Skip to content

Commit

Permalink
Merge pull request #281 from carterbox/probe-disabled
Browse files Browse the repository at this point in the history
BUG: Fix errors when probe updates disabled in lstsq method
  • Loading branch information
carterbox authored Oct 3, 2023
2 parents b8a452d + 263a298 commit e412e33
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 34 deletions.
80 changes: 52 additions & 28 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,16 @@ def lstsq_grad(
recover_probe=probe_options is not None,
recover_positions=position_options is not None,
)))
object_upd_sum = comm.Allreduce(object_upd_sum)
m_probe_update = comm.pool.bcast(
[comm.Allreduce_mean(
m_probe_update,
axis=-5,
)])

if object_options is not None:
object_upd_sum = comm.Allreduce(object_upd_sum)

if probe_options is not None:
m_probe_update = comm.pool.bcast(
[comm.Allreduce_mean(
m_probe_update,
axis=-5,
)])

(
beigen_probe,
Expand Down Expand Up @@ -188,8 +192,16 @@ def lstsq_grad(
recover_probe=probe_options is not None,
probe_options=probe_options,
)))
A1_delta = comm.pool.bcast([comm.Allreduce_mean(A1, axis=-3)])
A4_delta = comm.pool.bcast([comm.Allreduce_mean(A4, axis=-3)])

if object_options is not None:
A1_delta = comm.pool.bcast([comm.Allreduce_mean(A1, axis=-3)])
else:
A1_delta = [None] * comm.pool.num_workers

if probe_options is not None:
A4_delta = comm.pool.bcast([comm.Allreduce_mean(A4, axis=-3)])
else:
A4_delta = [None] * comm.pool.num_workers

(
weighted_step_psi,
Expand All @@ -207,14 +219,18 @@ def lstsq_grad(
recover_probe=probe_options is not None,
m=0,
)))
bbeta_object = comm.Allreduce_mean(
weighted_step_psi,
axis=-5,
)[..., 0, 0, 0]
bbeta_probe = comm.Allreduce_mean(
weighted_step_probe,
axis=-5,
)

if object_options is not None:
bbeta_object = comm.Allreduce_mean(
weighted_step_psi,
axis=-5,
)[..., 0, 0, 0]

if probe_options is not None:
bbeta_probe = comm.Allreduce_mean(
weighted_step_probe,
axis=-5,
)

# Update each direction
if object_options is not None:
Expand Down Expand Up @@ -249,8 +265,11 @@ def lstsq_grad(
for c in costs:
batch_cost = batch_cost + c.tolist()

beta_object.append(bbeta_object)
beta_probe.append(bbeta_probe)
if object_options is not None:
beta_object.append(bbeta_object)

if probe_options is not None:
beta_probe.append(bbeta_probe)

if eigen_probe is not None:
eigen_probe = beigen_probe
Expand Down Expand Up @@ -671,7 +690,8 @@ def keep_some_args_constant(
unique_probe,
probe_update,
object_upd_sum,
m_probe_update / len(batches[batch_index]),
m_probe_update /
len(batches[batch_index]) if m_probe_update is not None else None,
costs,
patches,
position_update_numerator,
Expand Down Expand Up @@ -715,6 +735,15 @@ def _precondition_nearplane_gradients(

eps = op.xp.float32(1e-9) / (nearplane.shape[-2] * nearplane.shape[-1])

A1 = None
A2 = None
A4 = None
b1 = None
b2 = None
dOP = None
dPO = None
object_update_proj = None

if recover_psi:
object_update_precond = _precondition_object_update(
object_upd_sum,
Expand All @@ -730,10 +759,6 @@ def _precondition_nearplane_gradients(
None, :, :] * unique_probe[..., m:m + 1, :, :]

A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1))
else:
object_update_proj = None
dOP = None
A1 = None

if recover_probe:

Expand Down Expand Up @@ -761,9 +786,6 @@ def _precondition_nearplane_gradients(

dPO = m_probe_update[..., m:m + 1, :, :] * patches
A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1))
else:
dPO = None
A4 = None

if recover_psi and recover_probe:
b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real,
Expand Down Expand Up @@ -791,8 +813,10 @@ def _precondition_nearplane_gradients(
def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi,
recover_probe, m):

A1 += 0.5 * A1_delta
A4 += 0.5 * A4_delta
if recover_psi:
A1 += 0.5 * A1_delta
if recover_probe:
A4 += 0.5 * A4_delta

# (22) Use least-squares to find the optimal step sizes simultaneously
if recover_psi and recover_probe:
Expand Down
14 changes: 11 additions & 3 deletions tests/ptycho/io.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import warnings
import os
import typing
import warnings

import numpy as np
import numpy.typing as npt
import tike.view
import tike.ptycho

test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))

Expand Down Expand Up @@ -41,7 +44,12 @@ def _save_eigen_probe(output_folder, eigen_probe):
)


def _save_probe(output_folder, probe, probe_options, algorithm):
def _save_probe(
output_folder: str,
probe: npt.NDArray,
probe_options: typing.Union[None, tike.ptycho.ProbeOptions],
algorithm: str,
):
import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
Expand All @@ -54,7 +62,7 @@ def _save_probe(output_folder, probe, probe_options, algorithm):
f'{output_folder}/probe.png',
tike.view.complexHSV_to_RGB(flattened),
)
if len(probe_options.power) > 0:
if probe_options is not None and len(probe_options.power) > 0:
f = plt.figure()
tike.view.plot_probe_power_series(probe_options.power)
plt.title(algorithm)
Expand Down
105 changes: 102 additions & 3 deletions tests/ptycho/test_ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,25 @@ def test_consistent_lstsq_grad(self):
params=params,
), f"mpi{self.mpi_size}-lstsq_grad{self.post_name}")

def test_consistent_lstsq_grad_no_probe(self):
"""Check ptycho.solver.lstsq_grad for consistency."""
params = tike.ptycho.PtychoParameters(
psi=self.psi,
probe=self.probe,
scan=self.scan,
algorithm_options=tike.ptycho.LstsqOptions(
num_batch=5,
num_iter=16,
),
object_options=ObjectOptions(use_adaptive_moment=True,),
)

_save_ptycho_result(
self.template_consistent_algorithm(
data=self.data,
params=params,
), f"mpi{self.mpi_size}-lstsq_grad-no-probe{self.post_name}")

def test_consistent_lstsq_grad_compact(self):
"""Check ptycho.solver.lstsq_grad for consistency."""
params = tike.ptycho.PtychoParameters(
Expand All @@ -408,6 +427,27 @@ def test_consistent_lstsq_grad_compact(self):
params=params,
), f"mpi{self.mpi_size}-lstsq_grad-compact{self.post_name}")

def test_consistent_lstsq_grad_compact_no_probe(self):
"""Check ptycho.solver.lstsq_grad for consistency."""
params = tike.ptycho.PtychoParameters(
psi=self.psi,
probe=self.probe,
scan=self.scan,
algorithm_options=tike.ptycho.LstsqOptions(
num_batch=5,
num_iter=16,
batch_method='compact',
),
object_options=ObjectOptions(use_adaptive_moment=True,),
)

_save_ptycho_result(
self.template_consistent_algorithm(
data=self.data,
params=params,
),
f"mpi{self.mpi_size}-lstsq_grad-compact-no-probe{self.post_name}")

def test_consistent_lstsq_grad_variable_probe(self):
"""Check ptycho.solver.lstsq_grad for consistency."""
params = tike.ptycho.PtychoParameters(
Expand Down Expand Up @@ -544,9 +584,7 @@ def test_consistent_rpie(self):
num_iter=16,
),
probe_options=ProbeOptions(force_orthogonality=True,),
object_options=ObjectOptions(
smoothness_constraint=0.01,
),
object_options=ObjectOptions(smoothness_constraint=0.01,),
)

_save_ptycho_result(
Expand All @@ -557,6 +595,27 @@ def test_consistent_rpie(self):
f"mpi{self.mpi_size}-rpie{self.post_name}",
)

def test_consistent_rpie_no_probe(self):
"""Check ptycho.solver.rpie for consistency."""
params = tike.ptycho.PtychoParameters(
psi=self.psi,
probe=self.probe,
scan=self.scan,
algorithm_options=tike.ptycho.RpieOptions(
num_batch=5,
num_iter=16,
),
object_options=ObjectOptions(smoothness_constraint=0.01,),
)

_save_ptycho_result(
self.template_consistent_algorithm(
data=self.data,
params=params,
),
f"mpi{self.mpi_size}-rpie-no-probe{self.post_name}",
)

def test_consistent_rpie_compact(self):
"""Check ptycho.solver.rpie for consistency."""
params = tike.ptycho.PtychoParameters(
Expand All @@ -581,6 +640,25 @@ def test_consistent_rpie_compact(self):
params=params,
), f"mpi{self.mpi_size}-rpie-compact{self.post_name}")

def test_consistent_rpie_compact_no_probe(self):
"""Check ptycho.solver.rpie for consistency."""
params = tike.ptycho.PtychoParameters(
psi=self.psi,
probe=self.probe,
scan=self.scan,
algorithm_options=tike.ptycho.RpieOptions(
num_batch=5,
num_iter=16,
batch_method='compact',
),
object_options=ObjectOptions(use_adaptive_moment=True,),
)
_save_ptycho_result(
self.template_consistent_algorithm(
data=self.data,
params=params,
), f"mpi{self.mpi_size}-rpie-compact-no-probe{self.post_name}")

def test_consistent_rpie_variable_probe(self):
"""Check ptycho.solver.lstsq_grad for consistency."""
params = tike.ptycho.PtychoParameters(
Expand Down Expand Up @@ -636,6 +714,27 @@ def test_consistent_dm(self):
f"mpi{self.mpi_size}-dm{self.post_name}",
)

def test_consistent_dm_no_probe(self):
"""Check ptycho.solver.dm for consistency."""
params = tike.ptycho.PtychoParameters(
psi=self.psi,
probe=self.probe,
scan=self.scan,
algorithm_options=tike.ptycho.DmOptions(
num_iter=16,
num_batch=5,
),
object_options=ObjectOptions(),
)

_save_ptycho_result(
self.template_consistent_algorithm(
data=self.data,
params=params,
),
f"mpi{self.mpi_size}-dm-no-probe{self.post_name}",
)


class TestPtychoRecon(
PtychoRecon,
Expand Down

0 comments on commit e412e33

Please sign in to comment.