Skip to content

Commit

Permalink
events effects gaussian bumps
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Feb 3, 2025
1 parent 79a8008 commit 3ed272a
Showing 1 changed file with 207 additions and 0 deletions.
207 changes: 207 additions & 0 deletions pymc_marketing/mmm/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from dataclasses import dataclass

Check warning on line 1 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L1

Added line #L1 was not covered by tests

import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

Check warning on line 7 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L3-L7

Added lines #L3 - L7 were not covered by tests

from pymc_marketing.deserialize import deserialize, register_deserialization
from pymc_marketing.mmm.components.base import Transformation, create_registration_meta
from pymc_marketing.prior import Prior, create_dim_handler

Check warning on line 11 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L9-L11

Added lines #L9 - L11 were not covered by tests

BASIS_TRANSFORMATIONS = {}
BasisMeta = create_registration_meta(BASIS_TRANSFORMATIONS)

Check warning on line 14 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L13-L14

Added lines #L13 - L14 were not covered by tests


class Basis(Transformation, metaclass=BasisMeta):
prefix: str = "basis"
lookup_name: str

Check warning on line 19 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L17-L19

Added lines #L17 - L19 were not covered by tests

@validate_call
def sample_curve(

Check warning on line 22 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L21-L22

Added lines #L21 - L22 were not covered by tests
self,
parameters: InstanceOf[xr.Dataset] = Field(
..., description="Parameters of the saturation transformation."
),
days: int = Field(0, ge=0, description="Minimum number of days."),
) -> xr.DataArray:
"""Sample the curve of the saturation transformation given parameters.
Parameters
----------
parameters : xr.Dataset
Dataset with the parameters of the saturation transformation.
max_value : float, optional
Maximum value of the curve, by default 1.0.
Returns
-------
xr.DataArray
Curve of the saturation transformation.
"""
x = np.linspace(-days, days, 100)

Check warning on line 44 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L44

Added line #L44 was not covered by tests

coords = {

Check warning on line 46 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L46

Added line #L46 was not covered by tests
"x": x,
}

return self._sample_curve(

Check warning on line 50 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L50

Added line #L50 was not covered by tests
var_name="saturation",
parameters=parameters,
x=x,
coords=coords,
)


def basis_from_dict(data: dict) -> Basis:

Check warning on line 58 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L58

Added line #L58 was not covered by tests
"""Create a basis transformation from a dictionary."""
data = data.copy()
lookup_name = data.pop("lookup_name")
cls = BASIS_TRANSFORMATIONS[lookup_name]

Check warning on line 62 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L60-L62

Added lines #L60 - L62 were not covered by tests

if "priors" in data:
data["priors"] = {k: deserialize(v) for k, v in data["priors"].items()}

Check warning on line 65 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L64-L65

Added lines #L64 - L65 were not covered by tests

return cls(**data)

Check warning on line 67 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L67

Added line #L67 was not covered by tests


def _is_basis(data):
return "lookup_name" in data and data["lookup_name"] in BASIS_TRANSFORMATIONS

Check warning on line 71 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L70-L71

Added lines #L70 - L71 were not covered by tests


register_deserialization(

Check warning on line 74 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L74

Added line #L74 was not covered by tests
is_type=_is_basis,
deserialize=basis_from_dict,
)


@dataclass
class EventEffect:
basis: Basis
effect_size: Prior
dims: tuple[str, ...]

Check warning on line 84 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L80-L84

Added lines #L80 - L84 were not covered by tests

def apply(self, X, name: str = "event"):

Check warning on line 86 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L86

Added line #L86 was not covered by tests
"""Apply the event effect to the data."""
dim_handler = create_dim_handler(("x", *self.dims))
return self.basis.apply(X, dims=self.dims) * dim_handler(

Check warning on line 89 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L88-L89

Added lines #L88 - L89 were not covered by tests
self.effect_size.create_variable(f"{name}_effect_size"),
self.effect_size.dims,
)

def to_dict(self):
return {

Check warning on line 95 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L94-L95

Added lines #L94 - L95 were not covered by tests
"class": "EventEffect",
"data": {
"basis": self.basis.to_dict(),
"effect_size": self.effect_size.to_dict(),
"dims": self.dims,
},
}

@classmethod
def from_dict(cls, data):
return cls(

Check warning on line 106 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L104-L106

Added lines #L104 - L106 were not covered by tests
basis=deserialize(data["basis"]),
effect_size=deserialize(data["effect_size"]),
dims=data["dims"],
)


def _is_event_effect(data):
return data["class"] == "EventEffect"

Check warning on line 114 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L113-L114

Added lines #L113 - L114 were not covered by tests


register_deserialization(

Check warning on line 117 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L117

Added line #L117 was not covered by tests
is_type=_is_event_effect,
deserialize=lambda data: EventEffect.from_dict(data["data"]),
)


if __name__ == "__main__":
import matplotlib.pyplot as plt

Check warning on line 124 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L123-L124

Added lines #L123 - L124 were not covered by tests

from pymc_marketing.plot import plot_curve

Check warning on line 126 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L126

Added line #L126 was not covered by tests

class GaussianBasis(Basis):
lookup_name = "gaussian"

Check warning on line 129 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L128-L129

Added lines #L128 - L129 were not covered by tests

def function(self, x, sigma):
return pm.math.exp(-0.5 * (x / sigma) ** 2)

Check warning on line 132 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L131-L132

Added lines #L131 - L132 were not covered by tests

default_priors = {

Check warning on line 134 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L134

Added line #L134 was not covered by tests
"sigma": Prior("Gamma", mu=7, sigma=1),
}

gaussian = GaussianBasis(

Check warning on line 138 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L138

Added line #L138 was not covered by tests
priors={
"sigma": Prior("Gamma", mu=[4, 7, 10], sigma=1, dims="event"),
},
)
coords = {"event": ["NYE", "Grand Opening Game Show", "Super Bowl"]}
prior = gaussian.sample_prior(coords=coords)
curve = gaussian.sample_curve(prior, days=21)

Check warning on line 145 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L143-L145

Added lines #L143 - L145 were not covered by tests

fig, axes = gaussian.plot_curve(curve, same_axes=True)
fig.suptitle("Gaussian Basis")
plt.savefig("gaussian-basis")
plt.close()

Check warning on line 150 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L147-L150

Added lines #L147 - L150 were not covered by tests

df_events = pd.DataFrame(

Check warning on line 152 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L152

Added line #L152 was not covered by tests
{
"event": ["first", "second"],
"start_date": pd.to_datetime(["2023-01-01", "2023-01-20"]),
"end_date": pd.to_datetime(["2023-01-02", "2023-01-25"]),
}
)

def difference_in_days(model_dates, event_dates):
if hasattr(model_dates, "to_numpy"):
model_dates = model_dates.to_numpy()
if hasattr(event_dates, "to_numpy"):
event_dates = event_dates.to_numpy()

Check warning on line 164 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L160-L164

Added lines #L160 - L164 were not covered by tests

one_day = np.timedelta64(1, "D")
return (model_dates[:, None] - event_dates) / one_day

Check warning on line 167 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L166-L167

Added lines #L166 - L167 were not covered by tests

def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
start_dates = df_events["start_date"]
end_dates = df_events["end_date"]

Check warning on line 171 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L169-L171

Added lines #L169 - L171 were not covered by tests

s_ref = difference_in_days(model_dates, start_dates)
e_ref = difference_in_days(model_dates, end_dates)

Check warning on line 174 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L173-L174

Added lines #L173 - L174 were not covered by tests

return np.where(

Check warning on line 176 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L176

Added line #L176 was not covered by tests
(s_ref >= 0) & (e_ref <= 0),
0,
np.where(np.abs(s_ref) < np.abs(e_ref), s_ref, e_ref),
)

gaussian = GaussianBasis(

Check warning on line 182 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L182

Added line #L182 was not covered by tests
priors={
"sigma": Prior("Gamma", mu=7, sigma=1, dims="event"),
}
)
effect_size = Prior("Normal", mu=1, sigma=1, dims="event")
effect = EventEffect(basis=gaussian, effect_size=effect_size, dims=("event",))

Check warning on line 188 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L187-L188

Added lines #L187 - L188 were not covered by tests

dates = pd.date_range("2022-12-01", periods=3 * 31, freq="D")

Check warning on line 190 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L190

Added line #L190 was not covered by tests

X = create_basis_matrix(df_events, model_dates=dates)

Check warning on line 192 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L192

Added line #L192 was not covered by tests

coords = {"date": dates, "event": df_events["event"].to_numpy()}
with pm.Model(coords=coords) as model:
pm.Deterministic("effect", effect.apply(X), dims=("date", "event"))

Check warning on line 196 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L194-L196

Added lines #L194 - L196 were not covered by tests

idata = pm.sample_prior_predictive()

Check warning on line 198 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L198

Added line #L198 was not covered by tests

fig, axes = idata.prior.effect.pipe(

Check warning on line 200 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L200

Added line #L200 was not covered by tests
plot_curve,
{"date"},
subplot_kwargs={"ncols": 1},
)
fig.suptitle("Gaussian Event Effect")
plt.savefig("gaussian-event")
plt.close()

Check warning on line 207 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L205-L207

Added lines #L205 - L207 were not covered by tests

0 comments on commit 3ed272a

Please sign in to comment.