Skip to content

Commit

Permalink
fix hstree feature printing
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Nov 16, 2023
1 parent d97be1a commit 6891928
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion imodels/tree/hierarchical_shrinkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def fit(self, X, y, sample_weight=None, *args, **kwargs):
# None returned if not passed
feature_names = kwargs.pop("feature_names", None)
X, y, feature_names = check_fit_arguments(self, X, y, feature_names)
if feature_names is not None:
self.feature_names = feature_names
self.estimator_ = self.estimator_.fit(
X, y, *args, sample_weight=sample_weight, **kwargs
)
Expand Down Expand Up @@ -341,7 +343,7 @@ def fit(self, X, y, *args, **kwargs):
base_est.fit(X_in, y_in)
for i, reg_param in enumerate(self.reg_param_list):
est_hs = HSTreeClassifier(base_est, reg_param)
est_hs.fit(X_in, y_in)
est_hs.fit(X_in, y_in, *args, **kwargs)
self.scores_[i].append(
scorer(y_out, est_hs.predict_proba(X_out)))
self.scores_ = [np.mean(s) for s in self.scores_]
Expand Down

0 comments on commit 6891928

Please sign in to comment.