Skip to content

Commit

Permalink
add SVRG test scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
gschramm committed Jul 22, 2024
1 parent 363a2ca commit 5b7ebab
Show file tree
Hide file tree
Showing 4 changed files with 531 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
data/
output/
output_test/
results/
tmp*/
err*.txt
info.txt
Expand All @@ -12,3 +14,4 @@ __pycache__/
*.ahv
*.hv
*.v
run*.sh
115 changes: 92 additions & 23 deletions main_SVRG.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from sirf.contrib.partitioner.partitioner import partition_indices
from sirf.contrib.partitioner import partitioner

import numpy as np


class MaxIteration(callbacks.Callback):
"""
Expand Down Expand Up @@ -44,9 +46,9 @@ class Submission(Algorithm):
def __init__(
self,
data: Dataset,
num_subsets: int = 7,
initial_step_size: float = 1.0,
update_objective_interval: int = 10,
**kwargs
**kwargs,
):
"""
Initialisation function, setting up data & (hyper)parameters.
Expand All @@ -56,11 +58,25 @@ def __init__(

self.subset = 0
self.x = data.OSEM_image.clone()

# find views in each subset
# (note that SIRF can currently only do subsets over views)
views = data.mult_factors.dimensions()[2]

# hard coded number of subsets for now
if views == 50:
num_subsets = 10
elif views == 128:
num_subsets = 16
elif views == 252:
num_subsets = 21
else:
raise ValueError(f"Unknown number of views: {views}")

self._num_subsets = num_subsets
self._update = 0

#############################################################################
#############################################################################
self._step_size = initial_step_size
self._subset_number_list = []

self._data_sub, self._acq_models, self._obj_funs = partitioner.data_partition(
data.acquired_data,
Expand All @@ -70,6 +86,7 @@ def __init__(
initial_image=data.OSEM_image,
mode="staggered",
)
print("start init")
# WARNING: modifies prior strength with 1/num_subsets (as currently needed for BSREM implementations)
data.prior.set_penalisation_factor(
data.prior.get_penalisation_factor() / num_subsets
Expand All @@ -85,45 +102,97 @@ def __init__(
self._adjoint_ones = 0 * data.OSEM_image

for i in range(num_subsets):
print(f"Calculating subset {i} sensitivity")
subset_adjoint_ones = self._obj_funs[i].get_subset_sensitivity(0)
self._subset_adjoint_ones.append(subset_adjoint_ones)
self._adjoint_ones += subset_adjoint_ones

self._fov_mask = 0 * self._adjoint_ones
tmp = 1.0 * (self._adjoint_ones.as_array() > 0)
self._fov_mask.fill(tmp)

# add a small number to avoid NaN in division
self._adjoint_ones += self._adjoint_ones.max() * 1e-6

# setup the preconditioner
self._precond = (data.OSEM_image + 1e-6) / self._adjoint_ones
self._precond = (
data.OSEM_image + (1e-6) * data.OSEM_image.max()
) / self._adjoint_ones

#############################################################################
#############################################################################
# initialize list / ImageData for all subset gradients and sum of gradients
# TODO better way to create image full of zeros
self._summed_subset_gradients = 0 * self.x
self._subset_gradients = []

def update(self):
# compute forward projection for the denomintor
# add a small number to avoid NaN in division, as OSEM lead to 0/0 or worse.
# (Theoretically, MLEM cannot, but it might nevertheless due to numerical issues)
super().__init__(update_objective_interval=update_objective_interval, **kwargs)
self.configured = True # required by Algorithm

######################################################################
######################################################################
######################################################################
def update_all_subset_gradients(self) -> None:

x_cur = self.x
# TODO better way to create image full of zeros
self._summed_subset_gradients = 0 * self.x
self._subset_gradients = []

for i in range(self._num_subsets):
self._subset_gradients.append(self._obj_funs[i].gradient(self.x))
self._summed_subset_gradients += self._subset_gradients[i]

# additive OSEM update using gradient of subset objective function
x2 = x_cur + (x_cur / self._sensitivities[self.subset]) * self._obj_funs[
self.subset
].gradient(x_cur)
def update(self):

self.subset = (self.subset + 1) % len(self._prompts)
update_all_subset_gradients = self._update % (2 * self._num_subsets) == 0

if update_all_subset_gradients:
print(
f" {self._update}, {self.subset}, recalculating all subset gradients"
)
self.update_all_subset_gradients()
grad = self._summed_subset_gradients
else:
if self._subset_number_list == []:
self.create_subset_number_list()

self.subset = self._subset_number_list.pop()
print(f" {self.update}, {self.subset}, subset gradient update")

grad = (
self._num_subsets
* (
self._obj_funs[self.subset].gradient(self.x)
- self._subset_gradients[self.subset]
)
+ self._summed_subset_gradients
)

### Objective has to be maximized -> "+" for gradient ascent
self.x = self.x + self._step_size * self._precond * self._fov_mask * grad

# enforce non-negative constraint
tmp = self.x.as_array()
np.clip(tmp, 0, None, out=tmp)
self.x.fill(tmp)

self._update += 1

def update_objective(self):
def update_objective(self) -> None:
"""
NB: The objective value is not required by OSEM nor by PETRIC, so this returns `0`.
NB: It should be `sum(prompts * log(acq_model.forward(self.x)) - self.x * sensitivity)` across all subsets.
"""

self.loss.append(self.calc_cost(self.x))

def calc_cost(self, x):
# cost = 0
# for i in range(self._num_subsets):
# cost += self._obj_funs[i](x)
# return cost

return 0

def create_subset_number_list(self):
tmp = np.arange(self._num_subsets)
np.random.shuffle(tmp)
self._subset_number_list = tmp.tolist()


submission_callbacks = [MaxIteration(660)]
submission_callbacks = []
1 change: 1 addition & 0 deletions simulations/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
*.png
*.npy
run*.sh
Loading

0 comments on commit 5b7ebab

Please sign in to comment.