Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add metrics as per wiki #46

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from traceback import print_exc

import numpy as np
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio
from skimage.metrics import mean_squared_error as mse
from tensorboardX import SummaryWriter

import sirf.STIR as STIR
Expand Down Expand Up @@ -88,26 +88,51 @@ def __call__(self, algo: Algorithm):
log.debug("...logged")


class QualityMetrics(ImageQualityCallback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, whole_object_mask, background_mask, **kwargs):
super().__init__(reference_image, **kwargs)
self.whole_object_indices = np.where(whole_object_mask == 1)
self.background_indices = np.where(background_mask == 1)
self.ref_im_arr = reference_image.as_array()
casperdcl marked this conversation as resolved.
Show resolved Hide resolved

def __call__(self, algorithm):
iteration = algorithm.iteration
if iteration % algorithm.update_objective_interval != 0 and iteration != algorithm.max_iteration:
casperdcl marked this conversation as resolved.
Show resolved Hide resolved
return

for filter_name, filter_func in self.filter.items():
if filter_func is None:
filter_func = lambda x: x
test_im, ref_im = (filter_func(img_data).as_array() for img_data in (algorithm.x, self.reference_image))
KrisThielemans marked this conversation as resolved.
Show resolved Hide resolved

# (1) global metrics & statistics
norm = ref_im[self.background_indices].mean()
self.tb_summary_writer.add_scalar(
f"RMSE_whole_object{filter_name}",
np.sqrt(mse(ref_im[self.whole_object_indices], test_im[self.whole_object_indices])) / norm, iteration)
self.tb_summary_writer.add_scalar(
f"RMSE_background{filter_name}",
np.sqrt(mse(ref_im[self.background_indices], test_im[self.background_indices])) / norm, iteration)

# (2) local metrics & statistics
for voi_name, voi_indices in self.voi_indices.items():
# AEM not to be confused with MAE
self.tb_summary_writer.add_scalar(
f"AEM_VOI_{voi_name}{filter_name}",
np.abs(test_im[voi_indices].mean() - ref_im[voi_indices].mean()) / norm, iteration)


class MetricsWithTimeout(cbks.Callback):
"""Stops the algorithm after `seconds`"""
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, reference_image=None,
verbose=1):
def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_slice=None, verbose=1):
super().__init__(verbose)
self._seconds = seconds
self.callbacks = [
cbks.ProgressCallback(),
SaveIters(outdir=outdir),
(tb_cbk := TensorBoard(logdir=outdir, transverse_slice=transverse_slice, coronal_slice=coronal_slice))]

if reference_image:
roi_image_dict = {f'S{i}': STIR.ImageData(f'S{i}.hv') for i in range(1, 8)}
# NB: these metrics are for testing only.
# The final evaluation will use metrics described in https://github.com/SyneRBI/PETRIC/wiki
self.callbacks.append(
ImageQualityCallback(
reference_image, tb_cbk.tb, roi_mask_dict=roi_image_dict, metrics_dict={
'MSE': mean_squared_error, 'MAE': self.mean_absolute_error, 'PSNR': peak_signal_noise_ratio},
statistics_dict={'MEAN': np.mean, 'STDDEV': np.std, 'MAX': np.max}))
self.tb = tb_cbk.tb
self.reset()

def reset(self, seconds=None):
Expand Down Expand Up @@ -144,7 +169,9 @@ def construct_RDP(penalty_strength, initial_image, kappa, max_scaling=1e-3):
return prior


Dataset = namedtuple('Dataset', ['acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa'])
Dataset = namedtuple('Dataset', [
'acquired_data', 'additive_term', 'mult_factors', 'OSEM_image', 'prior', 'kappa', 'reference_image',
'whole_object_mask', 'background_mask', 'voi_masks'])


def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
Expand All @@ -165,7 +192,19 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
penalty_strength = 1 / 700 # default choice
prior = construct_RDP(penalty_strength, OSEM_image, kappa)

return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa)
reference_image = STIR.ImageData(str(srcdir / 'reference_image.hv')) if (srcdir /
'reference_image.hv').is_file() else None
whole_object_mask = STIR.ImageData(str(srcdir /
casperdcl marked this conversation as resolved.
Show resolved Hide resolved
'VOI_whole_object.hv')) if (srcdir /
'VOI_whole_object.hv').is_file() else None
background_mask = STIR.ImageData(str(srcdir / 'VOI_background.hv')) if (srcdir /
'VOI_background.hv').is_file() else None
voi_masks = {
voi.stem: STIR.ImageData(str(voi))
for voi in srcdir.glob("VOI_*.hv") if voi.stem[4:] not in ('background', 'whole_object')}

return Dataset(acquired_data, additive_term, mult_factors, OSEM_image, prior, kappa, reference_image,
whole_object_mask, background_mask, voi_masks)


if SRCDIR.is_dir():
Expand Down Expand Up @@ -194,7 +233,12 @@ def get_data(srcdir=".", outdir=OUTDIR, sirf_verbosity=0):
assert issubclass(Submission, Algorithm)
for srcdir, outdir, metrics in data_dirs_metrics:
data = get_data(srcdir=srcdir, outdir=outdir)
metrics[0].reset() # timeout from now
metrics_with_timeout = metrics[0]
if data.reference_image is not None:
metrics_with_timeout.callbacks.append(
QualityMetrics(data.reference_image, data.whole_object_mask, data.background_mask,
tb_summary_writer=metrics_with_timeout.tb, roi_mask_dict=data.voi_masks))
metrics_with_timeout.reset() # timeout from now
algo = Submission(data)
try:
algo.run(np.inf, callbacks=metrics + submission_callbacks)
Expand Down