Skip to content

Commit

Permalink
1271 predictions=true for sample posterior predictive throws error (#…
Browse files Browse the repository at this point in the history
…1272)

* group conditional on the kwargs

* flip the case

* test for predictions kwarg
  • Loading branch information
wd60622 authored Dec 13, 2024
1 parent 7007b4f commit 250a8ee
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,9 +1929,12 @@ def sample_posterior_predictive(
if extend_idata:
self.idata.extend(post_pred, join="right") # type: ignore

posterior_predictive_samples = az.extract(
post_pred, "posterior_predictive", combined=combined
group = (
"predictions"
if sample_posterior_predictive_kwargs.get("predictions", False)
else "posterior_predictive"
)
posterior_predictive_samples = az.extract(post_pred, group, combined=combined)

if include_last_observations:
posterior_predictive_samples = posterior_predictive_samples.isel(
Expand Down
24 changes: 24 additions & 0 deletions tests/mmm/test_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,30 @@ def test_new_data_sample_posterior_predictive_method(
)


@pytest.mark.parametrize(
"predictions",
[True, False],
)
def test_sample_posterior_predictive_with_prediction_kwarg(
generate_data,
mmm_fitted,
predictions: bool,
) -> None:
new_dates = pd.date_range("2022-01-01", "2022-03-01", freq="W-MON")
X_pred = generate_data(new_dates)

predictions = mmm_fitted.sample_posterior_predictive(
X_pred=X_pred,
extend_idata=False,
combined=True,
predictions=predictions,
)
pd.testing.assert_index_equal(
pd.DatetimeIndex(predictions.coords["date"]),
new_dates,
)


@pytest.mark.parametrize(
"model_name", ["mmm_fitted", "mmm_fitted_with_fourier_features"]
)
Expand Down

0 comments on commit 250a8ee

Please sign in to comment.