Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added option to use sklearn's OneHotEncoder to handle unknown categories #174

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
51 changes: 42 additions & 9 deletions prince/mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@
import numpy as np
import pandas as pd
import sklearn.base
import sklearn.preprocessing
import sklearn.utils
from sklearn.preprocessing import OneHotEncoder

from prince import utils

from . import ca


class MCA(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin, ca.CA):
'''
added new attributes to support one-hot encoding when handling unknown categories

added attributes:
get_dummies: if True, use pd.get_dummies to one-hot encode the data
one_hot_encoder: OneHotEncoder object to use
is_one_hot_fitted: check if one_hot_encoder is fitted (set it to true if the one_hot_encoder is already fitted)
'''
def __init__(
self,
n_components=2,
Expand All @@ -21,8 +29,10 @@ def __init__(
check_input=True,
random_state=None,
engine="sklearn",
one_hot=True,
handle_unknown="error",
one_hot = True,
get_dummies = False,#if True, use pd.get_dummies to one-hot encode the data
one_hot_encoder=OneHotEncoder(handle_unknown="ignore", sparse_output=False, dtype=bool), #OneHotEncoder object to use
is_one_hot_fitted = False
):
super().__init__(
n_components=n_components,
Expand All @@ -33,17 +43,40 @@ def __init__(
engine=engine,
)
self.one_hot = one_hot
self.handle_unknown = handle_unknown
self.get_dummies = get_dummies
self.one_hot_encoder = one_hot_encoder
self.is_one_hot_fitted = is_one_hot_fitted


def _prepare(self, X):
if self.one_hot:
# Create the one-hot encoder if it doesn't exist (usually because we're in the fit method)
X = pd.get_dummies(X, columns=X.columns)
if self.get_dummies:
X = pd.get_dummies(X, columns=X.columns)
return X
else:
if self.is_one_hot_fitted is False:
#if the one_hot_encoder is not fitted, to fit and also set the is_one_hot_fitted variable to True
X_enc = self.one_hot_encoder.fit_transform(X)
X_enc = pd.DataFrame(X_enc, columns=self.one_hot_encoder.get_feature_names_out(X.columns))
self.is_one_hot_fitted = True
return X_enc
else:
#checking if the columns fed to the onehot encoder and the columns fitted to the onehot encoder are the same
oh_cols = set(self.one_hot_encoder.feature_names_in_.tolist())
X_cols = set(X.columns.tolist())

if oh_cols == X_cols:
#if the fitted cols are the same as the inferencing columns, then can transform
X_enc = self.one_hot_encoder.transform(X)
X_enc = pd.DataFrame(X_enc, columns=self.one_hot_encoder.get_feature_names_out(X.columns))
return X_enc
else:
#if the fitted cols are different to the inferencing columns, then should fit the onehot encoder again, to handle unit tests
X_enc = self.one_hot_encoder.fit_transform(X)
X_enc = pd.DataFrame(X_enc, columns=self.one_hot_encoder.get_feature_names_out(X.columns))
return X_enc
return X

def get_feature_names_out(self, input_features=None):
return np.arange(self.n_components_)

@utils.check_is_dataframe_input
def fit(self, X, y=None):
"""Fit the MCA for the dataframe X.
Expand Down