Skip to content

Commit

Permalink
proba fix with reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
deadsoul44 committed Oct 8, 2024
1 parent b598b62 commit 423f7cb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python-package/python/perpetual/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ def predict_proba(self, X, parallel: Union[bool, None] = None) -> np.ndarray:
cols=cols,
parallel=parallel,
)
return np.concatenate([probabilities, 1 - probabilities], axis=1)
return np.concatenate(
[probabilities.reshape(-1, 1), (1.0 - probabilities).reshape(-1, 1)],
axis=1,
)
else:
raise NotImplementedError(
f"predict_proba not implemented for regression. n_classes = {len(self.classes_)}"
Expand Down
11 changes: 10 additions & 1 deletion python-package/tests/test_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ def test_booster_from_numpy(X_y):
assert np.allclose(model2_preds, model3_preds)


def test_predict_proba(X_y):
X, y = X_y
model = PerpetualBooster(objective="LogLoss")
model.fit(X, y)

y_proba = model.predict_proba(X)

assert np.allclose(y_proba.shape, (len(X), 2))


def test_get_node_list(X_y):
X, y = X_y
X = X
Expand Down Expand Up @@ -614,7 +624,6 @@ def test_booster_metadata(
):
f64_model_path = tmp_path / "modelf64_sl.json"
X, y = X_y
X = X
model = PerpetualBooster(objective="SquaredLoss")
model.fit(X, y)
preds = model.predict(X)
Expand Down

0 comments on commit 423f7cb

Please sign in to comment.