Skip to content

Commit

Permalink
support for deserialization of three classes
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Dec 17, 2024
1 parent 6d70476 commit e97d809
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def function(self, x, alpha):
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

from pymc_marketing.deserialize import register_deserialization
from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
ConvMode,
Expand Down Expand Up @@ -343,3 +344,13 @@ def adstock_from_dict(data: dict) -> AdstockTransformation:
if "priors" in data:
data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()}
return cls(**data)


def _is_adstock(data):
return "lookup_name" in data and data["lookup_name"] in ADSTOCK_TRANSFORMATIONS


register_deserialization(
is_type=_is_adstock,
deserialize=adstock_from_dict,
)
8 changes: 8 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def function(self, x, b):
import xarray as xr
from pydantic import Field, InstanceOf, validate_call

from pymc_marketing.deserialize import register_deserialization
from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
hill_function,
Expand Down Expand Up @@ -483,3 +484,10 @@ def saturation_from_dict(data: dict) -> SaturationTransformation:
key: Prior.from_json(value) for key, value in data["priors"].items()
}
return cls(**data)


def _is_saturation(data):
return "lookup_name" in data and data["lookup_name"] in SATURATION_TRANSFORMATIONS


register_deserialization(_is_saturation, saturation_from_dict)
9 changes: 9 additions & 0 deletions pymc_marketing/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def custom_transform(x):
from pydantic import validate_call
from pymc.distributions.shape_utils import Dims

from pymc_marketing.deserialize import register_deserialization


class UnsupportedShapeError(Exception):
"""Error for when the shapes from variables are not compatible."""
Expand Down Expand Up @@ -987,3 +989,10 @@ def create_likelihood_variable(
distribution.parameters["mu"] = mu
distribution.parameters["observed"] = observed
return distribution.create_variable(name)


def _is_prior_type(data: dict) -> bool:
return "dist" in data


register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_json)

0 comments on commit e97d809

Please sign in to comment.