diff --git a/src/tike/ptycho/solvers/lstsq.py b/src/tike/ptycho/solvers/lstsq.py index 5797739d..14933aac 100644 --- a/src/tike/ptycho/solvers/lstsq.py +++ b/src/tike/ptycho/solvers/lstsq.py @@ -140,12 +140,16 @@ def lstsq_grad( recover_probe=probe_options is not None, recover_positions=position_options is not None, ))) - object_upd_sum = comm.Allreduce(object_upd_sum) - m_probe_update = comm.pool.bcast( - [comm.Allreduce_mean( - m_probe_update, - axis=-5, - )]) + + if object_options is not None: + object_upd_sum = comm.Allreduce(object_upd_sum) + + if probe_options is not None: + m_probe_update = comm.pool.bcast( + [comm.Allreduce_mean( + m_probe_update, + axis=-5, + )]) ( beigen_probe, @@ -188,8 +192,16 @@ def lstsq_grad( recover_probe=probe_options is not None, probe_options=probe_options, ))) - A1_delta = comm.pool.bcast([comm.Allreduce_mean(A1, axis=-3)]) - A4_delta = comm.pool.bcast([comm.Allreduce_mean(A4, axis=-3)]) + + if object_options is not None: + A1_delta = comm.pool.bcast([comm.Allreduce_mean(A1, axis=-3)]) + else: + A1_delta = [None] * comm.pool.num_workers + + if probe_options is not None: + A4_delta = comm.pool.bcast([comm.Allreduce_mean(A4, axis=-3)]) + else: + A4_delta = [None] * comm.pool.num_workers ( weighted_step_psi, @@ -207,14 +219,18 @@ def lstsq_grad( recover_probe=probe_options is not None, m=0, ))) - bbeta_object = comm.Allreduce_mean( - weighted_step_psi, - axis=-5, - )[..., 0, 0, 0] - bbeta_probe = comm.Allreduce_mean( - weighted_step_probe, - axis=-5, - ) + + if object_options is not None: + bbeta_object = comm.Allreduce_mean( + weighted_step_psi, + axis=-5, + )[..., 0, 0, 0] + + if probe_options is not None: + bbeta_probe = comm.Allreduce_mean( + weighted_step_probe, + axis=-5, + ) # Update each direction if object_options is not None: @@ -249,8 +265,11 @@ def lstsq_grad( for c in costs: batch_cost = batch_cost + c.tolist() - beta_object.append(bbeta_object) - beta_probe.append(bbeta_probe) + if object_options is not None: + beta_object.append(bbeta_object) + + if probe_options is not None: + beta_probe.append(bbeta_probe) if eigen_probe is not None: eigen_probe = beigen_probe @@ -671,7 +690,8 @@ def keep_some_args_constant( unique_probe, probe_update, object_upd_sum, - m_probe_update / len(batches[batch_index]), + m_probe_update / + len(batches[batch_index]) if m_probe_update is not None else None, costs, patches, position_update_numerator, @@ -715,6 +735,15 @@ def _precondition_nearplane_gradients( eps = op.xp.float32(1e-9) / (nearplane.shape[-2] * nearplane.shape[-1]) + A1 = None + A2 = None + A4 = None + b1 = None + b2 = None + dOP = None + dPO = None + object_update_proj = None + if recover_psi: object_update_precond = _precondition_object_update( object_upd_sum, @@ -730,10 +759,6 @@ def _precondition_nearplane_gradients( None, :, :] * unique_probe[..., m:m + 1, :, :] A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1)) - else: - object_update_proj = None - dOP = None - A1 = None if recover_probe: @@ -761,9 +786,6 @@ def _precondition_nearplane_gradients( dPO = m_probe_update[..., m:m + 1, :, :] * patches A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1)) - else: - dPO = None - A4 = None if recover_psi and recover_probe: b1 = cp.sum((dOP.conj() * nearplane[..., m:m + 1, :, :]).real, @@ -791,8 +813,10 @@ def _precondition_nearplane_gradients( def _get_nearplane_steps(A1, A2, A4, b1, b2, A1_delta, A4_delta, recover_psi, recover_probe, m): - A1 += 0.5 * A1_delta - A4 += 0.5 * A4_delta + if recover_psi: + A1 += 0.5 * A1_delta + if recover_probe: + A4 += 0.5 * A4_delta # (22) Use least-squares to find the optimal step sizes simultaneously if recover_psi and recover_probe: diff --git a/tests/ptycho/io.py b/tests/ptycho/io.py index d6900681..3ecba0cf 100644 --- a/tests/ptycho/io.py +++ b/tests/ptycho/io.py @@ -1,8 +1,11 @@ -import warnings import os +import typing +import warnings import numpy as np +import numpy.typing as npt import tike.view +import tike.ptycho test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) @@ -41,7 +44,12 @@ def _save_eigen_probe(output_folder, eigen_probe): ) -def _save_probe(output_folder, probe, probe_options, algorithm): +def _save_probe( + output_folder: str, + probe: npt.NDArray, + probe_options: typing.Union[None, tike.ptycho.ProbeOptions], + algorithm: str, +): import matplotlib matplotlib.use('Agg') from matplotlib import pyplot as plt @@ -54,7 +62,7 @@ def _save_probe(output_folder, probe, probe_options, algorithm): f'{output_folder}/probe.png', tike.view.complexHSV_to_RGB(flattened), ) - if len(probe_options.power) > 0: + if probe_options is not None and len(probe_options.power) > 0: f = plt.figure() tike.view.plot_probe_power_series(probe_options.power) plt.title(algorithm) diff --git a/tests/ptycho/test_ptycho.py b/tests/ptycho/test_ptycho.py index 0099593a..40672ade 100644 --- a/tests/ptycho/test_ptycho.py +++ b/tests/ptycho/test_ptycho.py @@ -384,6 +384,25 @@ def test_consistent_lstsq_grad(self): params=params, ), f"mpi{self.mpi_size}-lstsq_grad{self.post_name}") + def test_consistent_lstsq_grad_no_probe(self): + """Check ptycho.solver.lstsq_grad for consistency.""" + params = tike.ptycho.PtychoParameters( + psi=self.psi, + probe=self.probe, + scan=self.scan, + algorithm_options=tike.ptycho.LstsqOptions( + num_batch=5, + num_iter=16, + ), + object_options=ObjectOptions(use_adaptive_moment=True,), + ) + + _save_ptycho_result( + self.template_consistent_algorithm( + data=self.data, + params=params, + ), f"mpi{self.mpi_size}-lstsq_grad-no-probe{self.post_name}") + def test_consistent_lstsq_grad_compact(self): """Check ptycho.solver.lstsq_grad for consistency.""" params = tike.ptycho.PtychoParameters( @@ -408,6 +427,27 @@ def test_consistent_lstsq_grad_compact(self): params=params, ), f"mpi{self.mpi_size}-lstsq_grad-compact{self.post_name}") + def test_consistent_lstsq_grad_compact_no_probe(self): + """Check ptycho.solver.lstsq_grad for consistency.""" + params = tike.ptycho.PtychoParameters( + psi=self.psi, + probe=self.probe, + scan=self.scan, + algorithm_options=tike.ptycho.LstsqOptions( + num_batch=5, + num_iter=16, + batch_method='compact', + ), + object_options=ObjectOptions(use_adaptive_moment=True,), + ) + + _save_ptycho_result( + self.template_consistent_algorithm( + data=self.data, + params=params, + ), + f"mpi{self.mpi_size}-lstsq_grad-compact-no-probe{self.post_name}") + def test_consistent_lstsq_grad_variable_probe(self): """Check ptycho.solver.lstsq_grad for consistency.""" params = tike.ptycho.PtychoParameters( @@ -544,9 +584,7 @@ def test_consistent_rpie(self): num_iter=16, ), probe_options=ProbeOptions(force_orthogonality=True,), - object_options=ObjectOptions( - smoothness_constraint=0.01, - ), + object_options=ObjectOptions(smoothness_constraint=0.01,), ) _save_ptycho_result( @@ -557,6 +595,27 @@ def test_consistent_rpie(self): f"mpi{self.mpi_size}-rpie{self.post_name}", ) + def test_consistent_rpie_no_probe(self): + """Check ptycho.solver.rpie for consistency.""" + params = tike.ptycho.PtychoParameters( + psi=self.psi, + probe=self.probe, + scan=self.scan, + algorithm_options=tike.ptycho.RpieOptions( + num_batch=5, + num_iter=16, + ), + object_options=ObjectOptions(smoothness_constraint=0.01,), + ) + + _save_ptycho_result( + self.template_consistent_algorithm( + data=self.data, + params=params, + ), + f"mpi{self.mpi_size}-rpie-no-probe{self.post_name}", + ) + def test_consistent_rpie_compact(self): """Check ptycho.solver.rpie for consistency.""" params = tike.ptycho.PtychoParameters( @@ -581,6 +640,25 @@ def test_consistent_rpie_compact(self): params=params, ), f"mpi{self.mpi_size}-rpie-compact{self.post_name}") + def test_consistent_rpie_compact_no_probe(self): + """Check ptycho.solver.rpie for consistency.""" + params = tike.ptycho.PtychoParameters( + psi=self.psi, + probe=self.probe, + scan=self.scan, + algorithm_options=tike.ptycho.RpieOptions( + num_batch=5, + num_iter=16, + batch_method='compact', + ), + object_options=ObjectOptions(use_adaptive_moment=True,), + ) + _save_ptycho_result( + self.template_consistent_algorithm( + data=self.data, + params=params, + ), f"mpi{self.mpi_size}-rpie-compact-no-probe{self.post_name}") + def test_consistent_rpie_variable_probe(self): """Check ptycho.solver.lstsq_grad for consistency.""" params = tike.ptycho.PtychoParameters( @@ -636,6 +714,27 @@ def test_consistent_dm(self): f"mpi{self.mpi_size}-dm{self.post_name}", ) + def test_consistent_dm_no_probe(self): + """Check ptycho.solver.dm for consistency.""" + params = tike.ptycho.PtychoParameters( + psi=self.psi, + probe=self.probe, + scan=self.scan, + algorithm_options=tike.ptycho.DmOptions( + num_iter=16, + num_batch=5, + ), + object_options=ObjectOptions(), + ) + + _save_ptycho_result( + self.template_consistent_algorithm( + data=self.data, + params=params, + ), + f"mpi{self.mpi_size}-dm-no-probe{self.post_name}", + ) + class TestPtychoRecon( PtychoRecon,