Skip to content

Commit

Permalink
REF: Move probe support outside of algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
carterbox committed Jul 22, 2024
1 parent 03d5faa commit 08450de
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 42 deletions.
22 changes: 22 additions & 0 deletions src/tike/ptycho/ptycho.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
get_varying_probe,
apply_median_filter_abs_probe,
orthogonalize_eig,
finite_probe_support,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -715,6 +716,27 @@ def _apply_probe_constraints(
if parameters.probe_options is not None:
if parameters.probe_options.recover_probe(epoch):

if parameters.probe_options.probe_support > 0:
b0 = finite_probe_support(
parameters.probe,
p=parameters.probe_options.probe_support,
radius=parameters.probe_options.probe_support_radius,
degree=parameters.probe_options.probe_support_degree,
)
parameters.probe -= b0 * cp.conj(b0 * parameters.probe)

if parameters.probe_options.additional_probe_penalty > 0:
b1 = (
parameters.probe_options.additional_probe_penalty
* cp.linspace(
0,
1,
parameters.probe.shape[-3],
dtype=tike.precision.floating,
)[..., None, None]
)
parameters.probe -= b1 * cp.conj(b1 * parameters.probe)

if parameters.probe_options.median_filter_abs_probe:
parameters.probe = apply_median_filter_abs_probe(
parameters.probe,
Expand Down
29 changes: 0 additions & 29 deletions src/tike/ptycho/solvers/lstsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,35 +710,6 @@ def _precondition_nearplane_gradients(
A1 = cp.sum((dOP * dOP.conj()).real + eps, axis=(-2, -1))

if recover_probe:
b0 = tike.ptycho.probe.finite_probe_support(
unique_probe[..., m : m + 1, :, :],
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)

b1 = (
probe_options.additional_probe_penalty
* cp.linspace(
0,
1,
probe[0].shape[-3],
dtype=tike.precision.floating,
)[..., m : m + 1, None, None]
)

m_probe_update = m_probe_update - (b0 + b1) * probe[..., m : m + 1, :, :]
# / (
# (1 - alpha) * probe_update_denominator
# + alpha
# * probe_update_denominator.max(
# axis=(-2, -1),
# keepdims=True,
# )
# + b0
# + b1
# )

dPO = m_probe_update[..., m:m + 1, :, :] * patches
A4 = cp.sum((dPO * dPO.conj()).real + eps, axis=(-2, -1))

Expand Down
14 changes: 1 addition & 13 deletions src/tike/ptycho/solvers/rpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,26 +264,14 @@ def _update(
psi = psi + dpsi / deno

if recover_probe:
b0 = tike.ptycho.probe.finite_probe_support(
probe,
p=probe_options.probe_support,
radius=probe_options.probe_support_radius,
degree=probe_options.probe_support_degree,
)
b1 = (
probe_options.additional_probe_penalty
* cp.linspace(0, 1, probe.shape[-3], dtype="float32")[..., None, None]
)
dprobe = probe_update_numerator - (b1 + b0) * probe
dprobe = probe_update_numerator
deno = (
(1 - algorithm_options.alpha) * probe_options.preconditioner
+ algorithm_options.alpha
* probe_options.preconditioner.max(
axis=(-2, -1),
keepdims=True,
)
+ b0
+ b1
)
probe = probe + dprobe / deno
if probe_options.use_adaptive_moment:
Expand Down

0 comments on commit 08450de

Please sign in to comment.