Skip to content

Commit

Permalink
support multitask gam interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 10, 2024
1 parent 650d3c6 commit 0efced5
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions imodels/algebraic/gam_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,26 @@ class MultiTaskGAM(BaseEstimator):

def __init__(
self,
ebm_kwargs={'interactions': 0, 'n_jobs': 1},
ebm_kwargs={'n_jobs': 1},
multitask=True,
interactions=0.95,
linear_penalty='ridge',
random_state=42,
):
"""
Params
------
Note: args override ebm_kwargs if there are duplicates
"""
self.ebm_kwargs = ebm_kwargs
self.multitask = multitask
self.linear_penalty = linear_penalty
self.random_state = random_state
if not 'random_state' in ebm_kwargs:
ebm_kwargs['random_state'] = random_state
self.interactions = interactions

# override ebm_kwargs
ebm_kwargs['random_state'] = random_state
ebm_kwargs['interactions'] = interactions
self.ebm_ = ExplainableBoostingRegressor(**(ebm_kwargs or {}))

def fit(self, X, y, sample_weight=None):
Expand Down Expand Up @@ -77,6 +82,10 @@ def fit(self, X, y, sample_weight=None):
self.ebms_[num_features].fit(X, y, sample_weight=sample_weight)

# extract features
self.term_names_list_ = [
ebm_.term_names_ for ebm_ in self.ebms_]
self.term_names_ = sum(self.term_names_list_, [])

feats = self._extract_ebm_features(X)

# fit a linear model to the features
Expand All @@ -92,17 +101,16 @@ def fit(self, X, y, sample_weight=None):

def _extract_ebm_features(self, X):
'''
Extract features by predicting each feature with each EBM
Note: this doesn't currently handle interactions
Extract features by extracting all terms with EBM
'''
num_ebms = X.shape[1] + 1
num_features = X.shape[1]
feats = np.zeros((X.shape[0], num_ebms * num_features))
for ebm_num in range(num_ebms):
feats = np.empty((X.shape[0], len(self.term_names_)))
offset = 0
for ebm_num in range(len(self.ebms_)):
# see eval_terms function: https://interpret.ml/docs/python/api/ExplainableBoostingRegressor.html#interpret.glassbox.ExplainableBoostingRegressor.eval_terms
feats[:, ebm_num * num_features: (ebm_num + 1) * num_features] = \
n_features_ebm_num = len(self.term_names_list_[ebm_num])
feats[:, offset: offset + n_features_ebm_num] = \
self.ebms_[ebm_num].eval_terms(X)

offset += n_features_ebm_num
return feats

def predict(self, X):
Expand Down Expand Up @@ -133,27 +141,28 @@ def test_multitask_extraction():
# X, y, feature_names = imodels.get_clean_dataset("bike_sharing")

# remove some features to speed things up
X = X[:10]
X = X[:10, :4]
y = y[:10]
X, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

# unit test
gam = MultiTaskGAMRegressor(multitask=False)
gam.fit(X, y_train)
ebm = gam.ebm_
# print('feature_names_in', ebm.feature_names_in_)
# ebm = gam.ebm_
gam2 = MultiTaskGAMRegressor(multitask=True)
gam2.fit(X, y_train)
preds_orig = gam.predict(X_test)
assert np.allclose(preds_orig, gam2.ebms_[-1].predict(X_test))

# extracted curves should sum to original predictions
# extracted curves + intercept should sum to original predictions
feats_extracted = gam2._extract_ebm_features(X_test)
num_samples = X_test.shape[0]
num_features = X_test.shape[1]
num_ebms = num_features + 1
feats_extracted_target = feats_extracted[:, -num_features:]
assert feats_extracted_target.shape == (num_samples, num_features)

# get features for ebm that predicts target
feats_extracted_target = feats_extracted[:,
-len(gam2.term_names_list_[-1]):]
# assert feats_extracted_target.shape == (num_samples, num_features)
preds_extracted_target = np.sum(feats_extracted_target, axis=1) + \
gam2.ebms_[-1].intercept_
diff = preds_extracted_target - preds_orig
Expand All @@ -178,8 +187,8 @@ def test_multitask_extraction():
for gam in tqdm([
# AdaBoostRegressor(estimator=MultiTaskGAMRegressor(
# multitask=True), n_estimators=2),
MultiTaskGAMRegressor(multitask=False),
# MultiTaskGAMRegressor(multitask=True),
# MultiTaskGAMRegressor(multitask=False),
MultiTaskGAMRegressor(multitask=True),
# ExplainableBoostingRegressor(n_jobs=1, interactions=0)
]):
np.random.seed(42)
Expand Down

0 comments on commit 0efced5

Please sign in to comment.