Skip to content

Commit

Permalink
fix iPrePostNEGD input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jpreszler committed Sep 21, 2023
1 parent 133b987 commit 526d8a8
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from causalpy.custom_exceptions import BadIndexException # NOQA
from causalpy.custom_exceptions import DataException, FormulaException
from causalpy.plot_utils import plot_xY
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels
from causalpy.utils import _is_variable_dummy_coded

LEGEND_FONT_SIZE = 12
az.style.use("arviz-darkgrid")
Expand Down Expand Up @@ -978,7 +978,8 @@ class PrePostNEGD(ExperimentalDesign):
:param formula:
A statistical model formula
:param group_variable_name:
Name of the column in data for the group variable
Name of the column in data for the group variable, should be either
binary or boolean
:param pretreatment_variable_name:
Name of the column in data for the pretreatment variable
:param model:
Expand Down Expand Up @@ -1058,17 +1059,19 @@ def __init__(
self.group_variable_name: np.zeros(self.pred_xi.shape),
}
)
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
self.pred_untreated = self.model.predict(X=np.asarray(new_x))
(new_x_untreated,) = build_design_matrices(
[self._x_design_info], x_pred_untreated
)
self.pred_untreated = self.model.predict(X=np.asarray(new_x_untreated))
# treated
x_pred_untreated = pd.DataFrame(
x_pred_treated = pd.DataFrame(
{
self.pretreatment_variable_name: self.pred_xi,
self.group_variable_name: np.ones(self.pred_xi.shape),
}
)
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
self.pred_treated = self.model.predict(X=np.asarray(new_x))
(new_x_treated,) = build_design_matrices([self._x_design_info], x_pred_treated)
self.pred_treated = self.model.predict(X=np.asarray(new_x_treated))

# Evaluate causal impact as equal to the trestment effect
self.causal_impact = self.idata.posterior["beta"].sel(
Expand All @@ -1079,7 +1082,7 @@ def __init__(

def _input_validation(self) -> None:
"""Validate the input data and model formula for correctness"""
if not _series_has_2_levels(self.data[self.group_variable_name]):
if not _is_variable_dummy_coded(self.data[self.group_variable_name]):
raise DataException(
f"""
There must be 2 levels of the grouping variable
Expand Down Expand Up @@ -1165,7 +1168,7 @@ def _get_treatment_effect_coeff(self) -> str:
then we want `C(group)[T.1]`.
"""
for label in self.labels:
if ("group" in label) & (":" not in label):
if (self.group_variable_name in label) & (":" not in label):
return label

raise NameError("Unable to find coefficient name for the treatment effect")
Expand Down

0 comments on commit 526d8a8

Please sign in to comment.