diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 805ac30..40b10ba 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -5,13 +5,14 @@ import numpy as np import pytest import torch -from mfai.torch.metrics import CSINeighborood +from mfai.torch.metrics import CSINeighborood, PR_AUC, FAR, FNR -@pytest.mark.parametrize("num_neighbors,expected", [(0, 0.36), (1, 0.91)]) -def test_csi_binary(num_neighbors: int, expected: float): +@pytest.mark.parametrize("num_neighbors,expected_value", [(0, 0.36), (1, 0.91)]) +def test_csi_binary(num_neighbors: int, expected_value: float): """ - Build two 5x5 tensors and compute the CSI for binary task. + Build tensors of size (2, 1, 5, 5), compute the CSI for binary task and check if the output is + the result expected. """ y_true = torch.tensor( np.array( @@ -66,15 +67,14 @@ def test_csi_binary(num_neighbors: int, expected: float): csi.update(preds=y_hat, targets=y_true) csi_score = torch.round(csi.compute(), decimals=2).float() - assert ( - csi_score - expected < 0.01 - ), f"Failed to compute the CSI, return {csi_score} instead of {expected}." + assert pytest.approx(csi_score, 0.001) == expected_value -@pytest.mark.parametrize("num_neighbors,expected", [(0, 0.43), (1, 0.79)]) -def test_csi_multiclass(num_neighbors: int, expected: torch.Tensor): +@pytest.mark.parametrize("num_neighbors,expected_value", [(0, 0.43), (1, 0.79)]) +def test_csi_multiclass(num_neighbors: int, expected_value: torch.Tensor): """ - Build two 5x5 tensors and compute the CSI for binary task. + Build tensors of size (1, 1, 5, 5), compute the CSI for multiclass taskand check if the output is + the result expected. """ y_true = torch.tensor( np.array( @@ -118,15 +118,14 @@ def test_csi_multiclass(num_neighbors: int, expected: torch.Tensor): print(csi.false_negatives) print(csi_score) - assert ( - torch.sum(csi_score - expected) < 0.01 - ), f"Failed to compute the CSI, return {csi_score} instead of {expected}." + assert pytest.approx(csi_score, 0.001) == expected_value -@pytest.mark.parametrize("num_neighbors,expected", [(0, 0.36), (1, 0.81)]) -def test_csi_multilabel(num_neighbors: int, expected: torch.Tensor): +@pytest.mark.parametrize("num_neighbors,expected_value", [(0, 0.36), (1, 0.81)]) +def test_csi_multilabel(num_neighbors: int, expected_value: torch.Tensor): """ - Build two 5x5 tensors and compute the CSI for binary task. + Build tensors of size (1, 3, 5, 5), compute the CSI for multilabel task and check if the output is + the result expected. """ y_true = torch.tensor( np.array( @@ -191,6 +190,43 @@ def test_csi_multilabel(num_neighbors: int, expected: torch.Tensor): csi.update(preds=y_hat, targets=y_true) csi_score = torch.round(csi.compute(), decimals=2).float() - assert ( - torch.sum(csi_score - expected) < 0.001 - ), f"Failed to compute the CSI, return {csi_score} instead of {expected}." + assert pytest.approx(csi_score, 0.001) == expected_value + + +def test_pr_auc(): + """ + Test of the compute of the Precision-Recall Area Under the Curve. + """ + preds = torch.tensor([0.0, 1.0, 0.0, 1.0]) + targets = torch.tensor([0, 0, 1, 1]) + far = PR_AUC() + far.update(preds, targets) + auc_value = far.compute() + expected_value = 0.125 + assert pytest.approx(auc_value, 0.001) == expected_value + + +def test_far(): + """ + Test of the compute of the False Alarm Rate. + """ + preds = torch.tensor([0.0, 1.0, 0.0, 1.0]) + targets = torch.tensor([0, 0, 1, 1]) + far = FAR("binary") + far.update(preds, targets) + auc_value = far.compute() + expected_value = 0.5 + assert pytest.approx(auc_value, 0.001) == expected_value + + +def test_fnr(): + """ + Test of the compute of the False Alarm Rate. + """ + preds = torch.tensor([0.0, 1.0, 0.0, 1.0]) + targets = torch.tensor([0, 0, 1, 1]) + fnr = FNR() + fnr.update(preds, targets) + auc_value = fnr.compute() + expected_value = 0.5 + assert pytest.approx(auc_value, 0.001) == expected_value