Skip to content

Commit

Permalink
Fix #1351: Merge small improvements from PR #1281
Browse files Browse the repository at this point in the history
  • Loading branch information
trentmc committed Jul 5, 2024
1 parent 0752578 commit dae2dc5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 2 additions & 0 deletions pdr_backend/aimodel/aimodel_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _build_wrapped_regr(
ss = self.ss
assert ss.do_regr
assert ycont is not None
assert X.shape[0] == ycont.shape[0], (X.shape[0], ycont.shape[0])
do_constant = min(ycont) == max(ycont) or ss.approach == "RegrConstant"

# weight newest sample 10x, and 2nd-newest sample 5x
Expand Down Expand Up @@ -145,6 +146,7 @@ def _build_direct_classif(
) -> Aimodel:
ss = self.ss
assert not ss.do_regr
assert X.shape[0] == len(ytrue), (X.shape[0], len(ytrue))
n_True, n_False = sum(ytrue), sum(np.invert(ytrue))
smallest_n = min(n_True, n_False)
do_constant = (smallest_n == 0) or ss.approach == "ClassifConstant"
Expand Down
7 changes: 4 additions & 3 deletions pdr_backend/aimodel/aimodel_plotdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def __init__(
model: Aimodel,
X_train: np.ndarray,
ytrue_train: np.ndarray,
ycont_train: np.ndarray,
y_thr: float,
ycont_train: Optional[np.ndarray],
y_thr: Optional[float],
colnames: List[str],
slicing_x: np.ndarray,
sweep_vars: Optional[List[int]] = None,
Expand All @@ -45,7 +45,8 @@ def __init__(
assert len(colnames) == n, (len(colnames), n)
assert slicing_x.shape[0] == n, (slicing_x.shape[0], n)
assert ytrue_train.shape[0] == N, (ytrue_train.shape[0], N)
assert ycont_train.shape[0] == N, (ycont_train.shape[0], N)
if ycont_train is not None:
assert ycont_train.shape[0] == N, (ycont_train.shape[0], N)
assert sweep_vars is None or len(sweep_vars) in [1, 2]

# set values
Expand Down
2 changes: 2 additions & 0 deletions pdr_backend/aimodel/aimodel_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def _plot_lineplot_1var(aimodel_plotdata: AimodelPlotdata):
# line plot: regressor response, training data
if d.model.do_regr:
assert mesh_ycont_hat is not None
assert y_thr is not None
assert ycont is not None
fig.add_trace(
go.Scatter(
x=mesh_chosen_x,
Expand Down

0 comments on commit dae2dc5

Please sign in to comment.