Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redo constraint violation indicators and legend on inample and predicted plots #2955

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 86 additions & 3 deletions ax/analysis/plotly/arm_effects/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,60 @@ def prepare_arm_effects_plot(
color="source",
# TODO: can we format this by callable or string template?
hover_data=_get_parameter_columns(df),
size="size_column",
size_max=10,
# avoid red because it will match the constraint violation indicator
color_discrete_sequence=px.colors.qualitative.Vivid,
)
dot_size = 8
# set all dots to size 8 in plots
fig.update_traces(marker={"line": {"width": 2}, "size": dot_size})

# Manually create each constraint violation indicator
# as a red outline around the dot, with alpha based on the
# probability of constraint violation.
for trace in fig.data:
# there is a trace per source, so get the rows of df
# pertaining to this trace
indices = df["source"] == trace.name
trace.marker.line.color = [
# raising the alpha to a power < 1 makes the colors more
# visible when there is a lower chance of constraint violation
f"rgba(255, 0, 0, {(alpha) ** .75})"
for alpha in df.loc[indices, "overall_probability_constraints_violated"]
]
# Create a separate trace for the legend, otherwise the legend
# will have the constraint violation indicator of the first arm
# in the source group
legend_trace = go.Scatter(
# (None, None) is a hack to get a legend item without
# appearing on the plot
x=[None],
y=[None],
mode="markers",
marker={
"size": dot_size,
"color": trace.marker.color,
},
name=trace.name,
)
fig.add_trace(legend_trace)
trace.showlegend = False

# Add an item to the legend for the constraint violation indicator
legend_trace = go.Scatter(
# (None, None) is a hack to get a legend item without
# appearing on the plot
x=[None],
y=[None],
mode="markers",
marker={
"size": dot_size,
"color": "white",
"line": {"width": 2, "color": "red"},
},
name="Constraint Violation",
)
fig.add_trace(legend_trace)

_add_style_to_effects_by_arm_plot(
fig=fig, df=df, metric_name=metric_name, outcome_constraints=outcome_constraints
)
Expand Down Expand Up @@ -100,6 +151,20 @@ def _add_style_to_effects_by_arm_plot(
y=df[df["arm_name"] == "status_quo"]["mean"].iloc[0],
line_width=1,
line_color="red",
showlegend=True,
name="Status Quo Mean",
)
# Add the status quo mean to the legend
fig.add_trace(
go.Scatter(
# (None, None) is a hack to get a legend item without
# appearing on the plot
x=[None],
y=[None],
mode="lines",
line={"color": "red", "width": 1},
name="Status Quo Mean",
)
)
for constraint in outcome_constraints:
if constraint.metric.name == metric_name:
Expand All @@ -110,10 +175,25 @@ def _add_style_to_effects_by_arm_plot(
line_color="red",
line_dash="dash",
)
# Add the constraint bound to the legend
fig.add_trace(
go.Scatter(
# (None, None) is a hack to get a legend item without
# appearing on the plot
x=[None],
y=[None],
mode="lines",
line={"color": "red", "width": 1, "dash": "dash"},
name="Constraint Bound",
)
)
fig.update_layout(
xaxis={
"tickangle": 45,
},
legend={
"title": None,
},
)


Expand Down Expand Up @@ -206,7 +286,10 @@ def get_predictions_by_arm(
"constraints_violated": format_constraint_violated_probabilities(
constraints_violated[i]
),
"size_column": 100 - probabilities_not_feasible[i] * 100,
# used for constraint violation indicator
"overall_probability_constraints_violated": round(
probabilities_not_feasible[i], ndigits=2
),
"parameters": format_parameters_for_effects_by_arm_plot(
parameters=features[i].parameters
),
Expand Down
4 changes: 3 additions & 1 deletion ax/analysis/plotly/tests/test_insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,9 @@ def test_constraints(self) -> None:
str(non_sq_df["constraints_violated"][0]),
)
# AND THEN it marks that constraints are not violated for the SQ
self.assertEqual(sq_row["size_column"].iloc[0], 100)
self.assertEqual(
sq_row["overall_probability_constraints_violated"].iloc[0], 0
)
self.assertEqual(
sq_row["constraints_violated"].iloc[0], "No constraints violated"
)
Expand Down
6 changes: 4 additions & 2 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_compute(self) -> None:
"sem",
"error_margin",
"constraints_violated",
"size_column",
"overall_probability_constraints_violated",
},
)
self.assertIsNotNone(card.blob)
Expand Down Expand Up @@ -380,7 +380,9 @@ def test_constraints(self) -> None:
str(non_sq_df["constraints_violated"][0]),
)
# AND THEN it marks that constraints are not violated for the SQ
self.assertEqual(sq_row["size_column"].iloc[0], 100)
self.assertEqual(
sq_row["overall_probability_constraints_violated"].iloc[0], 0
)
self.assertEqual(
sq_row["constraints_violated"].iloc[0], "No constraints violated"
)
Expand Down