From 1d89f66e1404b53dc11af2afa8b1e6c4f42c0c3c Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+wd60622@users.noreply.github.com> Date: Mon, 2 Dec 2024 12:57:52 +0100 Subject: [PATCH] Raise error when two dims are not compatible (#1249) * add quick fail for incompat dims * correct the error message * more test below related test --- pymc_marketing/prior.py | 8 +++++++- tests/test_prior.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pymc_marketing/prior.py b/pymc_marketing/prior.py index c68ed49a5..8d82f29a3 100644 --- a/pymc_marketing/prior.py +++ b/pymc_marketing/prior.py @@ -110,7 +110,7 @@ def custom_transform(x): class UnsupportedShapeError(Exception): - """Error for when the shape of the hierarchical variable is not supported.""" + """Error for when the shapes from variables are not compatible.""" class UnsupportedDistributionError(Exception): @@ -169,6 +169,12 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa dims = dims if isinstance(dims, tuple) else (dims,) desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,) + if difference := set(dims).difference(desired_dims): + raise UnsupportedShapeError( + f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. " + f"{difference} is missing from the desired dims." + ) + aligned_dims = np.array(dims)[:, None] == np.array(desired_dims) missing_dims = aligned_dims.sum(axis=0) == 0 diff --git a/tests/test_prior.py b/tests/test_prior.py index edc5e2ae4..7f0b870cd 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -72,6 +72,20 @@ def test_handle_dims(x, dims, desired_dims, expected_fn) -> None: np.testing.assert_array_equal(result, expected_fn(x)) +@pytest.mark.parametrize( + "x, dims, desired_dims", + [ + (np.ones(3), "channel", "something_else"), + (np.ones((3, 2)), ("a", "b"), ("a", "B")), + ], + ids=["no_incommon", "some_incommon"], +) +def test_handle_dims_with_impossible_dims(x, dims, desired_dims) -> None: + match = " are not a subset of the desired dims " + with pytest.raises(UnsupportedShapeError, match=match): + handle_dims(x, dims, desired_dims) + + def test_missing_transform() -> None: match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'" with pytest.raises(UnknownTransformError, match=match):