Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add configuration options to PostProcessor #2547

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 108 additions & 47 deletions src/anomalib/post_processing/one_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ class OneClassPostProcessor(PostProcessor):
- Formatting results for downstream use

Args:
enable_normalization (bool, optional): Enable normalization of anomaly scores.
Defaults to True.
enable_thresholding (bool, optional): Enable thresholding of anomaly scores.
Defaults to True.
enable_threshold_matching (bool, optional): Use image-level threshold for pixel-level predictions when
pixel-level threshold is not available, and vice-versa.
Defaults to True.
image_sensitivity (float | None, optional): Sensitivity value for image-level
predictions. Higher values make the model more sensitive to anomalies.
Defaults to None.
Expand All @@ -53,32 +60,39 @@ class OneClassPostProcessor(PostProcessor):

def __init__(
self,
image_sensitivity: float | None = None,
pixel_sensitivity: float | None = None,
enable_normalization: bool = True,
enable_thresholding: bool = True,
enable_threshold_matching: bool = True,
image_sensitivity: float = 0.5,
pixel_sensitivity: float = 0.5,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.enable_thresholding = enable_thresholding
self.enable_normalization = enable_normalization
self.enable_threshold_matching = enable_threshold_matching

# configure sensitivity values
self.image_sensitivity = image_sensitivity
self.pixel_sensitivity = pixel_sensitivity

# initialize threshold and normalization metrics
self._image_threshold = F1AdaptiveThreshold(fields=["pred_score", "gt_label"], strict=False)
self._pixel_threshold = F1AdaptiveThreshold(fields=["anomaly_map", "gt_mask"], strict=False)
self._image_min_max = MinMax(fields=["pred_score"], strict=False)
self._pixel_min_max = MinMax(fields=["anomaly_map"], strict=False)
self._image_threshold_metric = F1AdaptiveThreshold(fields=["pred_score", "gt_label"], strict=False)
self._pixel_threshold_metric = F1AdaptiveThreshold(fields=["anomaly_map", "gt_mask"], strict=False)
self._image_min_max_metric = MinMax(fields=["pred_score"], strict=False)
self._pixel_min_max_metric = MinMax(fields=["anomaly_map"], strict=False)

# register buffers to persist threshold and normalization values
self.register_buffer("image_threshold", torch.tensor(0))
self.register_buffer("pixel_threshold", torch.tensor(0))
self.register_buffer("image_min", torch.tensor(0))
self.register_buffer("image_max", torch.tensor(1))
self.register_buffer("pixel_min", torch.tensor(0))
self.register_buffer("pixel_max", torch.tensor(1))

self.image_threshold: torch.Tensor
self.pixel_threshold: torch.Tensor
self.register_buffer("_image_threshold", torch.tensor(float("nan")))
self.register_buffer("_pixel_threshold", torch.tensor(float("nan")))
self.register_buffer("image_min", torch.tensor(float("nan")))
self.register_buffer("image_max", torch.tensor(float("nan")))
self.register_buffer("pixel_min", torch.tensor(float("nan")))
self.register_buffer("pixel_max", torch.tensor(float("nan")))

self._image_threshold: torch.Tensor
self._pixel_threshold: torch.Tensor
self.image_min: torch.Tensor
self.image_max: torch.Tensor
self.pixel_min: torch.Tensor
Expand All @@ -102,10 +116,14 @@ def on_validation_batch_end(
**kwargs: Arbitrary keyword arguments.
"""
del trainer, pl_module, args, kwargs # Unused arguments.
self._image_threshold.update(outputs)
self._pixel_threshold.update(outputs)
self._image_min_max.update(outputs)
self._pixel_min_max.update(outputs)
if self.enable_thresholding:
# update threshold metrics
self._image_threshold_metric.update(outputs)
self._pixel_threshold_metric.update(outputs)
if self.enable_normalization:
# update normalization metrics
self._image_min_max_metric.update(outputs)
self._pixel_min_max_metric.update(outputs)

def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Compute final threshold and normalization values.
Expand All @@ -115,14 +133,22 @@ def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule)
pl_module (LightningModule): PyTorch Lightning module instance.
"""
del trainer, pl_module
if self._image_threshold.update_called:
self.image_threshold = self._image_threshold.compute()
if self._pixel_threshold.update_called:
self.pixel_threshold = self._pixel_threshold.compute()
if self._image_min_max.update_called:
self.image_min, self.image_max = self._image_min_max.compute()
if self._pixel_min_max.update_called:
self.pixel_min, self.pixel_max = self._pixel_min_max.compute()
if self.enable_thresholding:
# compute threshold values
if self._image_threshold_metric.update_called:
self._image_threshold = self._image_threshold_metric.compute()
self._image_threshold_metric.reset()
if self._pixel_threshold_metric.update_called:
self._pixel_threshold = self._pixel_threshold_metric.compute()
self._pixel_threshold_metric.reset()
if self.enable_normalization:
# compute normalization values
if self._image_min_max_metric.update_called:
self.image_min, self.image_max = self._image_min_max_metric.compute()
self._image_min_max_metric.reset()
if self._pixel_min_max_metric.update_called:
self.pixel_min, self.pixel_max = self._pixel_min_max_metric.compute()
self._pixel_min_max_metric.reset()

def on_test_batch_end(
self,
Expand Down Expand Up @@ -181,10 +207,21 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch:
msg = "At least one of pred_score or anomaly_map must be provided."
raise ValueError(msg)
pred_score = predictions.pred_score or torch.amax(predictions.anomaly_map, dim=(-2, -1))
pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold)
anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold)
pred_label = self._apply_threshold(pred_score, self.normalized_image_threshold)
pred_mask = self._apply_threshold(anomaly_map, self.normalized_pixel_threshold)

if self.enable_normalization:
pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold)
anomaly_map = self._normalize(predictions.anomaly_map, self.pixel_min, self.pixel_max, self.pixel_threshold)
else:
pred_score = predictions.pred_score
anomaly_map = predictions.anomaly_map

if self.enable_thresholding:
pred_label = self._apply_threshold(pred_score, self.normalized_image_threshold)
pred_mask = self._apply_threshold(anomaly_map, self.normalized_pixel_threshold)
else:
pred_label = None
pred_mask = None

return InferenceBatch(
pred_label=pred_label,
pred_score=pred_score,
Expand All @@ -201,9 +238,11 @@ def post_process_batch(self, batch: Batch) -> None:
batch (Batch): Batch containing model predictions.
"""
# apply normalization
self.normalize_batch(batch)
if self.enable_normalization:
self.normalize_batch(batch)
# apply threshold
self.threshold_batch(batch)
if self.enable_thresholding:
self.threshold_batch(batch)

def threshold_batch(self, batch: Batch) -> None:
"""Apply thresholding to batch predictions.
Expand Down Expand Up @@ -236,7 +275,7 @@ def normalize_batch(self, batch: Batch) -> None:
@staticmethod
def _apply_threshold(
preds: torch.Tensor | None,
threshold: torch.Tensor | None,
threshold: torch.Tensor,
) -> torch.Tensor | None:
"""Apply thresholding to a single tensor.

Expand All @@ -247,16 +286,16 @@ def _apply_threshold(
Returns:
torch.Tensor | None: Thresholded predictions or None if input is None.
"""
if preds is None or threshold is None:
if preds is None or threshold.isnan():
return preds
return preds > threshold

@staticmethod
def _normalize(
preds: torch.Tensor | None,
norm_min: torch.Tensor | None,
norm_max: torch.Tensor | None,
threshold: torch.Tensor | None,
norm_min: torch.Tensor,
norm_max: torch.Tensor,
threshold: torch.Tensor,
) -> torch.Tensor | None:
"""Normalize a tensor using min, max, and threshold values.

Expand All @@ -269,29 +308,51 @@ def _normalize(
Returns:
torch.Tensor | None: Normalized predictions or None if input is None.
"""
if preds is None or norm_min is None or norm_max is None or threshold is None:
if preds is None or norm_min.isnan() or norm_max.isnan():
return preds
if threshold.isnan():
threshold = (norm_max + norm_min) / 2
preds = ((preds - threshold) / (norm_max - norm_min)) + 0.5
return preds.clamp(min=0, max=1)

@property
def normalized_image_threshold(self) -> float:
def image_threshold(self) -> torch.tensor:
"""Get the image-level threshold.

Returns:
float: Image-level threshold value.
"""
if not self._image_threshold.isnan():
return self._image_threshold
return self._pixel_threshold if self.enable_threshold_matching else torch.tensor(float("nan"))

@property
def pixel_threshold(self) -> torch.tensor:
"""Get the pixel-level threshold.

If the pixel-level threshold is not set, the image-level threshold is used.

Returns:
float: Pixel-level threshold value.
"""
if not self._pixel_threshold.isnan():
return self._pixel_threshold
return self._image_threshold if self.enable_threshold_matching else torch.tensor(float("nan"))

@property
def normalized_image_threshold(self) -> torch.tensor:
"""Get the normalized image-level threshold.

Returns:
float: Normalized image-level threshold value, adjusted by sensitivity.
"""
if self.image_sensitivity is not None:
return torch.tensor(1.0) - self.image_sensitivity
return torch.tensor(0.5)
return torch.tensor(1.0) - self.image_sensitivity

@property
def normalized_pixel_threshold(self) -> float:
def normalized_pixel_threshold(self) -> torch.tensor:
"""Get the normalized pixel-level threshold.

Returns:
float: Normalized pixel-level threshold value, adjusted by sensitivity.
"""
if self.pixel_sensitivity is not None:
return torch.tensor(1.0) - self.pixel_sensitivity
return torch.tensor(0.5)
return torch.tensor(1.0) - self.pixel_sensitivity
108 changes: 108 additions & 0 deletions tests/unit/post_processing/test_post_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Test the PostProcessor class."""

# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from anomalib.data import ImageBatch
from anomalib.post_processing import OneClassPostProcessor


class TestPostProcessor:
"""Test the PreProcessor class."""

@staticmethod
@pytest.mark.parametrize(
("preds", "min_val", "max_val", "thresh", "target"),
[
(torch.tensor([20, 40, 60, 80]), 0, 100, 50, torch.tensor([0.2, 0.4, 0.6, 0.8])),
(torch.tensor([20, 40, 60, 80]), 0, 100, 40, torch.tensor([0.3, 0.5, 0.7, 0.9])), # lower threshold
(torch.tensor([20, 40, 60, 80]), 0, 100, 60, torch.tensor([0.1, 0.3, 0.5, 0.7])), # higher threshold
(torch.tensor([0, 40, 80, 120]), 20, 100, 50, torch.tensor([0.0, 0.375, 0.875, 1.0])), # out of bounds
(torch.tensor([-80, -60, -40, -20]), -100, 0, -50, torch.tensor([0.2, 0.4, 0.6, 0.8])), # negative values
(torch.tensor([20, 40, 60, 80]), 0, 100, -50, torch.tensor([1.0, 1.0, 1.0, 1.0])), # threshold below range
(torch.tensor([20, 40, 60, 80]), 0, 100, 150, torch.tensor([0.0, 0.0, 0.0, 0.0])), # threshold above range
(torch.tensor([20, 40, 60, 80]), 50, 50, 50, torch.tensor([0.0, 0.0, 1.0, 1.0])), # all same
(torch.tensor(60), 0, 100, 50, torch.tensor(0.6)), # scalar tensor
(torch.tensor([[20, 40], [60, 80]]), 0, 100, 50, torch.tensor([[0.2, 0.4], [0.6, 0.8]])), # 2D tensor
],
)
def test_normalize(
preds: torch.Tensor,
min_val: float,
max_val: float,
thresh: float,
target: torch.Tensor,
) -> None:
"""Test the normalize method."""
pre_processor = OneClassPostProcessor()
normalized = pre_processor._normalize( # noqa: SLF001
preds,
torch.tensor(min_val),
torch.tensor(max_val),
torch.tensor(thresh),
)
assert torch.allclose(normalized, target)

@staticmethod
@pytest.mark.parametrize(
("preds", "thresh", "target"),
[
(torch.tensor(20), 50, torch.tensor(0).bool()), # test scalar
(torch.tensor([20, 40, 60, 80]), 50, torch.tensor([0, 0, 1, 1]).bool()), # test 1d tensor
(torch.tensor([[20, 40], [60, 80]]), 50, torch.tensor([[0, 0], [1, 1]]).bool()), # test 2d tensor
(torch.tensor(50), 50, torch.tensor(0).bool()), # test on threshold labeled as normal
(torch.tensor([-80, -60, -40, -20]), -50, torch.tensor([0, 0, 1, 1]).bool()), # test negative
],
)
def test_apply_threshold(preds: torch.Tensor, thresh: float, target: torch.Tensor) -> None:
"""Test the apply_threshold method."""
pre_processor = OneClassPostProcessor()
binary_preds = pre_processor._apply_threshold(preds, torch.tensor(thresh)) # noqa: SLF001
assert torch.allclose(binary_preds, target)

@staticmethod
def test_thresholds_computed() -> None:
"""Test that both image and pixel threshold are computed correctly."""
batch = ImageBatch(
image=torch.rand(4, 3, 3, 3),
anomaly_map=torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]),
gt_mask=torch.tensor([[0, 0, 0], [0, 0, 0], [0, 1, 1]]),
pred_score=torch.tensor([20, 40, 60, 80]),
gt_label=torch.tensor([0, 0, 1, 1]),
)
pre_processor = OneClassPostProcessor()
pre_processor.on_validation_batch_end(None, None, batch)
pre_processor.on_validation_epoch_end(None, None)
assert pre_processor.image_threshold == 60
assert pre_processor.pixel_threshold == 80

@staticmethod
def test_pixel_threshold_matching() -> None:
"""Test that pixel_threshold is used as image threshold when no gt masks are available."""
batch = ImageBatch(
image=torch.rand(4, 3, 10, 10),
anomaly_map=torch.rand(4, 10, 10),
pred_score=torch.tensor([20, 40, 60, 80]),
gt_label=torch.tensor([0, 0, 1, 1]),
)
pre_processor = OneClassPostProcessor(enable_threshold_matching=True)
pre_processor.on_validation_batch_end(None, None, batch)
pre_processor.on_validation_epoch_end(None, None)
assert pre_processor.image_threshold == pre_processor.pixel_threshold

@staticmethod
def test_image_threshold_matching() -> None:
"""Test that pixel_threshold is used as image threshold when no gt masks are available."""
batch = ImageBatch(
image=torch.rand(4, 3, 3, 3),
anomaly_map=torch.tensor([[10, 20, 30], [40, 50, 60], [70, 80, 90]]),
gt_mask=torch.tensor([[0, 0, 0], [0, 0, 0], [0, 1, 1]]),
pred_score=torch.tensor([20, 40, 60, 80]),
)
pre_processor = OneClassPostProcessor(enable_threshold_matching=True)
pre_processor.on_validation_batch_end(None, None, batch)
pre_processor.on_validation_epoch_end(None, None)
assert pre_processor.image_threshold == pre_processor.pixel_threshold
Loading