Skip to content

Commit

Permalink
support renormalize_features in multi-task gam
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 11, 2024
1 parent 657216f commit ec52209
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions imodels/algebraic/gam_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import defaultdict
import pandas as pd
import json
from sklearn.preprocessing import StandardScaler

import imodels
from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor
Expand All @@ -37,7 +38,8 @@ def __init__(
multitask=True,
interactions=0.95,
linear_penalty='ridge',
onehot_prior=True,
onehot_prior=False,
renormalize_features=False,
random_state=42,
):
"""
Expand All @@ -54,6 +56,7 @@ def __init__(
self.random_state = random_state
self.interactions = interactions
self.onehot_prior = onehot_prior
self.renormalize_features = renormalize_features

# override ebm_kwargs
ebm_kwargs['random_state'] = random_state
Expand Down Expand Up @@ -90,9 +93,12 @@ def fit(self, X, y, sample_weight=None):
self.term_names_list_ = [
ebm_.term_names_ for ebm_ in self.ebms_]
self.term_names_ = sum(self.term_names_list_, [])

feats = self._extract_ebm_features(X)

if self.renormalize_features:
self.scaler_ = StandardScaler()
feats = self.scaler_.fit_transform(feats)

# fit a linear model to the features
if self.linear_penalty == 'ridge':
self.lin_model = RidgeCV(alphas=np.logspace(-2, 3, 7))
Expand Down Expand Up @@ -126,13 +132,16 @@ def _extract_ebm_features(self, X):
feats[:, offset: offset + n_features_ebm_num] = \
self.ebms_[ebm_num].eval_terms(X)
offset += n_features_ebm_num

return feats

def predict(self, X):
check_is_fitted(self)
X = check_array(X, accept_sparse=False)
if hasattr(self, 'ebms_'):
feats = self._extract_ebm_features(X)
if hasattr(self, 'scaler_'):
feats = self.scaler_.transform(feats)
return self.lin_model.predict(feats)
else:
return self.ebm_.predict(X)
Expand Down Expand Up @@ -183,7 +192,7 @@ def test_multitask_extraction():


if __name__ == "__main__":
test_multitask_extraction()
# test_multitask_extraction()
# X, y, feature_names = imodels.get_clean_dataset("heart")
X, y, feature_names = imodels.get_clean_dataset("bike_sharing")
# X, y, feature_names = imodels.get_clean_dataset("diabetes")
Expand All @@ -199,9 +208,10 @@ def test_multitask_extraction():
for gam in tqdm([
# AdaBoostRegressor(estimator=MultiTaskGAMRegressor(
# multitask=True), n_estimators=2),
MultiTaskGAMRegressor(multitask=False, onehot_prior=True),
MultiTaskGAMRegressor(multitask=False, onehot_prior=False),
MultiTaskGAMRegressor(multitask=True),
# MultiTaskGAMRegressor(multitask=True, onehot_prior=True),
# MultiTaskGAMRegressor(multitask=True, onehot_prior=False),
MultiTaskGAMRegressor(multitask=True, renormalize_features=True),
MultiTaskGAMRegressor(multitask=True, renormalize_features=False),
# ExplainableBoostingRegressor(n_jobs=1, interactions=0)
]):
np.random.seed(42)
Expand Down

0 comments on commit ec52209

Please sign in to comment.