diff --git a/causalml/metrics/visualize.py b/causalml/metrics/visualize.py index 165e2e97..065baafb 100644 --- a/causalml/metrics/visualize.py +++ b/causalml/metrics/visualize.py @@ -137,7 +137,10 @@ def get_cumlift( assert ( (outcome_col in df.columns and df[outcome_col].notnull().all()) and (treatment_col in df.columns and df[treatment_col].notnull().all()) - or (treatment_effect_col in df.columns and df[treatment_effect_col].notnull().all()) + or ( + treatment_effect_col in df.columns + and df[treatment_effect_col].notnull().all() + ) ), "{outcome_col} and {treatment_col}, or {treatment_effect_col} should be present without null.".format( outcome_col=outcome_col, treatment_col=treatment_col, @@ -266,7 +269,10 @@ def get_qini( assert ( (outcome_col in df.columns and df[outcome_col].notnull().all()) and (treatment_col in df.columns and df[treatment_col].notnull().all()) - or (treatment_effect_col in df.columns and df[treatment_effect_col].notnull().all()) + or ( + treatment_effect_col in df.columns + and df[treatment_effect_col].notnull().all() + ) ), "{outcome_col} and {treatment_col}, or {treatment_effect_col} should be present without null.".format( outcome_col=outcome_col, treatment_col=treatment_col,