diff --git a/tests/test_efaar.py b/tests/test_efaar.py index 1878933..7f1294e 100644 --- a/tests/test_efaar.py +++ b/tests/test_efaar.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +import pytest from anndata import AnnData from efaar_benchmarking import efaar @@ -63,3 +64,30 @@ def test_filter_cell_profiler_features(): filtered_features, _ = efaar.filter_cell_profiler_features(features, metadata) assert not filtered_features.empty assert "Image_Granularity_12_ER" not in filtered_features.columns + + +@pytest.mark.parametrize( + "metadata, expected_shape, expect_error", + [ + ( + pd.DataFrame({"perturbation": ["control", "control", "treatment", "treatment"], "batch": [1, 1, 1, 1]}), + (4, 2), + False, + ), + ( + pd.DataFrame({"perturbation": ["treatment", "treatment", "treatment", "treatment"], "batch": [1, 1, 1, 1]}), + None, + True, + ), + ], +) +def test_pca_centerscale_on_controls(metadata, expected_shape, expect_error): + embeddings = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) + pert_col = "perturbation" + control_key = "control" + if expect_error: + with pytest.raises(ValueError, match=f"No control samples found for {control_key}"): + efaar.pca_centerscale_on_controls(embeddings, metadata, pert_col, control_key, None) + else: + result = efaar.pca_centerscale_on_controls(embeddings, metadata, pert_col, control_key, None) + assert result.shape == expected_shape