From f60b1ac230ab97bd3c6ad2c210f79882293b306d Mon Sep 17 00:00:00 2001 From: AmenRa Date: Mon, 1 Jul 2024 19:44:52 +0200 Subject: [PATCH] Formatting --- ranx/statistical_tests/__init__.py | 18 +++++++++--------- tests/unit/ranx/fusion/logn_isr_test.py | 24 ++++++------------------ tests/unit/ranx/meta/evaluate_test.py | 12 ++---------- tests/unit/ranx/metrics_test.py | 12 ++---------- 4 files changed, 19 insertions(+), 47 deletions(-) diff --git a/ranx/statistical_tests/__init__.py b/ranx/statistical_tests/__init__.py index 338895a..9f20e51 100644 --- a/ranx/statistical_tests/__init__.py +++ b/ranx/statistical_tests/__init__.py @@ -72,15 +72,15 @@ def compute_statistical_significance( treatment_metric_scores = metric_scores[treatment] # Compute statistical significance - comparisons[ - frozenset([control, treatment]) - ] = _compute_statistical_significance( - control_metric_scores, - treatment_metric_scores, - stat_test, - n_permutations, - max_p, - random_seed, + comparisons[frozenset([control, treatment])] = ( + _compute_statistical_significance( + control_metric_scores, + treatment_metric_scores, + stat_test, + n_permutations, + max_p, + random_seed, + ) ) return comparisons diff --git a/tests/unit/ranx/fusion/logn_isr_test.py b/tests/unit/ranx/fusion/logn_isr_test.py index db3641e..2b0986b 100644 --- a/tests/unit/ranx/fusion/logn_isr_test.py +++ b/tests/unit/ranx/fusion/logn_isr_test.py @@ -48,21 +48,9 @@ def test_logn_isr(run_1, run_2, run_3): assert len(combined_run["q1"]) == 3 assert len(combined_run["q2"]) == 3 - assert combined_run["q1"]["d1"] == ((1 / (3**2)) + (1 / (1**2))) * np.log( - 2 + sigma - ) - assert combined_run["q1"]["d2"] == ((1 / (2**2)) + (1 / (2**2))) * np.log( - 2 + sigma - ) - assert combined_run["q1"]["d3"] == ((1 / (1**2)) + (1 / (1**2))) * np.log( - 2 + sigma - ) - assert combined_run["q2"]["d1"] == ((1 / (2**2)) + (1 / (2**2))) * np.log( - 2 + sigma - ) - assert combined_run["q2"]["d2"] == ((1 / (1**2)) + (1 / (2**2))) * np.log( - 2 + sigma - ) - assert combined_run["q2"]["d3"] == ((1 / (1**2)) + (1 / (1**2))) * np.log( - 2 + sigma - ) + assert combined_run["q1"]["d1"] == ((1 / (3**2)) + (1 / (1**2))) * np.log(2 + sigma) + assert combined_run["q1"]["d2"] == ((1 / (2**2)) + (1 / (2**2))) * np.log(2 + sigma) + assert combined_run["q1"]["d3"] == ((1 / (1**2)) + (1 / (1**2))) * np.log(2 + sigma) + assert combined_run["q2"]["d1"] == ((1 / (2**2)) + (1 / (2**2))) * np.log(2 + sigma) + assert combined_run["q2"]["d2"] == ((1 / (1**2)) + (1 / (2**2))) * np.log(2 + sigma) + assert combined_run["q2"]["d3"] == ((1 / (1**2)) + (1 / (1**2))) * np.log(2 + sigma) diff --git a/tests/unit/ranx/meta/evaluate_test.py b/tests/unit/ranx/meta/evaluate_test.py index 4b9cc6e..8c3be5f 100644 --- a/tests/unit/ranx/meta/evaluate_test.py +++ b/tests/unit/ranx/meta/evaluate_test.py @@ -298,11 +298,7 @@ def test_dcg_burges(): assert np.allclose( evaluate(y_true, y_pred_1, f"dcg_burges@{k}"), - ( - (2**5 - 1) / np.log2(3) - + (2**4 - 1) / np.log2(5) - + (2**3 - 1) / np.log2(6) - ), + ((2**5 - 1) / np.log2(3) + (2**4 - 1) / np.log2(5) + (2**3 - 1) / np.log2(6)), ) assert np.allclose( @@ -343,11 +339,7 @@ def test_ndcg_burges(): assert np.allclose( evaluate(y_true, y_pred_1, f"ndcg_burges@{k}"), - ( - (2**5 - 1) / np.log2(3) - + (2**4 - 1) / np.log2(5) - + (2**3 - 1) / np.log2(6) - ) + ((2**5 - 1) / np.log2(3) + (2**4 - 1) / np.log2(5) + (2**3 - 1) / np.log2(6)) / idcg, ) diff --git a/tests/unit/ranx/metrics_test.py b/tests/unit/ranx/metrics_test.py index 615ba6e..a2677e0 100644 --- a/tests/unit/ranx/metrics_test.py +++ b/tests/unit/ranx/metrics_test.py @@ -256,11 +256,7 @@ def test_dcg_burges(): assert np.allclose( rm.dcg_burges(y_true, y_pred_1, k)[0], - ( - (2**5 - 1) / np.log2(3) - + (2**4 - 1) / np.log2(5) - + (2**3 - 1) / np.log2(6) - ), + ((2**5 - 1) / np.log2(3) + (2**4 - 1) / np.log2(5) + (2**3 - 1) / np.log2(6)), ) assert np.allclose( @@ -301,11 +297,7 @@ def test_ndcg_burges(): assert np.allclose( rm.ndcg_burges(y_true, y_pred_1, k)[0], - ( - (2**5 - 1) / np.log2(3) - + (2**4 - 1) / np.log2(5) - + (2**3 - 1) / np.log2(6) - ) + ((2**5 - 1) / np.log2(3) + (2**4 - 1) / np.log2(5) + (2**3 - 1) / np.log2(6)) / idcg, )