diff --git a/pymc_marketing/mmm/mmm.py b/pymc_marketing/mmm/mmm.py index c5c32f441..71d4c1772 100644 --- a/pymc_marketing/mmm/mmm.py +++ b/pymc_marketing/mmm/mmm.py @@ -1929,9 +1929,12 @@ def sample_posterior_predictive( if extend_idata: self.idata.extend(post_pred, join="right") # type: ignore - posterior_predictive_samples = az.extract( - post_pred, "posterior_predictive", combined=combined + group = ( + "predictions" + if sample_posterior_predictive_kwargs.get("predictions", False) + else "posterior_predictive" ) + posterior_predictive_samples = az.extract(post_pred, group, combined=combined) if include_last_observations: posterior_predictive_samples = posterior_predictive_samples.isel( diff --git a/tests/mmm/test_mmm.py b/tests/mmm/test_mmm.py index 1df4ed0a2..7617a9fd9 100644 --- a/tests/mmm/test_mmm.py +++ b/tests/mmm/test_mmm.py @@ -897,6 +897,30 @@ def test_new_data_sample_posterior_predictive_method( ) +@pytest.mark.parametrize( + "predictions", + [True, False], +) +def test_sample_posterior_predictive_with_prediction_kwarg( + generate_data, + mmm_fitted, + predictions: bool, +) -> None: + new_dates = pd.date_range("2022-01-01", "2022-03-01", freq="W-MON") + X_pred = generate_data(new_dates) + + predictions = mmm_fitted.sample_posterior_predictive( + X_pred=X_pred, + extend_idata=False, + combined=True, + predictions=predictions, + ) + pd.testing.assert_index_equal( + pd.DatetimeIndex(predictions.coords["date"]), + new_dates, + ) + + @pytest.mark.parametrize( "model_name", ["mmm_fitted", "mmm_fitted_with_fourier_features"] )