Skip to content

Commit

Permalink
avoid pandas deprecation errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jwdink committed Nov 15, 2023
1 parent 4dbb7c6 commit b5dbb27
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions foundry/evaluation/marginal_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, col: str, bins: Union[int, Sequence] = 20, **kwargs):
self.kwargs = kwargs

def __call__(self, df: pd.DataFrame) -> pd.Series:
if pd.api.types.is_categorical_dtype(df[self.orig_name]):
if isinstance(df[self.orig_name].dtype, pd.CategoricalDtype):
return df[self.orig_name]

if hasattr(self.bins, '__iter__'):
Expand Down Expand Up @@ -204,7 +204,7 @@ def __call__(self,
agg_kwargs['actual_lower'] = ('_actual', lambda s: s.quantile(.25, interpolation='lower'))
agg_kwargs['actual_upper'] = ('_actual', lambda s: s.quantile(.75, interpolation='higher'))
df_vary_grid = (X
.groupby(groupby_colnames + list(vary_features_mappings.keys()))
.groupby(groupby_colnames + list(vary_features_mappings.keys()), observed=False)
.agg(**agg_kwargs)
.reset_index())
if 'actual' in df_vary_grid.columns: # handle binary targets with low-n
Expand All @@ -228,7 +228,7 @@ def __call__(self,
if df_no_vary.shape[0] > 100_000 and not self.quiet:
print("Consider setting `marginalize_aggfun='downsample100000'`.")

chunks = [df for _, df in df_vary_grid.groupby(list(vary_features), sort=False)]
chunks = [df for _, df in df_vary_grid.groupby(list(vary_features), sort=False, observed=False)]
if not self.quiet:
chunks = tqdm(chunks, delay=10)

Expand All @@ -241,7 +241,7 @@ def __call__(self,
pred_colnames.add(col)

_df_collapsed = (_df_merged
.groupby(groupby_colnames + list(vary_features))
.groupby(groupby_colnames + list(vary_features), observed=False)
.agg(**{k: (k, y_aggfun) for k in pred_colnames})
.reset_index())

Expand Down Expand Up @@ -348,7 +348,7 @@ def plot(self,
theme_bw() +
theme(figure_size=(8, 6), subplots_adjust={'wspace': 0.10})
)
if pd.api.types.is_categorical_dtype(data[x.replace('_binned', '')]):
if isinstance(data[x.replace('_binned', '')].dtype, pd.CategoricalDtype):
plot += geom_col()
else:
plot += geom_line()
Expand All @@ -374,7 +374,7 @@ def _get_maybe_binned_features(X: pd.DataFrame,
# `bin_fun` will be a no-op if the user passed ``raw(feature)`` or if they passed
# the feature and it's categorical.
# TODO: test this works!
if not pd.api.types.is_categorical_dtype(binned_feature):
if not isinstance(binned_feature.dtype, pd.CategoricalDtype):
warn(f"{fname} is not categorical, values not present in the data will be dropped.")
binned_fname = fname
else:
Expand Down Expand Up @@ -431,7 +431,7 @@ def _standardize_maybe_binned(data: pd.DataFrame, features: Collection[Union[str
deffy_binned = maybe_binned
elif not isinstance(maybe_binned, str):
raise ValueError(f"Expected {maybe_binned} to be a string or wrapped in ``binned``.")
elif pd.api.types.is_categorical_dtype(data[maybe_binned]):
elif isinstance(data[maybe_binned].dtype, pd.CategoricalDtype):
deffy_binned = Binned(maybe_binned, bins=False)
elif pd.api.types.is_numeric_dtype(data[maybe_binned]):
deffy_binned = Binned(maybe_binned)
Expand Down Expand Up @@ -477,10 +477,10 @@ def _get_binned_feature_map(X: pd.DataFrame,
# creates a df with unique values of `binned_fname` and `nans` for `fname`.
# this will then get filled with the midpoint below:
# todo: less hacky way to do this
df_mapping = X.groupby(binned_fname)[fname].agg('count').reset_index()
df_mapping = X.groupby(binned_fname, observed=False)[fname].agg('count').reset_index()
df_mapping[fname] = float('nan')
else:
df_mapping = X.groupby(binned_fname)[fname].agg(aggfun).reset_index()
df_mapping = X.groupby(binned_fname, observed=False)[fname].agg(aggfun).reset_index()

# for any bins that aren't actually observed, use the midpoint:
midpoints = pd.Series([x.mid for x in df_mapping[binned_fname]])
Expand Down Expand Up @@ -525,11 +525,11 @@ def _get_df_novary(self,
raise NotImplementedError("TODO")

# collapse:
df_no_vary = X.groupby(list(groupby_colnames) + single_cols).agg(**agg_kwargs).reset_index()
df_no_vary = X.groupby(list(groupby_colnames) + single_cols, observed=False).agg(**agg_kwargs).reset_index()
# validate:
if (df_no_vary.groupby(groupby_colnames).size() > 1).any():
if (df_no_vary.groupby(groupby_colnames, observed=False).size() > 1).any():
for col in colnames_to_agg:
if (df_no_vary.groupby(groupby_colnames)[col].nunique() > 1).any():
if (df_no_vary.groupby(groupby_colnames, observed=False)[col].nunique() > 1).any():
raise ValueError(f"The aggfun {agg_kwargs[col]} did not result in 1 value per group for {col}")
# shouldn't generally get here:
raise ValueError(f"The ``marginalize_aggfun`` did not result in one row per group:\n{df_no_vary}")
Expand Down

0 comments on commit b5dbb27

Please sign in to comment.