Skip to content

Commit

Permalink
Merge pull request #981 from sarthakpati/synthesis_metrics_debug
Browse files Browse the repository at this point in the history
Ensure synthesis metrics have an option to take voided image
  • Loading branch information
sarthakpati authored Dec 20, 2024
2 parents 7a5bac9 + 1e9447c commit f9aabec
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 100 deletions.
5 changes: 1 addition & 4 deletions .spelling/.spelling/expect.txt
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ rgbatorgb
rgbtorgba
rigourous
Ritesh
rmse
rmsprop
rocm
rocmdocs
Expand Down Expand Up @@ -561,7 +562,6 @@ thresholded
thresholding
Thu
tiatoolbox
tiffslide
timepoints
timm
tio
Expand Down Expand Up @@ -597,8 +597,6 @@ unittests
unitwise
unsqueeze
upenn
Uploaing
Uploded
upsample
upsampled
upsampling
Expand Down Expand Up @@ -725,7 +723,6 @@ zsuokb
zwezggl
zzokqk
thirdparty
adopy
Shohei
crcrpar
lrs
Expand Down
63 changes: 34 additions & 29 deletions GANDLF/cli/generate_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
overall_stats,
structural_similarity_index,
mean_squared_error,
root_mean_squared_error,
peak_signal_noise_ratio,
mean_squared_log_error,
mean_absolute_error,
ncc_mean,
ncc_std,
ncc_max,
ncc_min,
ncc_metrics,
)
from GANDLF.losses.segmentation import dice
from GANDLF.metrics.segmentation import (
Expand Down Expand Up @@ -302,21 +300,22 @@ def __percentile_clip(
reference_tensor = (
input_tensor if reference_tensor is None else reference_tensor
)
v_min, v_max = np.percentile(
reference_tensor, [p_min, p_max]
) # get p_min percentile and p_max percentile

# get p_min percentile and p_max percentile
v_min, v_max = np.percentile(reference_tensor, [p_min, p_max])
# set lower bound to be 0 if strictlyPositive is enabled
v_min = max(v_min, 0.0) if strictlyPositive else v_min
output_tensor = np.clip(
input_tensor, v_min, v_max
) # clip values to percentiles from reference_tensor
output_tensor = (output_tensor - v_min) / (
v_max - v_min
) # normalizes values to [0;1]
# clip values to percentiles from reference_tensor
output_tensor = np.clip(input_tensor, v_min, v_max)
# normalizes values to [0;1]
output_tensor = (output_tensor - v_min) / (v_max - v_min)
return output_tensor

input_df = __update_header_location_case_insensitive(input_df, "Mask", False)
# these are additional columns that could be present for synthesis tasks
for column_to_make_case_insensitive in ["Mask", "VoidImage"]:
input_df = __update_header_location_case_insensitive(
input_df, column_to_make_case_insensitive, False
)

for _, row in tqdm(input_df.iterrows(), total=input_df.shape[0]):
current_subject_id = row["SubjectID"]
overall_stats_dict[current_subject_id] = {}
Expand All @@ -332,16 +331,26 @@ def __percentile_clip(
)
).byte()

void_image_present = True if "VoidImage" in row else False
void_image = (
__fix_2d_tensor(torchio.ScalarImage(row["VoidImage"]).data)
if "VoidImage" in row
else torch.from_numpy(
np.ones(target_image.numpy().shape, dtype=np.uint8)
)
)

# Get Infill region (we really are only interested in the infill region)
output_infill = (pred_image * mask).float()
gt_image_infill = (target_image * mask).float()

# Normalize to [0;1] based on GT (otherwise MSE will depend on the image intensity range)
normalize = parameters.get("normalize", True)
if normalize:
# use all the tissue that is not masked for normalization
reference_tensor = (
target_image * ~mask
) # use all the tissue that is not masked for normalization
target_image * ~mask if not void_image_present else void_image
)
gt_image_infill = __percentile_clip(
gt_image_infill,
reference_tensor=reference_tensor,
Expand All @@ -364,18 +373,10 @@ def __percentile_clip(
# ncc metrics
compute_ncc = parameters.get("compute_ncc", True)
if compute_ncc:
overall_stats_dict[current_subject_id]["ncc_mean"] = ncc_mean(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_std"] = ncc_std(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_max"] = ncc_max(
output_infill, gt_image_infill
)
overall_stats_dict[current_subject_id]["ncc_min"] = ncc_min(
output_infill, gt_image_infill
)
calculated_ncc_metrics = ncc_metrics(output_infill, gt_image_infill)
for key, value in calculated_ncc_metrics.items():
# we don't need the ".item()" here, since the values are already scalars
overall_stats_dict[current_subject_id][key] = value

# only voxels that are to be inferred (-> flat array)
# these are required for mse, psnr, etc.
Expand All @@ -386,6 +387,10 @@ def __percentile_clip(
output_infill, gt_image_infill
).item()

overall_stats_dict[current_subject_id]["rmse"] = root_mean_squared_error(
output_infill, gt_image_infill
).item()

overall_stats_dict[current_subject_id]["msle"] = mean_squared_log_error(
output_infill, gt_image_infill
).item()
Expand Down
2 changes: 1 addition & 1 deletion GANDLF/entrypoints/hf_hub_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
@click.option(
"--hf-template",
"-hft",
help="Adding the template path for the model card it is Required during Uploaing a model",
help="Adding the template path for the model card: it is required during model upload",
default=huggingface_file_path,
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
Expand Down
6 changes: 2 additions & 4 deletions GANDLF/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@
from .synthesis import (
structural_similarity_index,
mean_squared_error,
root_mean_squared_error,
peak_signal_noise_ratio,
mean_squared_log_error,
mean_absolute_error,
ncc_mean,
ncc_std,
ncc_max,
ncc_min,
ncc_metrics,
)
import GANDLF.metrics.classification as classification
import GANDLF.metrics.regression as regression
Expand Down
101 changes: 39 additions & 62 deletions GANDLF/metrics/synthesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,28 @@ def mean_squared_error(prediction: torch.Tensor, target: torch.Tensor) -> torch.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The mean squared error or its square root.
"""
mse = MeanSquaredError()
mse = MeanSquaredError(squared=True)
return mse(preds=prediction, target=target)


def root_mean_squared_error(
prediction: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
"""
Computes the mean squared error between the target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
torch.Tensor: The mean squared error or its square root.
"""
mse = MeanSquaredError(squared=False)
return mse(preds=prediction, target=target)


Expand Down Expand Up @@ -78,10 +98,9 @@ def peak_signal_noise_ratio(
return psnr(preds=prediction, target=target)
else: # implementation of PSNR that does not give 'inf'/'nan' when 'mse==0'
mse = mean_squared_error(target, prediction)
if data_range == None: # compute data_range like torchmetrics if not given
min_v = (
0 if torch.min(target) > 0 else torch.min(target)
) # look at this line
if data_range is None: # compute data_range like torchmetrics if not given
# put the min value to 0 if all values are positive
min_v = 0 if torch.min(target) > 0 else torch.min(target)
max_v = torch.max(target)
else:
min_v, max_v = data_range
Expand Down Expand Up @@ -158,69 +177,27 @@ def __convert_to_grayscale(image: sitk.Image) -> sitk.Image:
return correlation_filter.Execute(target_image, pred_image)


def ncc_mean(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation mean between target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
float: The normalized cross correlation mean.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMean()


def ncc_std(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation standard deviation between target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
float: The normalized cross correlation standard deviation.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetSigma()


def ncc_max(prediction: torch.Tensor, target: torch.Tensor) -> float:
"""
Computes normalized cross correlation maximum between target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
float: The normalized cross correlation maximum.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMaximum()


def ncc_min(prediction: torch.Tensor, target: torch.Tensor) -> float:
def ncc_metrics(prediction: torch.Tensor, target: torch.Tensor) -> dict:
"""
Computes normalized cross correlation minimum between target and prediction.
Computes normalized cross correlation metrics between target and prediction.
Args:
prediction (torch.Tensor): The prediction tensor.
target (torch.Tensor): The target tensor.
Returns:
float: The normalized cross correlation minimum.
dict: The normalized cross correlation metrics.
"""
stats_filter = sitk.StatisticsImageFilter()
corr_image = _get_ncc_image(target, prediction)
stats_filter.Execute(corr_image)
return stats_filter.GetMinimum()
stats_filter = sitk.LabelStatisticsImageFilter()
stats_filter.UseHistogramsOn()
# ensure that we are not considering zeros
onesImage = corr_image == corr_image
stats_filter.Execute(corr_image, onesImage)
return {
"ncc_mean": stats_filter.GetMean(1),
"ncc_std": stats_filter.GetSigma(1),
"ncc_max": stats_filter.GetMaximum(1),
"ncc_min": stats_filter.GetMinimum(1),
"ncc_median": stats_filter.GetMedian(1),
}

0 comments on commit f9aabec

Please sign in to comment.