-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add PDHG implementation optimizing the logL only and compare against …
…MLEM
- Loading branch information
Showing
2 changed files
with
593 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
"""Main file to modify for submissions. | ||
Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows: | ||
>>> from main import Submission, submission_callbacks | ||
>>> from petric import data, metrics | ||
>>> algorithm = Submission(data) | ||
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks) | ||
""" | ||
|
||
import sirf.STIR as STIR | ||
from cil.optimisation.algorithms import Algorithm | ||
from cil.optimisation.utilities import callbacks | ||
from petric import Dataset | ||
from sirf.contrib.partitioner.partitioner import partition_indices | ||
from sirf.contrib.partitioner import partitioner | ||
import numpy as np | ||
|
||
|
||
class MaxIteration(callbacks.Callback): | ||
""" | ||
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout). | ||
This callback forces stopping after `max_iteration` instead. | ||
""" | ||
|
||
def __init__(self, max_iteration: int, verbose: int = 1): | ||
super().__init__(verbose) | ||
self.max_iteration = max_iteration | ||
|
||
def __call__(self, algorithm: Algorithm): | ||
if algorithm.iteration >= self.max_iteration: | ||
raise StopIteration | ||
|
||
|
||
class Submission(Algorithm): | ||
""" | ||
OSEM algorithm example. | ||
NB: In OSEM, the multiplicative term cancels in the back-projection of the quotient of measured & estimated data | ||
(so this is used here for efficiency). | ||
A similar optimisation can be used for all algorithms using the Poisson log-likelihood. | ||
NB: OSEM does not use `data.prior` and thus does not converge to the MAP reference used in PETRIC. | ||
NB: this example does not use the `sirf.STIR` Poisson objective function. | ||
NB: see https://github.com/SyneRBI/SIRF-Contribs/tree/master/src/Python/sirf/contrib/BSREM | ||
""" | ||
|
||
def __init__( | ||
self, data: Dataset, update_objective_interval: int = 10000000, **kwargs | ||
): | ||
""" | ||
Initialisation function, setting up data & (hyper)parameters. | ||
NB: in practice, `num_subsets` should likely be determined from the data. | ||
This is just an example. Try to modify and improve it! | ||
""" | ||
self.x_pdhg = data.OSEM_image.clone() | ||
self.x_mlem = data.OSEM_image.clone() | ||
self.x = data.OSEM_image.clone() | ||
|
||
self.data = data.acquired_data | ||
|
||
_, acq_models, obj_funs = partitioner.data_partition( | ||
data.acquired_data, | ||
data.additive_term, | ||
data.mult_factors, | ||
1, | ||
mode="staggered", | ||
initial_image=data.OSEM_image, | ||
) | ||
|
||
self.acquisition_model = acq_models[0] | ||
self.obj_fun = obj_funs[0] | ||
|
||
self.sensitivity_img = self.acquisition_model.backward( | ||
self.data.get_uniform_copy(1) | ||
) | ||
|
||
# clip 0s in the sensitivity image | ||
self.sensitivity_img.maximum(1e-6, out=self.sensitivity_img) | ||
self.i_update = 0 | ||
|
||
# initialise PDHG variables | ||
x_pdhg_fwd = self.acquisition_model.forward(self.x_pdhg) | ||
# clip zero values | ||
x_pdhg_fwd.maximum(1e-6, out=x_pdhg_fwd) | ||
# initialise y | ||
self.y_pdhg = -self.data / x_pdhg_fwd + 1 | ||
|
||
# initialise z and zbar | ||
self.z = self.acquisition_model.backward(self.y_pdhg) | ||
self.z_bar = self.z.clone() | ||
|
||
# initialise the step size S | ||
self.rho = 1.0 | ||
self.gamma = 1.0 / data.OSEM_image.max() | ||
|
||
ones_img = self.x.get_uniform_copy(1) | ||
# get the linear part of the acquisition model | ||
A = self.acquisition_model.get_linear_acquisition_model() | ||
ones_img_fwd = A.forward(ones_img) | ||
# clip zero values | ||
ones_img_fwd.maximum(1e-6, out=ones_img_fwd) | ||
self.S = ones_img_fwd.power(-1) * (self.rho * self.gamma) | ||
|
||
# initialise the step size T | ||
self.T = self.sensitivity_img.power(-1) * (self.rho / self.gamma) | ||
|
||
super().__init__(update_objective_interval=update_objective_interval, **kwargs) | ||
self.configured = True # required by Algorithm | ||
|
||
def update(self): | ||
|
||
######################################################################################## | ||
# MLEM update in additive form, clipping happens only in denominator which keeps the sign | ||
# of the gradient | ||
# remember that self.acquisition_model.forward(self.x_mlem) includes the additive term | ||
exp_data = self.acquisition_model.forward(self.x_mlem) | ||
denom = exp_data.maximum(1e-6) | ||
|
||
precond_mlem = self.x_mlem / self.sensitivity_img | ||
|
||
self.x_mlem = self.x_mlem + precond_mlem * self.acquisition_model.backward( | ||
(self.data - exp_data) / denom | ||
) | ||
|
||
######################################################################################## | ||
# PDHG update | ||
|
||
self.x_pdhg = self.x_pdhg - self.T * self.z_bar | ||
# clip negative values | ||
self.x_pdhg.maximum(0, out=self.x_pdhg) | ||
|
||
x_pdhg_fwd = self.acquisition_model.forward(self.x_pdhg) | ||
y_plus = self.y_pdhg + self.S * x_pdhg_fwd | ||
# apply the proximal operator | ||
tmp = (y_plus - 1) * (y_plus - 1) + self.data * self.S * 4 | ||
tmp.power(0.5, out=tmp) | ||
y_plus = (y_plus + 1 - tmp) / 2 | ||
|
||
delta_z = self.acquisition_model.backward(y_plus - self.y_pdhg) | ||
|
||
self.y_pdhg = y_plus | ||
self.z = self.z + delta_z | ||
self.z_bar = self.z + delta_z | ||
|
||
######################################################################################## | ||
self.i_update += 1 | ||
print(self.i_update, self.x_mlem.max(), self.x_pdhg.max()) | ||
|
||
if (self.i_update < 10) or (self.i_update % 50 == 0): | ||
np.save(f"mlem_{self.i_update}.npy", self.x_mlem.as_array()) | ||
np.save(f"pdhg_{self.i_update}.npy", self.x_pdhg.as_array()) | ||
|
||
print(f"obj_fun mlem: {self.obj_fun(self.x_mlem)}") | ||
print(f"obj_fun pdhg: {self.obj_fun(self.x_pdhg)}") | ||
|
||
def update_objective(self): | ||
""" | ||
NB: The objective value is not required by OSEM nor by PETRIC, so this returns `0`. | ||
NB: It should be `sum(prompts * log(acq_model.forward(self.x)) - self.x * sensitivity)` across all subsets. | ||
""" | ||
return 0 | ||
|
||
|
||
submission_callbacks = [] |
Oops, something went wrong.