Skip to content

Commit

Permalink
SVRG linear decay
Browse files Browse the repository at this point in the history
  • Loading branch information
samdporter committed Sep 26, 2024
1 parent 0b442fb commit dd58a0e
Showing 1 changed file with 102 additions and 53 deletions.
155 changes: 102 additions & 53 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,58 @@
#%%
from cil.optimisation.algorithms import ISTA, Algorithm
from cil.optimisation.functions import IndicatorBox, SVRGFunction
from cil.optimisation.utilities import (Preconditioner, Sampler,
callbacks, StepSizeRule)
from cil.optimisation.utilities import (Preconditioner, Sampler,
StepSizeRule)
from petric import Dataset
from sirf.contrib.partitioner import partitioner
import numbers
import sirf.STIR as pet
import numpy as np
import pandas as pd
import types


assert issubclass(ISTA, Algorithm)

#%%
class MaxIteration(callbacks.Callback):
"""
The organisers try to `Submission(data).run(inf)` i.e. for infinite iterations (until timeout).
This callback forces stopping after `max_iteration` instead.
"""
def __init__(self, max_iteration: int, verbose: int = 1):
super().__init__(verbose)
self.max_iteration = max_iteration

def __call__(self, algorithm: Algorithm):
if algorithm.iteration >= self.max_iteration:
raise StopIteration
class BSREMPreconditioner(Preconditioner):
'''Step size rule for BSREM algorithm.
::math::
x^+ = x + t \nabla \log L(y|x)
with :math:`t = x / s` where :math:`s` is the adjoint of the range geometry of the acquisition model.
'''
def __init__(self, acq_models, freeze_iter = np.inf, epsilon=1e-6):

class MyPreconditioner(Preconditioner):
"""
Example based on the row-sum of the Hessian of the log-likelihood. See: Tsai et al. Fast Quasi-Newton Algorithms
for Penalized Reconstruction in Emission Tomography and Further Improvements via Preconditioning,
IEEE TMI https://doi.org/10.1109/tmi.2017.2786865
"""
def __init__(self, kappa):
# add an epsilon to avoid division by zero (probably should make epsilon dependent on kappa)
self.kappasq = kappa*kappa + ((kappa).max()/1000)**2
self.epsilon = epsilon
self.freeze_iter = freeze_iter
self.t = None

for i,el in enumerate(acq_models):
if i == 0:
self.s_sum = el.domain_geometry().get_uniform_copy(0.)
ones = el.range_geometry().allocate(1.)
s = el.adjoint(ones)
s.maximum(self.epsilon, out=s)
arr = s.as_array()
np.reciprocal(arr, out=arr)
s.fill(arr)
self.s_sum += s

def apply(self, algorithm, gradient, out=None):
return gradient.divide(self.kappasq, out=out)

if algorithm.iteration < self.freeze_iter:
t = algorithm.solution * self.s_sum + self.epsilon
else:
if self.t is None:
self.t = algorithm.solution * self.s_sum + self.epsilon
t = self.t

return gradient.multiply(t, out=out)

def apply_without_algorithm(self, gradient, x, out=None):
t = x * self.s_sum + self.epsilon
return gradient.multiply(t, out=out)

class LinearDecayStepSizeRule(StepSizeRule):
"""
Linear decay of the step size.
Linear decay of the step size with iteration.
"""
def __init__(self, initial_step_size: float, decay: float):
self.initial_step_size = initial_step_size
Expand All @@ -61,56 +71,95 @@ def __init__(self, initial_step_size: float, decay: float):
def get_step_size(self, algorithm):
return self.initial_step_size / (1 + self.decay * algorithm.iteration)

def initial_step_size_search_rule(x, f, g, grad, max_iter=100, tol=0.1):
def armijo_step_size_search_rule(x, f, g, grad, precond_grad, step_size=2.0, beta = 0.5, max_iter=100, tol=0.2):
"""
Simple line search for the initial step size.
"""
step_size = 1.0
f_x = f(x) + g(x)
g_norm = grad.squared_norm()
g_norm = grad.dot(precond_grad)
for _ in range(max_iter):
x_new = g.proximal(x - step_size * grad, step_size)
x_new = g.proximal(x - step_size * precond_grad, step_size)
f_x_new = f(x_new) + g(x_new)
if f_x_new <= f_x - tol * step_size * g_norm:
break
step_size /= 2
step_size *= beta
return step_size

#%%
def calculate_subsets(sino, min_counts_per_subset=2**20, max_subsets=30):
"""
Calculate the number of subsets for a given sinogram such that each subset
has at least the minimum number of counts.
Args:
sino: A sinogram object with .dimensions() and .sum() methods.
min_counts_per_subset (float): Minimum number of counts per subset (default is 11057672.26).
Returns:
int: The number of subsets that can be created while maintaining the minimum counts per subset.
"""
views = sino.dimensions()[2] # Extract the number of views
total_counts = sino.sum() # Sum of counts for the sinogram

# Calculate the maximum number of subsets based on minimum counts per subset
max_subsets = int(total_counts / min_counts_per_subset)
# ensure less than views / 4 subsets
max_subsets = min(max_subsets, views // 4)
# ensure less than max_subsets
max_subsets = min(max_subsets, max_subsets)

# Find a divisor of the number of views that results in the closest number of subsets
subsets = max(1, min(views, max_subsets))

# Ensure subsets is a divisor of views
while views % subsets != 0 and subsets > 1:
subsets -= 1

return subsets

class Submission(ISTA):
"""Stochastic variance reduced subset version of preconditioned ISTA"""

# note that `issubclass(ISTA, Algorithm) == True`
def __init__(self, data: Dataset):
"""
Initialisation function, setting up data & (hyper)parameters.
NB: in practice, `num_subsets` should likely be determined from the data.
This is just an example. Try to modify and improve it!
"""

num_subsets = data.acquired_data.dimensions()[2]//8
initial_step_size = 0.5
decay = (1/0.9-1) / num_subsets

# Very simple heuristic to determine the number of subsets
self.num_subsets = calculate_subsets(data.acquired_data, min_counts_per_subset=2**20)
print(f"Number of subsets: {self.num_subsets}")
update_interval = self.num_subsets
# 10% decay per update interval
upper_decay_perc = 0.1
upper_decay = (1/(1-upper_decay_perc) - 1)/update_interval
beta = 0.5

data_subs, acq_models, obj_funs = partitioner.data_partition(data.acquired_data, data.additive_term,
data.mult_factors, num_subsets, mode='staggered',
data.mult_factors, self.num_subsets, mode='staggered',
initial_image=data.OSEM_image)
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations)


data.prior.set_penalisation_factor(data.prior.get_penalisation_factor() / len(obj_funs))
data.prior.set_up(data.OSEM_image)

grad = data.OSEM_image.get_uniform_copy(0)
preconditioner = MyPreconditioner(data.kappa)

for f, d in zip(obj_funs, data_subs): # add prior to every objective function
f.set_prior(data.prior)
grad -= preconditioner.apply(self, f.gradient(data.OSEM_image))
grad -= f.gradient(data.OSEM_image)

sampler = Sampler.random_without_replacement(len(obj_funs))
f = -SVRGFunction(obj_funs, sampler=sampler, snapshot_update_interval=None, store_gradients=True)
g = IndicatorBox(lower=0, accelerated=True) # non-negativity constraint
initial_step_size = initial_step_size_search_rule(data.OSEM_image, f, g, grad)
step_size_rule = LinearDecayStepSizeRule(initial_step_size, decay=decay)
f = -SVRGFunction(obj_funs, sampler=sampler, snapshot_update_interval=update_interval, store_gradients=True)

super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule, preconditioner=preconditioner,
update_objective_interval=0)
preconditioner = BSREMPreconditioner(acq_models, epsilon=data.OSEM_image.max()/1e6, freeze_iter=10*update_interval)
g = IndicatorBox(lower=0, accelerated=True) # non-negativity constraint

precond_grad = preconditioner.apply_without_algorithm(grad, data.OSEM_image)

initial_step_size = armijo_step_size_search_rule(data.OSEM_image, f, g, grad, precond_grad, beta=beta, step_size = 0.08, tol=0.2)
step_size_rule = LinearDecayStepSizeRule(initial_step_size, 0.01)

super().__init__(initial=data.OSEM_image, f=f, g=g, step_size=step_size_rule,
preconditioner=preconditioner, update_objective_interval=update_interval)

submission_callbacks = [MaxIteration(np.inf)]
submission_callbacks = []

0 comments on commit dd58a0e

Please sign in to comment.