Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CIL recon: adding implicit formulation of TV with warm start #1972

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 56 additions & 26 deletions mantidimaging/core/reconstruct/cil_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from cil.framework import AcquisitionData, AcquisitionGeometry, DataOrder, ImageGeometry, BlockGeometry
from cil.optimisation.algorithms import PDHG, SPDHG
from cil.optimisation.operators import GradientOperator, BlockOperator
from cil.optimisation.functions import MixedL21Norm, L2NormSquared, BlockFunction, ZeroFunction, IndicatorBox, Function
from cil.optimisation.functions import MixedL21Norm, L2NormSquared, BlockFunction, ZeroFunction, IndicatorBox, Function, TotalVariation
from cil.plugins.astra.operators import ProjectionOperator

from mantidimaging.core.data import ImageStack
Expand All @@ -34,51 +34,81 @@
class CILRecon(BaseRecon):

@staticmethod
def _set_stochastic_data_fitting_objective_function(
acquisition_data: AcquisitionData,
Grad: GradientOperator, alpha: float) -> list:
# now, A is a BlockOperator as acquisition_data is a BlockDataContainer
fs = []
for i, _ in enumerate(acquisition_data.geometry):
fs.append(L2NormSquared(b=acquisition_data.get_item(i)))

return fs
@staticmethod
def set_up_TV_regularisation(
image_geometry: ImageGeometry, acquisition_data: AcquisitionData,
recon_params: ReconstructionParameters) -> tuple[BlockOperator, BlockFunction, Function]:

# Forward operator
A2d = ProjectionOperator(image_geometry, acquisition_data.geometry, 'gpu')
# Projection operator
A = ProjectionOperator(image_geometry, acquisition_data.geometry, 'gpu')

if recon_params.stochastic:
for partition_geometry, partition_operator in zip(acquisition_data.geometry, A2d, strict=True):
for partition_geometry, partition_operator in zip(acquisition_data.geometry, A, strict=True):
CILRecon.set_approx_norm(partition_operator, partition_geometry, image_geometry)
else:
CILRecon.set_approx_norm(A2d, acquisition_data.geometry, image_geometry)
CILRecon.set_approx_norm(A, acquisition_data.geometry, image_geometry)

# Define Gradient Operator and BlockOperator
alpha = recon_params.alpha
Grad = GradientOperator(image_geometry)

if recon_params.stochastic:
# now, A2d is a BlockOperator as acquisition_data is a BlockDataContainer
fs = []
for i, _ in enumerate(acquisition_data.geometry):
fs.append(L2NormSquared(b=acquisition_data.get_item(i)))
fs.append(MixedL21Norm())
# TODO: add explicit/implicit option to the GUI
if recon_params.explicit:
if recon_params.stochastic:
fs = CILRecon._set_stochastic_objective_function(acquisition_data)
fs.append(MixedL21Norm())

F = BlockFunction(*fs)
F = BlockFunction(*fs)

# needs to unrol the A2d BlockOperator and put it in another, followed by
# the gradient operator, as in the deterministic case
K = BlockOperator(*A2d.get_as_list(), alpha * Grad)
# needs to unrol the A BlockOperator and put it in another, followed by
# the gradient operator, as in the deterministic case
K = BlockOperator(*A.get_as_list(), alpha * Grad)

else:
# Define BlockFunction F using the MixedL21Norm() and the L2NormSquared()
f1 = MixedL21Norm()
f2 = L2NormSquared(b=acquisition_data)
else:
# Define BlockFunction F using the MixedL21Norm() and the L2NormSquared()
f1 = MixedL21Norm()
f2 = L2NormSquared(b=acquisition_data)

F = BlockFunction(f1, f2)
F = BlockFunction(f1, f2)

# define the BlockOperator
K = BlockOperator(alpha * Grad, A2d)
# define the BlockOperator
K = BlockOperator(alpha * Grad, A)

if recon_params.non_negative:
G = IndicatorBox(lower=0)
if recon_params.non_negative:
G = IndicatorBox(lower=0)
else:
# Define Function G simply as zero
G = ZeroFunction()
else:
# Define Function G simply as zero
G = ZeroFunction()
# implicit
if recon_params.stochastic:
fs = CILRecon._set_stochastic_objective_function(acquisition_data)

F = BlockFunction(*fs)

else:
F = L2NormSquared(b=acquisition_data)

# define the ProjectionOperator
K = A

# regulariser
lower = 0 if recon_params.non_negative else None

# Here we could add different regularisers from the CCPi Regularisation toolkit
# like TGV that is not currently in CIL.
# It will require a new control from the GUI, possibly a drop down menu
G = alpha * TotalVariation(lower=lower)


return (K, F, G)

Expand Down