Skip to content

Commit

Permalink
add PDHG implementation optimizing the logL only and compare against …
Browse files Browse the repository at this point in the history
…MLEM
  • Loading branch information
gschramm committed Oct 2, 2024
1 parent 19132d0 commit 3f3e987
Show file tree
Hide file tree
Showing 2 changed files with 593 additions and 0 deletions.
163 changes: 163 additions & 0 deletions main_PDHG.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Main file to modify for submissions.
Once renamed or symlinked as `main.py`, it will be used by `petric.py` as follows:
>>> from main import Submission, submission_callbacks
>>> from petric import data, metrics
>>> algorithm = Submission(data)
>>> algorithm.run(np.inf, callbacks=metrics + submission_callbacks)
"""

import sirf.STIR as STIR
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks
from petric import Dataset
from sirf.contrib.partitioner.partitioner import partition_indices
from sirf.contrib.partitioner import partitioner
import numpy as np


class MaxIteration(callbacks.Callback):
"""
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout).
This callback forces stopping after `max_iteration` instead.
"""

def __init__(self, max_iteration: int, verbose: int = 1):
super().__init__(verbose)
self.max_iteration = max_iteration

def __call__(self, algorithm: Algorithm):
if algorithm.iteration >= self.max_iteration:
raise StopIteration


class Submission(Algorithm):
"""
OSEM algorithm example.
NB: In OSEM, the multiplicative term cancels in the back-projection of the quotient of measured & estimated data
(so this is used here for efficiency).
A similar optimisation can be used for all algorithms using the Poisson log-likelihood.
NB: OSEM does not use `data.prior` and thus does not converge to the MAP reference used in PETRIC.
NB: this example does not use the `sirf.STIR` Poisson objective function.
NB: see https://github.com/SyneRBI/SIRF-Contribs/tree/master/src/Python/sirf/contrib/BSREM
"""

def __init__(
self, data: Dataset, update_objective_interval: int = 10000000, **kwargs
):
"""
Initialisation function, setting up data & (hyper)parameters.
NB: in practice, `num_subsets` should likely be determined from the data.
This is just an example. Try to modify and improve it!
"""
self.x_pdhg = data.OSEM_image.clone()
self.x_mlem = data.OSEM_image.clone()
self.x = data.OSEM_image.clone()

self.data = data.acquired_data

_, acq_models, obj_funs = partitioner.data_partition(
data.acquired_data,
data.additive_term,
data.mult_factors,
1,
mode="staggered",
initial_image=data.OSEM_image,
)

self.acquisition_model = acq_models[0]
self.obj_fun = obj_funs[0]

self.sensitivity_img = self.acquisition_model.backward(
self.data.get_uniform_copy(1)
)

# clip 0s in the sensitivity image
self.sensitivity_img.maximum(1e-6, out=self.sensitivity_img)
self.i_update = 0

# initialise PDHG variables
x_pdhg_fwd = self.acquisition_model.forward(self.x_pdhg)
# clip zero values
x_pdhg_fwd.maximum(1e-6, out=x_pdhg_fwd)
# initialise y
self.y_pdhg = -self.data / x_pdhg_fwd + 1

# initialise z and zbar
self.z = self.acquisition_model.backward(self.y_pdhg)
self.z_bar = self.z.clone()

# initialise the step size S
self.rho = 1.0
self.gamma = 1.0 / data.OSEM_image.max()

ones_img = self.x.get_uniform_copy(1)
# get the linear part of the acquisition model
A = self.acquisition_model.get_linear_acquisition_model()
ones_img_fwd = A.forward(ones_img)
# clip zero values
ones_img_fwd.maximum(1e-6, out=ones_img_fwd)
self.S = ones_img_fwd.power(-1) * (self.rho * self.gamma)

# initialise the step size T
self.T = self.sensitivity_img.power(-1) * (self.rho / self.gamma)

super().__init__(update_objective_interval=update_objective_interval, **kwargs)
self.configured = True # required by Algorithm

def update(self):

########################################################################################
# MLEM update in additive form, clipping happens only in denominator which keeps the sign
# of the gradient
# remember that self.acquisition_model.forward(self.x_mlem) includes the additive term
exp_data = self.acquisition_model.forward(self.x_mlem)
denom = exp_data.maximum(1e-6)

precond_mlem = self.x_mlem / self.sensitivity_img

self.x_mlem = self.x_mlem + precond_mlem * self.acquisition_model.backward(
(self.data - exp_data) / denom
)

########################################################################################
# PDHG update

self.x_pdhg = self.x_pdhg - self.T * self.z_bar
# clip negative values
self.x_pdhg.maximum(0, out=self.x_pdhg)

x_pdhg_fwd = self.acquisition_model.forward(self.x_pdhg)
y_plus = self.y_pdhg + self.S * x_pdhg_fwd
# apply the proximal operator
tmp = (y_plus - 1) * (y_plus - 1) + self.data * self.S * 4
tmp.power(0.5, out=tmp)
y_plus = (y_plus + 1 - tmp) / 2

delta_z = self.acquisition_model.backward(y_plus - self.y_pdhg)

self.y_pdhg = y_plus
self.z = self.z + delta_z
self.z_bar = self.z + delta_z

########################################################################################
self.i_update += 1
print(self.i_update, self.x_mlem.max(), self.x_pdhg.max())

if (self.i_update < 10) or (self.i_update % 50 == 0):
np.save(f"mlem_{self.i_update}.npy", self.x_mlem.as_array())
np.save(f"pdhg_{self.i_update}.npy", self.x_pdhg.as_array())

print(f"obj_fun mlem: {self.obj_fun(self.x_mlem)}")
print(f"obj_fun pdhg: {self.obj_fun(self.x_pdhg)}")

def update_objective(self):
"""
NB: The objective value is not required by OSEM nor by PETRIC, so this returns `0`.
NB: It should be `sum(prompts * log(acq_model.forward(self.x)) - self.x * sensitivity)` across all subsets.
"""
return 0


submission_callbacks = []
Loading

0 comments on commit 3f3e987

Please sign in to comment.