Skip to content

Commit

Permalink
Raise error when two dims are not compatible (#1249)
Browse files Browse the repository at this point in the history
* add quick fail for incompat dims

* correct the error message

* more test below related test
  • Loading branch information
wd60622 authored Dec 2, 2024
1 parent 9e1a89e commit 1d89f66
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
8 changes: 7 additions & 1 deletion pymc_marketing/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def custom_transform(x):


class UnsupportedShapeError(Exception):
"""Error for when the shape of the hierarchical variable is not supported."""
"""Error for when the shapes from variables are not compatible."""


class UnsupportedDistributionError(Exception):
Expand Down Expand Up @@ -169,6 +169,12 @@ def handle_dims(x: pt.TensorLike, dims: Dims, desired_dims: Dims) -> pt.TensorVa
dims = dims if isinstance(dims, tuple) else (dims,)
desired_dims = desired_dims if isinstance(desired_dims, tuple) else (desired_dims,)

if difference := set(dims).difference(desired_dims):
raise UnsupportedShapeError(
f"Dims {dims} of data are not a subset of the desired dims {desired_dims}. "
f"{difference} is missing from the desired dims."
)

aligned_dims = np.array(dims)[:, None] == np.array(desired_dims)

missing_dims = aligned_dims.sum(axis=0) == 0
Expand Down
14 changes: 14 additions & 0 deletions tests/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,20 @@ def test_handle_dims(x, dims, desired_dims, expected_fn) -> None:
np.testing.assert_array_equal(result, expected_fn(x))


@pytest.mark.parametrize(
"x, dims, desired_dims",
[
(np.ones(3), "channel", "something_else"),
(np.ones((3, 2)), ("a", "b"), ("a", "B")),
],
ids=["no_incommon", "some_incommon"],
)
def test_handle_dims_with_impossible_dims(x, dims, desired_dims) -> None:
match = " are not a subset of the desired dims "
with pytest.raises(UnsupportedShapeError, match=match):
handle_dims(x, dims, desired_dims)


def test_missing_transform() -> None:
match = "Neither pytensor.tensor nor pymc.math have the function 'foo_bar'"
with pytest.raises(UnknownTransformError, match=match):
Expand Down

0 comments on commit 1d89f66

Please sign in to comment.