Skip to content

Commit

Permalink
support multitask gam classification
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 12, 2024
1 parent 22865c2 commit f3b5420
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 138 deletions.
221 changes: 93 additions & 128 deletions imodels/algebraic/gam_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.linear_model import ElasticNetCV, LinearRegression, RidgeCV, LassoCV
from sklearn.linear_model import ElasticNetCV, LinearRegression, RidgeCV, LassoCV, LogisticRegressionCV
from sklearn.tree import DecisionTreeRegressor
from sklearn.utils.validation import check_is_fitted
from sklearn.utils import check_array
Expand All @@ -13,7 +13,7 @@
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm
from sklearn.multioutput import MultiOutputRegressor
from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier
from collections import defaultdict
import pandas as pd
import json
Expand All @@ -25,9 +25,10 @@
from sklearn.base import RegressorMixin, ClassifierMixin


# See notes in this implementation:
# https://github.com/interpretml/interpret/blob/develop/python/interpret-core/interpret/glassbox/_ebm/_ebm.py
# See notes on EBM in the docs
# main file: https://github.com/interpretml/interpret/blob/develop/python/interpret-core/interpret/glassbox/_ebm/_ebm.py
# merge ebms: https://github.com/interpretml/interpret/blob/develop/python/interpret-core/interpret/glassbox/_ebm/_merge_ebms.py#L280
# eval_terms: https://interpret.ml/docs/python/api/ExplainableBoostingRegressor.html#interpret.glassbox.ExplainableBoostingRegressor.eval_terms

class MultiTaskGAM(BaseEstimator):
"""EBM-based GAM that shares curves for predicting different outputs.
Expand All @@ -44,6 +45,7 @@ def __init__(
onehot_prior=False,
renormalize_features=False,
random_state=42,
use_internal_classifiers=False,
):
"""
Params
Expand All @@ -52,6 +54,8 @@ def __init__(
one_hot_prior: bool
If True and multitask, the linear model will be fit with a prior that the ebm
features predicting the target should have coef 1
use_internal_classifiers: bool
whether to use internal classifiers (as opposed to regressors)
"""
self.ebm_kwargs = ebm_kwargs
self.multitask = multitask
Expand All @@ -60,24 +64,43 @@ def __init__(
self.interactions = interactions
self.onehot_prior = onehot_prior
self.renormalize_features = renormalize_features
self.use_internal_classifiers = use_internal_classifiers

# override ebm_kwargs
ebm_kwargs['random_state'] = random_state
ebm_kwargs['interactions'] = interactions
self.ebm_ = ExplainableBoostingRegressor(**(ebm_kwargs or {}))

def fit(self, X, y, sample_weight=None):
X, y = check_X_y(X, y, accept_sparse=False, multi_output=True)
self.n_outputs_ = 1 if len(y.shape) == 1 else y.shape[1]
if isinstance(self, ClassifierMixin):
check_classification_targets(y)
self.classes_, y = np.unique(y, return_inverse=True)
if self.n_outputs_ == 1:
self.classes_, y = np.unique(y, return_inverse=True)
if len(self.classes_) > 2:
raise ValueError(
"MultiTaskGAMClassifier currently only supports binary classification")
elif self.n_outputs_ > 1:
self.classes_ = [np.unique(y[:, i])
for i in range(self.n_outputs_)]
if any(len(c) > 2 for c in self.classes_):
raise ValueError(
"MultiTaskGAMClassifier currently only supports binary classification")
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)
self.n_outputs_ = 1 if len(y.shape) == 1 else y.shape[1]

# just fit ebm normally
if not self.multitask:
if isinstance(self, ClassifierMixin):
self.ebm_ = ExplainableBoostingClassifier(**self.ebm_kwargs)
else:
self.ebm_ = ExplainableBoostingRegressor(**self.ebm_kwargs)

# fit
if self.n_outputs_ > 1:
self.ebm_multioutput_ = MultiOutputRegressor(self.ebm_)
if isinstance(self, ClassifierMixin):
self.ebm_multioutput_ = MultiOutputClassifier(self.ebm_)
else:
self.ebm_multioutput_ = MultiOutputRegressor(self.ebm_)
self.ebm_multioutput_.fit(X, y, sample_weight=sample_weight)
else:
self.ebm_.fit(X, y, sample_weight=sample_weight)
Expand All @@ -88,23 +111,27 @@ def fit(self, X, y, sample_weight=None):
num_features = X.shape[1]

if self.n_outputs_ == 1:
# with 1 output, we fit an EBM to each feature
for task_num in tqdm(range(num_features)):
self.ebms_.append(deepcopy(self.ebm_))
y_ = np.ascontiguousarray(X[:, task_num])
X_ = deepcopy(X)
X_[:, task_num] = 0
self.ebms_.append(self._initialize_ebm_internal(y_))
if isinstance(self, ClassifierMixin):
_, y_ = np.unique(y_, return_inverse=True)
self.ebms_[task_num].fit(X_, y_, sample_weight=sample_weight)

# finally, fit EBM to the target
self.ebms_.append(deepcopy(self.ebm_))
# also fit an EBM to the target
self.ebms_.append(self._initialize_ebm_internal(y))
self.ebms_[num_features].fit(X, y, sample_weight=sample_weight)
elif self.n_outputs_ > 1:
# with multiple outputs, we fit an EBM to each output
for task_num in tqdm(range(self.n_outputs_)):
self.ebms_.append(deepcopy(self.ebm_))
self.ebms_.append(self._initialize_ebm_internal(y))
y_ = np.ascontiguousarray(y[:, task_num])
self.ebms_[task_num].fit(X, y_, sample_weight=sample_weight)

# extract features
# extract features from EBMs
self.term_names_list_ = [
ebm_.term_names_ for ebm_ in self.ebms_]
self.term_names_ = sum(self.term_names_list_, [])
Expand All @@ -113,25 +140,48 @@ def fit(self, X, y, sample_weight=None):
if self.renormalize_features:
self.scaler_ = StandardScaler()
feats = self.scaler_.fit_transform(feats)
feats[np.isinf(feats)] = 0

# fit linear model
self.lin_model = self._fit_linear_model(feats, y, sample_weight)

return self

def _initialize_ebm_internal(self, y):
if self.use_internal_classifiers and len(np.unique(y)) == 2:
return ExplainableBoostingClassifier(**self.ebm_kwargs)
else:
return ExplainableBoostingRegressor(**self.ebm_kwargs)

def _fit_linear_model(self, feats, y, sample_weight):
# fit a linear model to the features
self.lin_model = {
'ridge': RidgeCV(alphas=np.logspace(-2, 3, 7)),
'elasticnet': ElasticNetCV(n_alphas=7),
'lasso': LassoCV(n_alphas=7)
}[self.linear_penalty]

if not self.onehot_prior:
self.lin_model.fit(feats, y, sample_weight=sample_weight)
if isinstance(self, ClassifierMixin):
lin_model = {
'ridge': LogisticRegressionCV(penalty='l2'),
'elasticnet': LogisticRegressionCV(penalty='elasticnet'),
'lasso': LogisticRegressionCV(penalty='l1'),
}[self.linear_penalty]
if self.n_outputs_ > 1:
lin_model = MultiOutputClassifier(lin_model)
else:
lin_model = {
'ridge': RidgeCV(alphas=np.logspace(-2, 3, 7)),
'elasticnet': ElasticNetCV(n_alphas=7),
'lasso': LassoCV(n_alphas=7),
}[self.linear_penalty]

# onehot prior is a prior (for regression only) that
# the ebm features predicting the target should have coef 1
if not self.onehot_prior or isinstance(self, ClassifierMixin):
lin_model.fit(feats, y, sample_weight=sample_weight)
else:
coef_prior_ = np.zeros((feats.shape[1], ))
coef_prior_[:num_features] = 1
coef_prior_[:-len(self.term_names_list_)] = 1
preds_prior = feats @ coef_prior_
residuals = y - preds_prior
self.lin_model.fit(feats, residuals, sample_weight=sample_weight)
self.lin_model.coef_ = self.lin_model.coef_ + coef_prior_

return self
lin_model.fit(feats, residuals, sample_weight=sample_weight)
lin_model.coef_ = lin_model.coef_ + coef_prior_
return lin_model

def _extract_ebm_features(self, X):
'''
Expand All @@ -140,7 +190,6 @@ def _extract_ebm_features(self, X):
feats = np.empty((X.shape[0], len(self.term_names_)))
offset = 0
for ebm_num in range(len(self.ebms_)):
# see eval_terms function: https://interpret.ml/docs/python/api/ExplainableBoostingRegressor.html#interpret.glassbox.ExplainableBoostingRegressor.eval_terms
n_features_ebm_num = len(self.term_names_list_[ebm_num])
feats[:, offset: offset + n_features_ebm_num] = \
self.ebms_[ebm_num].eval_terms(X)
Expand All @@ -155,20 +204,32 @@ def predict(self, X):
feats = self._extract_ebm_features(X)
if hasattr(self, 'scaler_'):
feats = self.scaler_.transform(feats)
feats[np.isinf(feats)] = 0
return self.lin_model.predict(feats)

# multi-output without multitask learning
elif hasattr(self, 'ebm_multioutput_'):
return self.ebm_multioutput_.predict(X)

# single-task standard
else:
elif hasattr(self, 'ebm_'):
return self.ebm_.predict(X)

# def predict_proba(self, X):
# check_is_fitted(self)
# X = check_array(X, accept_sparse=False)
# return self.ebm_.predict_proba(X)
def predict_proba(self, X):
check_is_fitted(self)
if hasattr(self, 'ebms_'):
feats = self._extract_ebm_features(X)
if hasattr(self, 'scaler_'):
feats = self.scaler_.transform(feats)
return self.lin_model.predict_proba(feats)

# multi-output without multitask learning
elif hasattr(self, 'ebm_multioutput_'):
return self.ebm_multioutput_.predict_proba(X)

# single-task standard
elif hasattr(self, 'ebm_'):
return self.ebm_.predict_proba(X)


class MultiTaskGAMRegressor(MultiTaskGAM, RegressorMixin):
Expand All @@ -177,99 +238,3 @@ class MultiTaskGAMRegressor(MultiTaskGAM, RegressorMixin):

class MultiTaskGAMClassifier(MultiTaskGAM, ClassifierMixin):
...


def test_single_output_self_supervised():
X, y, feature_names = imodels.get_clean_dataset("california_housing")
# X, y, feature_names = imodels.get_clean_dataset("bike_sharing")

# remove some features to speed things up
X = X[:10, :4]

# remove some outcomes to speed things up
y = y[:10, :3]
X, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

# unit test
gam = MultiTaskGAMRegressor(multitask=False)
gam.fit(X, y_train)
gam2 = MultiTaskGAMRegressor(multitask=True)
gam2.fit(X, y_train)
preds_orig = gam.predict(X_test)
assert np.allclose(preds_orig, gam2.ebms_[-1].predict(X_test))

# extracted curves + intercept should sum to original predictions
feats_extracted = gam2._extract_ebm_features(X_test)

# get features for ebm that predicts target
feats_extracted_target = feats_extracted[:,
-len(gam2.term_names_list_[-1]):]
# assert feats_extracted_target.shape == (num_samples, num_features)
preds_extracted_target = np.sum(feats_extracted_target, axis=1) + \
gam2.ebms_[-1].intercept_
diff = preds_extracted_target - preds_orig
assert np.allclose(preds_extracted_target, preds_orig), diff
print('Tests pass successfully')


def test_multi_output():
X, y, feature_names = imodels.get_clean_dataset("water-quality_multitask")

# remove some features to speed things up
X = X[:10, :4]
y = y[:10]
X, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
print('shapes', X.shape, y_train.shape, X_test.shape, y_test.shape)

gam_mt = MultiTaskGAMRegressor(multitask=True)
gam_mt.fit(X, y_train)
print('multitask r2_test', gam_mt.score(X_test, y_test))

gam = MultiTaskGAMRegressor(multitask=False)
gam.fit(X, y_train)
print('single-task r2_test', gam.score(X_test, y_test))


def test_compare_models():
# X, y, feature_names = imodels.get_clean_dataset("heart")
X, y, feature_names = imodels.get_clean_dataset("bike_sharing")
# X, y, feature_names = imodels.get_clean_dataset("diabetes")

# remove some features to speed things up
X = X[:, :2]
X, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

kwargs = dict(
random_state=42,
)
results = defaultdict(list)
for gam in tqdm([
# AdaBoostRegressor(estimator=MultiTaskGAMRegressor(
# multitask=True), n_estimators=2),
# MultiTaskGAMRegressor(multitask=True, onehot_prior=True),
# MultiTaskGAMRegressor(multitask=True, onehot_prior=False),
MultiTaskGAMRegressor(multitask=True, renormalize_features=True),
MultiTaskGAMRegressor(multitask=True, renormalize_features=False),
# ExplainableBoostingRegressor(n_jobs=1, interactions=0)
]):
np.random.seed(42)
results["model_name"].append(gam)
print('Fitting', results['model_name'][-1])
gam.fit(X, y_train)
results['test_corr'].append(np.corrcoef(
y_test, gam.predict(X_test))[0, 1].round(3))
results['test_r2'].append(gam.score(X_test, y_test).round(3))
if hasattr(gam, 'lin_model'):
print('lin model coef', gam.lin_model.coef_)

# don't round strings
with pd.option_context(
"display.max_rows", None, "display.max_columns", None, "display.width", 1000
):
print(pd.DataFrame(results).round(3))


if __name__ == "__main__":
# test_single_output_self_supervised()
test_multi_output()
# test_compare_models()
23 changes: 13 additions & 10 deletions imodels/util/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

from imodels.util.tree_interaction_utils import make_rj, make_vp


DSET_CLASSIFICATION_KWARGS = {
# classification
'iris': {'dataset_name': 61, 'data_source': 'openml'},
"pima_diabetes": {"dataset_name": 40715, "data_source": "openml"},
"sonar": {"dataset_name": "sonar", "data_source": "pmlb"},
"heart": {"dataset_name": "heart", "data_source": "imodels"},
Expand Down Expand Up @@ -59,20 +61,21 @@
},
# 'breast_tumor': {'dataset_name': '1201_BNG_breastTumor', 'data_source': 'pmlb' # v big
}
DSET_MULTITASK_NAMES = ['3s-bbc1000', '3s-guardian1000', '3s-inter3000', '3s-reuters1000',
'birds', 'cal500', 'chd_49', 'corel16k001', 'corel16k002',
'corel16k003', 'corel16k004', 'corel16k005', 'corel16k006',
'corel16k007', 'corel16k008', 'corel16k009', 'corel16k010',
'corel5k', 'emotions', 'flags', 'foodtruck', 'genbase', 'image',
'mediamill', 'scene', 'stackex_chemistry', 'stackex_chess',
'stackex_cooking', 'stackex_cs', 'water-quality', 'yeast', 'yelp']
DSET_MULTITASK_KWARGS = {
DSET_CLASSIFICATION_MULTITASK_NAMES = [
'3s-bbc1000', '3s-guardian1000', '3s-inter3000', '3s-reuters1000',
'birds', 'cal500', 'chd_49', 'corel16k001', 'corel16k002',
'corel16k003', 'corel16k004', 'corel16k005', 'corel16k006',
'corel16k007', 'corel16k008', 'corel16k009', 'corel16k010',
'corel5k', 'emotions', 'flags', 'foodtruck', 'genbase', 'image',
'mediamill', 'scene', 'stackex_chemistry', 'stackex_chess',
'stackex_cooking', 'stackex_cs', 'water-quality', 'yeast', 'yelp']
DSET_CLASSIFICATION_MULTITASK_KWARGS = {
name + '_multitask': {"dataset_name": name, "data_source": "imodels-multitask"}
for name in DSET_MULTITASK_NAMES
for name in DSET_CLASSIFICATION_MULTITASK_NAMES
}
DSET_KWARGS = {
**DSET_CLASSIFICATION_KWARGS, **DSET_REGRESSION_KWARGS,
**DSET_MULTITASK_KWARGS}
**DSET_CLASSIFICATION_MULTITASK_KWARGS}


def get_clean_dataset(
Expand Down
Loading

0 comments on commit f3b5420

Please sign in to comment.