Skip to content

Commit

Permalink
address the serialization after fit method
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jan 24, 2025
1 parent a5c8f52 commit 264ed2e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 4 additions & 0 deletions pymc_marketing/mmm/hsgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def register_data(self, X: TensorLike) -> Self:
"""Register the data."""
...

Check warning on line 51 in pymc_marketing/mmm/hsgp.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/hsgp.py#L51

Added line #L51 was not covered by tests

def to_dict(self) -> dict:
"""Convert the object to a dictionary."""
...

Check warning on line 55 in pymc_marketing/mmm/hsgp.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/hsgp.py#L55

Added line #L55 was not covered by tests


@validate_call
def create_complexity_penalizing_prior(
Expand Down
12 changes: 10 additions & 2 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,16 @@ def create_idata_attrs(self) -> dict[str, str]:
attrs["channel_columns"] = json.dumps(self.channel_columns)
attrs["validate_data"] = json.dumps(self.validate_data)
attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
attrs["time_varying_media"] = json.dumps(self.time_varying_media)
attrs["time_varying_intercept"] = json.dumps(
self.time_varying_intercept
if not isinstance(self.time_varying_intercept, HSGPLike)
else self.time_varying_intercept.to_dict()
)
attrs["time_varying_media"] = json.dumps(
self.time_varying_media
if not isinstance(self.time_varying_media, HSGPLike)
else self.time_varying_media.to_dict()
)
attrs["dag"] = json.dumps(self.dag)
attrs["treatment_nodes"] = json.dumps(self.treatment_nodes)
attrs["outcome_node"] = json.dumps(self.outcome_node)
Expand Down

0 comments on commit 264ed2e

Please sign in to comment.