Skip to content

Commit

Permalink
Merge pull request #46 from SyneRBI/metrics
Browse files Browse the repository at this point in the history
add metrics as per wiki
  • Loading branch information
casperdcl authored Jul 10, 2024
2 parents 85393b1 + a5860b5 commit beb42e9
Showing 1 changed file with 61 additions and 16 deletions.
77 changes: 61 additions & 16 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from traceback import print_exc

import numpy as np
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio
from skimage.metrics import mean_squared_error as mse
from tensorboardX import SummaryWriter

import sirf.STIR as STIR
Expand Down Expand Up @@ -88,26 +88,51 @@ def __call__(self, algo: Algorithm):
log.debug("...logged")


class QualityMetrics(ImageQualityCallback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs):
super().__init__(reference_image, **kwargs)
self.whole_object_indices = np.where(whole_object_mask == 1)
self.background_indices = np.where(background_mask == 1)
self.ref_im_arr = reference_image.as_array()
self.norm = self.ref_im_arr[self.background_indices].mean()

def __call__(self, algorithm):
iteration = algorithm.iteration
if iteration % algorithm.update_objective_interval != 0 and iteration != algorithm.max_iteration:
return

assert not any(self.filter.values()), "Filtering not implemented"
test_im_arr = algorithm.x.as_array()

# (1) global metrics & statistics
self.tb_summary_writer.add_scalar(
"RMSE_whole_object",
np.sqrt(mse(self.ref_im_arr[self.whole_object_indices], test_im_arr[self.whole_object_indices])) /
self.norm, iteration)
self.tb_summary_writer.add_scalar(
"RMSE_background",
np.sqrt(mse(self.ref_im_arr[self.background_indices], test_im_arr[self.background_indices])) / self.norm,
iteration)

# (2) local metrics & statistics
for voi_name, voi_indices in sorted(self.voi_indices.items()):
# AEM not to be confused with MAE
self.tb_summary_writer.add_scalar(
f"AEM_VOI_{voi_name}",
np.abs(test_im_arr[voi_indices].mean() - self.ref_im_arr[voi_indices].mean()) / self.norm, iteration)


class MetricsWithTimeout(cbks.Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, reference_image=None,
verbose=1):
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, verbose=1):
super().__init__(verbose)
self._seconds = seconds
self.callbacks = [
cbks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := TensorBoard(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]

if reference_image:
roi_image_dict = {f'S{i}': STIR.ImageData(f'S{i}.hv') for i in range(1, 8)}
# NB: these metrics are for testing only.
# The final evaluation will use metrics described in https://github.com/SyneRBI/PETRIC/wiki
self.callbacks.append(
ImageQualityCallback(
reference_image, tb_cbk.tb, roi_mask_dict=roi_image_dict, metrics_dict={
'MSE': mean_squared_error, 'MAE': self.mean_absolute_error, 'PSNR': peak_signal_noise_ratio},
statistics_dict={'MEAN': np.mean, 'STDDEV': np.std, 'MAX': np.max}))
self.tb = tb_cbk.tb
self.reset()

def reset(self, seconds=None):
Expand Down Expand Up @@ -144,7 +169,9 @@ def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3):
return prior


Dataset = namedtuple('Dataset', ['acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa'])
Dataset = namedtuple('Dataset', [
'acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa', 'reference_image',
'whole_object_mask', 'background_mask', 'voi_masks'])


def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
Expand All @@ -165,7 +192,20 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
penalty_strength = 1 / 700 # default choice
prior = construct_RDP(penalty_strength, OSEM_image, kappa)

return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa)
def get_image(fname):
if (source := srcdir / 'PETRIC' / fname).is_file():
return STIR.ImageData(str(source))
return None # explicit to suppress linter warnings

reference_image = get_image('reference_image.hv')
whole_object_mask = get_image('VOI_whole_object.hv')
background_mask = get_image('VOI_background.hv')
voi_masks = {
voi.stem: STIR.ImageData(str(voi))
for voi in (srcdir / 'PETRIC').glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')}

return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image,
whole_object_mask, background_mask, voi_masks)


if SRCDIR.is_dir():
Expand Down Expand Up @@ -194,7 +234,12 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
assert issubclass(Submission, Algorithm)
for srcdir, outdir, metrics in data_dirs_metrics:
data = get_data(srcdir=srcdir, outdir=outdir)
metrics[0].reset() # timeout from now
metrics_with_timeout = metrics[0]
if data.reference_image is not None:
metrics_with_timeout.callbacks.append(
QualityMetrics(data.reference_image, data.whole_object_mask, data.background_mask,
tb_summary_writer=metrics_with_timeout.tb, roi_mask_dict=data.voi_masks))
metrics_with_timeout.reset() # timeout from now
algo = Submission(data)
try:
algo.run(np.inf, callbacks=metrics + submission_callbacks)
Expand Down

0 comments on commit beb42e9

Please sign in to comment.