Skip to content

Commit

Permalink
1325 invalid way of calling first element of pdseries (#1433)
Browse files Browse the repository at this point in the history
* add test cases with index not starting at zero

* make the change and add docs
  • Loading branch information
wd60622 authored Jan 24, 2025
1 parent d9f0c51 commit 9994ee9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
21 changes: 19 additions & 2 deletions pymc_marketing/mmm/tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,27 @@ def create_time_varying_gp_multiplier(


def infer_time_index(
date_series_new: pd.Series, date_series: pd.Series, time_resolution: int
date_series_new: pd.Series,
date_series: pd.Series,
time_resolution: int,
) -> npt.NDArray[np.int_]:
"""Infer the time-index given a new dataset.
Infers the time-indices by calculating the number of days since the first date in the dataset.
Parameters
----------
date_series_new : pd.Series
New date series.
date_series : pd.Series
Original date series.
time_resolution : int
Resolution of time points in days.
Returns
-------
np.ndarray
Time index.
"""
return (date_series_new - date_series[0]).dt.days.values // time_resolution
return (date_series_new - date_series.iloc[0]).dt.days.values // time_resolution
44 changes: 21 additions & 23 deletions tests/mmm/test_tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,28 +149,26 @@ def test_create_time_varying_intercept(coords, model_config):
assert isinstance(result, pt.TensorVariable)


@pytest.mark.parametrize("freq, time_resolution", [("D", 1), ("W", 7)])
def test_infer_time_index_in_sample(freq, time_resolution):
date_series = pd.Series(pd.date_range(start="1/1/2022", periods=5, freq=freq))
date_series_new = date_series
expected = np.arange(0, 5)
result = infer_time_index(date_series_new, date_series, time_resolution)
np.testing.assert_array_equal(result, expected)


@pytest.mark.parametrize("freq, time_resolution", [("D", 1), ("W", 7)])
def test_infer_time_index_oos_forward(freq, time_resolution):
date_series = pd.Series(pd.date_range(start="1/1/2022", periods=5, freq=freq))
date_series_new = date_series + pd.Timedelta(5, unit=freq)
expected = np.arange(5, 10)
result = infer_time_index(date_series_new, date_series, time_resolution)
np.testing.assert_array_equal(result, expected)


@pytest.mark.parametrize("freq, time_resolution", [("D", 1), ("W", 7)])
def test_infer_time_index_oos_backward(freq, time_resolution):
date_series = pd.Series(pd.date_range(start="1/1/2022", periods=5, freq=freq))
date_series_new = date_series - pd.Timedelta(5, unit=freq)
expected = np.arange(-5, 0)
@pytest.mark.parametrize(
"freq, time_resolution",
[
pytest.param("D", 1, id="daily"),
pytest.param("W", 7, id="weekly"),
],
)
@pytest.mark.parametrize(
"index",
[np.arange(5), np.arange(5) + 10],
ids=["zero-start", "non-zero-start"],
)
@pytest.mark.parametrize(
"offset, expected",
[(0, np.arange(0, 5)), (5, np.arange(5, 10)), (-5, np.arange(-5, 0))],
ids=["in-sample", "oos_forward", "oos_backward"],
)
def test_infer_time_index(freq, time_resolution, index, offset, expected):
dates = pd.date_range(start="1/1/2022", periods=5, freq=freq)
date_series = pd.Series(dates, index=index)
date_series_new = date_series + pd.Timedelta(offset, unit=freq)
result = infer_time_index(date_series_new, date_series, time_resolution)
np.testing.assert_array_equal(result, expected)

0 comments on commit 9994ee9

Please sign in to comment.