Skip to content

Commit

Permalink
Merge branch 'main' of [email protected]:SyneRBI/PETRIC-MaGeZ.git into s…
Browse files Browse the repository at this point in the history
…imulation
  • Loading branch information
gschramm committed Jul 23, 2024
2 parents 6cc4cef + d9fbdc0 commit 40c6be1
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 61 deletions.
1 change: 1 addition & 0 deletions SIRF_data_preparation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ Participants should never have to use these (unless you want to create your own

## Helpers

- `evaluation_utilities.py`: reading/plotting helpers for values of the objective function and metrics
- `PET_plot_functions.py`: plotting helpers
97 changes: 69 additions & 28 deletions SIRF_data_preparation/data_QC.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def plot_sinogram_profile(prompts, background, sumaxis=(0, 1), select=0, srcdir=
plt.plot(np.sum(prompts.as_array(), axis=sumaxis)[select, :], label='prompts')
plt.plot(np.sum(background.as_array(), axis=sumaxis)[select, :], label='background')
ax.legend()
plt.savefig(srcdir + 'prompts_background_profiles.png')
plt.savefig(os.path.join(srcdir, 'prompts_background_profiles.png'))


def plot_image(image, save_name=None, transverse_slice=-1, coronal_slice=-1, sagittal_slice=-1, vmin=0, vmax=None):
def plot_image(image, save_name=None, transverse_slice=-1, coronal_slice=-1, sagittal_slice=-1, vmin=0, vmax=None,
alpha=None, **kwargs):
"""
Plot a profile through sirf.STIR.ImageData
Plot transverse/coronal/sagital slices through sirf.STIR.ImageData
"""
if transverse_slice < 0:
transverse_slice = image.dimensions()[0] // 2
Expand All @@ -63,13 +64,24 @@ def plot_image(image, save_name=None, transverse_slice=-1, coronal_slice=-1, sag
vmax = image.max()

arr = image.as_array()
plt.figure()
plt.subplot(131)
plt.imshow(arr[transverse_slice, :, :], vmin=vmin, vmax=vmax)
plt.subplot(132)
plt.imshow(arr[:, coronal_slice, :], vmin=vmin, vmax=vmax)
plt.subplot(133)
plt.imshow(arr[:, :, sagittal_slice], vmin=vmin, vmax=vmax)
alpha_trans = None
alpha_cor = None
alpha_sag = None
if alpha is not None:
alpha_arr = alpha.as_array()
alpha_trans = alpha_arr[transverse_slice, :, :]
alpha_cor = alpha_arr[:, coronal_slice, :]
alpha_sag = alpha_arr[:, :, sagittal_slice]

ax = plt.subplot(131)
plt.imshow(arr[transverse_slice, :, :], vmin=vmin, vmax=vmax, alpha=alpha_trans, **kwargs)
ax.set_title(f"T={transverse_slice}")
ax = plt.subplot(132)
plt.imshow(arr[:, coronal_slice, :], vmin=vmin, vmax=vmax, alpha=alpha_cor, **kwargs)
ax.set_title(f"C={coronal_slice}")
ax = plt.subplot(133)
plt.imshow(arr[:, :, sagittal_slice], vmin=vmin, vmax=vmax, alpha=alpha_sag, **kwargs)
ax.set_title(f"S={sagittal_slice}")
plt.colorbar(shrink=.6)
if save_name is not None:
plt.savefig(save_name + '_slices.png')
Expand All @@ -79,28 +91,59 @@ def plot_image(image, save_name=None, transverse_slice=-1, coronal_slice=-1, sag
def plot_image_if_exists(prefix, **kwargs):
if os.path.isfile(prefix + '.hv'):
im = STIR.ImageData(prefix + '.hv')
plt.figure()
plot_image(im, prefix, **kwargs)
return im
else:
print(f"Image {prefix}.hv does not exist")
return None


def VOI_mean(image, VOI):
return float((image * VOI).sum() / VOI.sum())


def VOI_checks(allVOInames, OSEM_image, reference_image, srcdir='.'):
from scipy import ndimage


def VOI_checks(allVOInames, OSEM_image=None, reference_image=None, srcdir='.', **kwargs):
if len(allVOInames) == 0:
return
OSEM_VOI_values = []
ref_VOI_values = []
allVOIs = None
VOIkwargs = kwargs.copy()
VOIkwargs['vmax'] = 1
VOIkwargs['vmin'] = 0
for VOIname in allVOInames:
VOI = plot_image_if_exists(os.path.join(srcdir, VOIname), transverse_slice=transverse_slice,
coronal_slice=coronal_slice, sagittal_slice=sagittal_slice)
if OSEM_image:
prefix = os.path.join(srcdir, VOIname)
filename = prefix + '.hv'
if not os.path.isfile(filename):
print(f"VOI {VOIname} does not exist")
continue
VOI = STIR.ImageData(filename)
COM = np.rint(ndimage.center_of_mass(VOI.as_array()))
plt.figure()
plot_image(VOI, save_name=prefix, vmin=0, vmax=1, transverse_slice=int(COM[0]), coronal_slice=int(COM[1]),
sagittal_slice=int(COM[2]))

# construct transparency image
if VOIname == 'VOI_whole_object':
VOI /= 2
if allVOIs is None:
allVOIs = VOI.clone()
else:
allVOIs += VOI
if OSEM_image is not None:
OSEM_VOI_values.append(VOI_mean(OSEM_image, VOI))
if reference_image:
ref_VOI_values.append(VOI_mean(reference_image, VOI))
allVOIs /= allVOIs.max()

if OSEM_image is not None:
plt.figure()
plot_image(OSEM_image, alpha=allVOIs, save_name="OSEM_image_and_VOIs", **kwargs)

# unformatted print of VOI values for now
print(allVOInames)
print(OSEM_VOI_values)
Expand All @@ -111,27 +154,25 @@ def main(argv=None):
args = docopt(__doc__, argv=argv, version=__version__)
srcdir = args['--srcdir']
skip_sino_profiles = args['--skip_sino_profiles']
transverse_slice = literal_eval(args['--transverse_slice'])
coronal_slice = literal_eval(args['--coronal_slice'])
sagittal_slice = literal_eval(args['--sagittal_slice'])
slices = {}
slices["transverse_slice"] = literal_eval(args['--transverse_slice'])
slices["coronal_slice"] = literal_eval(args['--coronal_slice'])
slices["sagittal_slice"] = literal_eval(args['--sagittal_slice'])

if not skip_sino_profiles:
acquired_data = STIR.AcquisitionData(srcdir + 'prompts.hs')
additive_term = STIR.AcquisitionData(srcdir + 'additive_term.hs')
mult_factors = STIR.AcquisitionData(srcdir + 'mult_factors.hs')
acquired_data = STIR.AcquisitionData(os.path.join(srcdir, 'prompts.hs'))
additive_term = STIR.AcquisitionData(os.path.join(srcdir, 'additive_term.hs'))
mult_factors = STIR.AcquisitionData(os.path.join(srcdir, 'mult_factors.hs'))
background = additive_term * mult_factors
plot_sinogram_profile(acquired_data, background, srcdir=srcdir)

OSEM_image = plot_image_if_exists('OSEM_image', transverse_slice=transverse_slice, coronal_slice=coronal_slice,
sagittal_slice=sagittal_slice)
plot_image_if_exists('kappa', transverse_slice=transverse_slice, coronal_slice=coronal_slice,
sagittal_slice=sagittal_slice)
reference_image = plot_image_if_exists(srcdir + 'PETRIC/reference_image', transverse_slice=transverse_slice,
coronal_slice=coronal_slice, sagittal_slice=sagittal_slice)
OSEM_image = plot_image_if_exists(os.path.join(srcdir, 'OSEM_image'), **slices)
plot_image_if_exists(os.path.join(srcdir, 'kappa'), **slices)
reference_image = plot_image_if_exists(os.path.join(srcdir, 'PETRIC/reference_image'), **slices)

VOIdir = os.path.join(srcdir, srcdir + "PETRIC")
VOIdir = os.path.join(srcdir, 'PETRIC')
allVOInames = [os.path.basename(str(voi)[:-3]) for voi in Path(VOIdir).glob("VOI_*.hv")]
VOI_checks(allVOInames, OSEM_image, reference_image, srcdir=VOIdir)
VOI_checks(allVOInames, OSEM_image, reference_image, srcdir=VOIdir, **slices)
plt.show()


Expand Down
54 changes: 54 additions & 0 deletions SIRF_data_preparation/evaluation_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Some utilities for plotting objectives and metrics."""
import csv
from pathlib import Path
from typing import Iterator

import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import binary_erosion

import sirf.STIR as STIR
from petric import QualityMetrics


def read_objectives(datadir='.'):
"""Reads objectives.csv and returns as 2d array"""
with (Path(datadir) / 'objectives.csv').open() as csvfile:
reader = csv.reader(csvfile)
next(reader) # skip first (header) line
return np.asarray([tuple(map(float, row)) for row in reader])


def get_metrics(qm: QualityMetrics, iters: Iterator[int], srcdir='.'):
"""Read 'iter_{iter_glob}.hv' images from datadir, compute metrics and return as 2d array"""
return np.asarray([
list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters])


def pass_index(metrics: np.ndarray, thresh: Iterator, window: int = 1) -> int:
"""
Returns first index of `metrics` with value <= `thresh`.
The values must remain below the respective thresholds for at least `window` number of entries.
Otherwise raises IndexError.
"""
thr_arr = np.asanyarray(thresh)
assert metrics.ndim == 2
assert thr_arr.ndim == 1
assert metrics.shape[1] == thr_arr.shape[0]
passed = (metrics <= thr_arr[None]).all(axis=1)
res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2))
return np.where(res)[0][0]


def plot_metrics(iters: Iterator[int], m: np.ndarray, labels=None, suffix=""):
"""Make 2 subplots of metrics"""
if labels is None:
labels = [""] * m.shape[1]
ax = plt.subplot(121)
plt.plot(iters, m[:, 0], label=labels[0] + suffix)
plt.plot(iters, m[:, 1], label=labels[1] + suffix)
ax.legend()
ax = plt.subplot(122)
for i in range(2, m.shape[1]):
plt.plot(iters, m[:, i], label=labels[i] + suffix)
ax.legend()
87 changes: 54 additions & 33 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
ANY CHANGES TO THIS FILE ARE IGNORED BY THE ORGANISERS.
Only the `main.py` file may be modified by participants.
This file is not intended for participants to use.
This file is not intended for participants to use, except for
the `get_data` function (and possibly `QualityMetrics` class).
It is used by the organisers to run the submissions in a controlled way.
It is included here purely in the interest of transparency.
Expand All @@ -27,7 +28,7 @@

import sirf.STIR as STIR
from cil.optimisation.algorithms import Algorithm
from cil.optimisation.utilities import callbacks as cbks
from cil.optimisation.utilities import callbacks as cil_callbacks
from img_quality_cil_stir import ImageQualityCallback

log = logging.getLogger('petric')
Expand All @@ -38,17 +39,31 @@
SRCDIR = Path("./data")


class SaveIters(cbks.Callback):
class Callback(cil_callbacks.Callback):
"""
CIL Callback but with `self.skip_iteration` checking `min(self.interval, algo.update_objective_interval)`.
TODO: backport this class to CIL.
"""
def __init__(self, interval: int = 1 << 31, **kwargs):
super().__init__(**kwargs)
self.interval = interval

def skip_iteration(self, algo: Algorithm) -> bool:
return algo.iteration % min(self.interval,
algo.update_objective_interval) != 0 and algo.iteration != algo.max_iteration


class SaveIters(Callback):
"""Saves `algo.x` as "iter_{algo.iteration:04d}.hv" and `algo.loss` in `csv_file`"""
def __init__(self, verbose=1, outdir=OUTDIR, csv_file='objectives.csv'):
super().__init__(verbose)
def __init__(self, outdir=OUTDIR, csv_file='objectives.csv', **kwargs):
super().__init__(**kwargs)
self.outdir = Path(outdir)
self.outdir.mkdir(parents=True, exist_ok=True)
self.csv = csv.writer((self.outdir / csv_file).open("w", buffering=1))
self.csv.writerow(("iter", "objective"))

def __call__(self, algo: Algorithm):
if algo.iteration % algo.update_objective_interval == 0 or algo.iteration == algo.max_iteration:
if not self.skip_iteration(algo):
log.debug("saving iter %d...", algo.iteration)
algo.x.write(str(self.outdir / f'iter_{algo.iteration:04d}.hv'))
self.csv.writerow((algo.iteration, algo.get_last_loss()))
Expand All @@ -57,53 +72,55 @@ def __call__(self, algo: Algorithm):
algo.x.write(str(self.outdir / 'iter_final.hv'))


class StatsLog(cbks.Callback):
class StatsLog(Callback):
"""Log image slices & objective value"""
def __init__(self, verbose=1, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR):
super().__init__(verbose)
def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=OUTDIR, **kwargs):
super().__init__(**kwargs)
self.transverse_slice = transverse_slice
self.coronal_slice = coronal_slice
self.vmax = vmax
self.x_prev = None
self.tb = logdir if isinstance(logdir, SummaryWriter) else SummaryWriter(logdir=str(logdir))

def __call__(self, algo: Algorithm):
if algo.iteration % algo.update_objective_interval != 0 and algo.iteration != algo.max_iteration:
if self.skip_iteration(algo):
return
t = getattr(self, '__time', None) or time()
log.debug("logging iter %d...", algo.iteration)
# initialise `None` values
self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice
self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice
self.vmax = algo.x.max() if self.vmax is None else self.vmax

self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration)
self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t)
if self.x_prev is not None:
normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm()
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration)
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t)
self.x_prev = algo.x.clone()
self.tb.add_image("transverse",
np.clip(algo.x.as_array()[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0, 1),
algo.iteration)
self.tb.add_image("coronal", np.clip(algo.x.as_array()[None, :, self.coronal_slice] / self.vmax, 0, 1),
algo.iteration)
x_arr = algo.x.as_array()
self.tb.add_image("transverse", np.clip(x_arr[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0,
1), algo.iteration, t)
self.tb.add_image("coronal", np.clip(x_arr[None, :, self.coronal_slice] / self.vmax, 0, 1), algo.iteration, t)
log.debug("...logged")


class QualityMetrics(ImageQualityCallback):
class QualityMetrics(ImageQualityCallback, Callback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs):
super().__init__(reference_image, **kwargs)
def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1 << 31, **kwargs):
# TODO: drop multiple inheritance once `interval` included in CIL
Callback.__init__(self, interval=interval)
ImageQualityCallback.__init__(self, reference_image, **kwargs)
self.whole_object_indices = np.where(whole_object_mask.as_array())
self.background_indices = np.where(background_mask.as_array())
self.ref_im_arr = reference_image.as_array()
self.norm = self.ref_im_arr[self.background_indices].mean()

def __call__(self, algo: Algorithm):
iteration = algo.iteration
if iteration % algo.update_objective_interval != 0 and iteration != algo.max_iteration:
if self.skip_iteration(algo):
return
t = getattr(self, '__time', None) or time()
for tag, value in self.evaluate(algo.x).items():
self.tb_summary_writer.add_scalar(tag, value, iteration)
self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t)

def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
assert not any(self.filter.values()), "Filtering not implemented"
Expand All @@ -119,30 +136,34 @@ def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
for voi_name, voi_indices in sorted(self.voi_indices.items())}
return {**whole, **local}

def keys(self):
return ["RMSE_whole_object", "RMSE_background"] + [f"AEM_VOI_{name}" for name in sorted(self.voi_indices)]


class MetricsWithTimeout(cbks.Callback):
class MetricsWithTimeout(cil_callbacks.Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, verbose=1):
super().__init__(verbose)
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, **kwargs):
super().__init__(**kwargs)
self._seconds = seconds
self.callbacks = [
cbks.ProgressCallback(),
cil_callbacks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := StatsLog(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]
self.tb = tb_cbk.tb
self.tb = tb_cbk.tb # convenient access to the underlying SummaryWriter
self.reset()

def reset(self, seconds=None):
self.limit = time() + (self._seconds if seconds is None else seconds)
self.offset = 0

def __call__(self, algo: Algorithm):
if (now := time()) > self.limit:
if (now := time()) > self.limit + self.offset:
log.warning("Timeout reached. Stopping algorithm.")
raise StopIteration
if self.callbacks:
for c in self.callbacks:
c(algo)
self.limit += time() - now
for c in self.callbacks:
c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks
c(algo)
self.offset += time() - now

@staticmethod
def mean_absolute_error(y, x):
Expand Down

0 comments on commit 40c6be1

Please sign in to comment.