From e97d8092595f3d7eeff52b8c3613e2f34b7078a6 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 17 Dec 2024 21:29:13 +0100 Subject: [PATCH] support for deserialization of three classes --- pymc_marketing/mmm/components/adstock.py | 11 +++++++++++ pymc_marketing/mmm/components/saturation.py | 8 ++++++++ pymc_marketing/prior.py | 9 +++++++++ 3 files changed, 28 insertions(+) diff --git a/pymc_marketing/mmm/components/adstock.py b/pymc_marketing/mmm/components/adstock.py index 70107bb35..2f376e8b7 100644 --- a/pymc_marketing/mmm/components/adstock.py +++ b/pymc_marketing/mmm/components/adstock.py @@ -56,6 +56,7 @@ def function(self, x, alpha): import xarray as xr from pydantic import Field, InstanceOf, validate_call +from pymc_marketing.deserialize import register_deserialization from pymc_marketing.mmm.components.base import Transformation from pymc_marketing.mmm.transformers import ( ConvMode, @@ -343,3 +344,13 @@ def adstock_from_dict(data: dict) -> AdstockTransformation: if "priors" in data: data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()} return cls(**data) + + +def _is_adstock(data): + return "lookup_name" in data and data["lookup_name"] in ADSTOCK_TRANSFORMATIONS + + +register_deserialization( + is_type=_is_adstock, + deserialize=adstock_from_dict, +) diff --git a/pymc_marketing/mmm/components/saturation.py b/pymc_marketing/mmm/components/saturation.py index 15b3383ad..3a9a70fa8 100644 --- a/pymc_marketing/mmm/components/saturation.py +++ b/pymc_marketing/mmm/components/saturation.py @@ -76,6 +76,7 @@ def function(self, x, b): import xarray as xr from pydantic import Field, InstanceOf, validate_call +from pymc_marketing.deserialize import register_deserialization from pymc_marketing.mmm.components.base import Transformation from pymc_marketing.mmm.transformers import ( hill_function, @@ -483,3 +484,10 @@ def saturation_from_dict(data: dict) -> SaturationTransformation: key: Prior.from_json(value) for key, value in data["priors"].items() } return cls(**data) + + +def _is_saturation(data): + return "lookup_name" in data and data["lookup_name"] in SATURATION_TRANSFORMATIONS + + +register_deserialization(_is_saturation, saturation_from_dict) diff --git a/pymc_marketing/prior.py b/pymc_marketing/prior.py index 4d99ac74e..8fd8c81ae 100644 --- a/pymc_marketing/prior.py +++ b/pymc_marketing/prior.py @@ -108,6 +108,8 @@ def custom_transform(x): from pydantic import validate_call from pymc.distributions.shape_utils import Dims +from pymc_marketing.deserialize import register_deserialization + class UnsupportedShapeError(Exception): """Error for when the shapes from variables are not compatible.""" @@ -987,3 +989,10 @@ def create_likelihood_variable( distribution.parameters["mu"] = mu distribution.parameters["observed"] = observed return distribution.create_variable(name) + + +def _is_prior_type(data: dict) -> bool: + return "dist" in data + + +register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_json)