Skip to content

Commit

Permalink
adding test for PR_AUC, FAR and FNR
Browse files Browse the repository at this point in the history
  • Loading branch information
tourniert committed Sep 4, 2024
1 parent 9629979 commit 1f8ff85
Showing 1 changed file with 55 additions and 19 deletions.
74 changes: 55 additions & 19 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import numpy as np
import pytest
import torch
from mfai.torch.metrics import CSINeighborood
from mfai.torch.metrics import CSINeighborood, PR_AUC, FAR, FNR


@pytest.mark.parametrize("num_neighbors,expected", [(0, 0.36), (1, 0.91)])
def test_csi_binary(num_neighbors: int, expected: float):
@pytest.mark.parametrize("num_neighbors,expected_value", [(0, 0.36), (1, 0.91)])
def test_csi_binary(num_neighbors: int, expected_value: float):
"""
Build two 5x5 tensors and compute the CSI for binary task.
Build tensors of size (2, 1, 5, 5), compute the CSI for binary task and check if the output is
the result expected.
"""
y_true = torch.tensor(
np.array(
Expand Down Expand Up @@ -66,15 +67,14 @@ def test_csi_binary(num_neighbors: int, expected: float):
csi.update(preds=y_hat, targets=y_true)
csi_score = torch.round(csi.compute(), decimals=2).float()

assert (
csi_score - expected < 0.01
), f"Failed to compute the CSI, return {csi_score} instead of {expected}."
assert pytest.approx(csi_score, 0.001) == expected_value


@pytest.mark.parametrize("num_neighbors,expected", [(0, 0.43), (1, 0.79)])
def test_csi_multiclass(num_neighbors: int, expected: torch.Tensor):
@pytest.mark.parametrize("num_neighbors,expected_value", [(0, 0.43), (1, 0.79)])
def test_csi_multiclass(num_neighbors: int, expected_value: torch.Tensor):
"""
Build two 5x5 tensors and compute the CSI for binary task.
Build tensors of size (1, 1, 5, 5), compute the CSI for multiclass taskand check if the output is
the result expected.
"""
y_true = torch.tensor(
np.array(
Expand Down Expand Up @@ -118,15 +118,14 @@ def test_csi_multiclass(num_neighbors: int, expected: torch.Tensor):
print(csi.false_negatives)
print(csi_score)

assert (
torch.sum(csi_score - expected) < 0.01
), f"Failed to compute the CSI, return {csi_score} instead of {expected}."
assert pytest.approx(csi_score, 0.001) == expected_value


@pytest.mark.parametrize("num_neighbors,expected", [(0, 0.36), (1, 0.81)])
def test_csi_multilabel(num_neighbors: int, expected: torch.Tensor):
@pytest.mark.parametrize("num_neighbors,expected_value", [(0, 0.36), (1, 0.81)])
def test_csi_multilabel(num_neighbors: int, expected_value: torch.Tensor):
"""
Build two 5x5 tensors and compute the CSI for binary task.
Build tensors of size (1, 3, 5, 5), compute the CSI for multilabel task and check if the output is
the result expected.
"""
y_true = torch.tensor(
np.array(
Expand Down Expand Up @@ -191,6 +190,43 @@ def test_csi_multilabel(num_neighbors: int, expected: torch.Tensor):
csi.update(preds=y_hat, targets=y_true)
csi_score = torch.round(csi.compute(), decimals=2).float()

assert (
torch.sum(csi_score - expected) < 0.001
), f"Failed to compute the CSI, return {csi_score} instead of {expected}."
assert pytest.approx(csi_score, 0.001) == expected_value


def test_pr_auc():
"""
Test of the compute of the Precision-Recall Area Under the Curve.
"""
preds = torch.tensor([0.0, 1.0, 0.0, 1.0])
targets = torch.tensor([0, 0, 1, 1])
far = PR_AUC()
far.update(preds, targets)
auc_value = far.compute()
expected_value = 0.125
assert pytest.approx(auc_value, 0.001) == expected_value


def test_far():
"""
Test of the compute of the False Alarm Rate.
"""
preds = torch.tensor([0.0, 1.0, 0.0, 1.0])
targets = torch.tensor([0, 0, 1, 1])
far = FAR("binary")
far.update(preds, targets)
auc_value = far.compute()
expected_value = 0.5
assert pytest.approx(auc_value, 0.001) == expected_value


def test_fnr():
"""
Test of the compute of the False Alarm Rate.
"""
preds = torch.tensor([0.0, 1.0, 0.0, 1.0])
targets = torch.tensor([0, 0, 1, 1])
fnr = FNR()
fnr.update(preds, targets)
auc_value = fnr.compute()
expected_value = 0.5
assert pytest.approx(auc_value, 0.001) == expected_value

0 comments on commit 1f8ff85

Please sign in to comment.