Skip to content

Commit

Permalink
multiply by the media variable if present
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Dec 3, 2024
1 parent 043c593 commit cf666ad
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,15 @@ def channel_contributions_forward_pass(
progressbar=False,
)

return idata.posterior_predictive.channel_contributions.to_numpy()
channel_contributions = idata.posterior_predictive.channel_contributions
if self.time_varying_media:
# This is coupled with the name of the
# latent process Deterministic
name = "media_temporal_latent_multiplier"
mutliplier = self.fit_result[name]
channel_contributions = channel_contributions * mutliplier

Check warning on line 631 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L629-L631

Added lines #L629 - L631 were not covered by tests

return channel_contributions.to_numpy()

@property
def _serializable_model_config(self) -> dict[str, Any]:
Expand Down Expand Up @@ -996,7 +1004,8 @@ def get_channel_contributions_forward_pass_grid(
delta * self.preprocessed_data["X"][self.channel_columns].to_numpy()
)
channel_contribution_forward_pass = self.channel_contributions_forward_pass(
channel_data=channel_data, disable_logger_stdout=True
channel_data=channel_data,
disable_logger_stdout=True,
)
channel_contributions.append(channel_contribution_forward_pass)
return DataArray(
Expand Down

0 comments on commit cf666ad

Please sign in to comment.