From 59237ad394af0e1b01a9295b67b85b33498c3c67 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Fri, 25 Oct 2024 11:26:24 -0700 Subject: [PATCH] Redo constraint violation indicators and legend on insample and predicted plots (#2955) Summary: Instead of varying sizing, we're adding a red outline of varying opacity based on the overall probability of constraint violation. There is a hack to make sure the outlines don't appear in the legend, as the legend normally reflects the first point in each group. We're hiding the original and creating a duplicate group of the same name at point (None, None). In doing this, we give up the ability to toggle points on and off the graph by clicking on their legend item, which I just discovered existed. {F1944580069} Reviewed By: ItsMrLin Differential Revision: D64850289 --- ax/analysis/plotly/arm_effects/utils.py | 89 ++++++++++++++++++- .../plotly/tests/test_insample_effects.py | 4 +- .../plotly/tests/test_predicted_effects.py | 6 +- 3 files changed, 93 insertions(+), 6 deletions(-) diff --git a/ax/analysis/plotly/arm_effects/utils.py b/ax/analysis/plotly/arm_effects/utils.py index e0d68a5d494..c8244ec0e5c 100644 --- a/ax/analysis/plotly/arm_effects/utils.py +++ b/ax/analysis/plotly/arm_effects/utils.py @@ -65,9 +65,60 @@ def prepare_arm_effects_plot( color="source", # TODO: can we format this by callable or string template? hover_data=_get_parameter_columns(df), - size="size_column", - size_max=10, + # avoid red because it will match the constraint violation indicator + color_discrete_sequence=px.colors.qualitative.Vivid, ) + dot_size = 8 + # set all dots to size 8 in plots + fig.update_traces(marker={"line": {"width": 2}, "size": dot_size}) + + # Manually create each constraint violation indicator + # as a red outline around the dot, with alpha based on the + # probability of constraint violation. + for trace in fig.data: + # there is a trace per source, so get the rows of df + # pertaining to this trace + indices = df["source"] == trace.name + trace.marker.line.color = [ + # raising the alpha to a power < 1 makes the colors more + # visible when there is a lower chance of constraint violation + f"rgba(255, 0, 0, {(alpha) ** .75})" + for alpha in df.loc[indices, "overall_probability_constraints_violated"] + ] + # Create a separate trace for the legend, otherwise the legend + # will have the constraint violation indicator of the first arm + # in the source group + legend_trace = go.Scatter( + # (None, None) is a hack to get a legend item without + # appearing on the plot + x=[None], + y=[None], + mode="markers", + marker={ + "size": dot_size, + "color": trace.marker.color, + }, + name=trace.name, + ) + fig.add_trace(legend_trace) + trace.showlegend = False + + # Add an item to the legend for the constraint violation indicator + legend_trace = go.Scatter( + # (None, None) is a hack to get a legend item without + # appearing on the plot + x=[None], + y=[None], + mode="markers", + marker={ + "size": dot_size, + "color": "white", + "line": {"width": 2, "color": "red"}, + }, + name="Constraint Violation", + ) + fig.add_trace(legend_trace) + _add_style_to_effects_by_arm_plot( fig=fig, df=df, metric_name=metric_name, outcome_constraints=outcome_constraints ) @@ -100,6 +151,20 @@ def _add_style_to_effects_by_arm_plot( y=df[df["arm_name"] == "status_quo"]["mean"].iloc[0], line_width=1, line_color="red", + showlegend=True, + name="Status Quo Mean", + ) + # Add the status quo mean to the legend + fig.add_trace( + go.Scatter( + # (None, None) is a hack to get a legend item without + # appearing on the plot + x=[None], + y=[None], + mode="lines", + line={"color": "red", "width": 1}, + name="Status Quo Mean", + ) ) for constraint in outcome_constraints: if constraint.metric.name == metric_name: @@ -110,10 +175,25 @@ def _add_style_to_effects_by_arm_plot( line_color="red", line_dash="dash", ) + # Add the constraint bound to the legend + fig.add_trace( + go.Scatter( + # (None, None) is a hack to get a legend item without + # appearing on the plot + x=[None], + y=[None], + mode="lines", + line={"color": "red", "width": 1, "dash": "dash"}, + name="Constraint Bound", + ) + ) fig.update_layout( xaxis={ "tickangle": 45, }, + legend={ + "title": None, + }, ) @@ -206,7 +286,10 @@ def get_predictions_by_arm( "constraints_violated": format_constraint_violated_probabilities( constraints_violated[i] ), - "size_column": 100 - probabilities_not_feasible[i] * 100, + # used for constraint violation indicator + "overall_probability_constraints_violated": round( + probabilities_not_feasible[i], ndigits=2 + ), "parameters": format_parameters_for_effects_by_arm_plot( parameters=features[i].parameters ), diff --git a/ax/analysis/plotly/tests/test_insample_effects.py b/ax/analysis/plotly/tests/test_insample_effects.py index 48562a9588a..f160b7ab6f2 100644 --- a/ax/analysis/plotly/tests/test_insample_effects.py +++ b/ax/analysis/plotly/tests/test_insample_effects.py @@ -405,7 +405,9 @@ def test_constraints(self) -> None: str(non_sq_df["constraints_violated"][0]), ) # AND THEN it marks that constraints are not violated for the SQ - self.assertEqual(sq_row["size_column"].iloc[0], 100) + self.assertEqual( + sq_row["overall_probability_constraints_violated"].iloc[0], 0 + ) self.assertEqual( sq_row["constraints_violated"].iloc[0], "No constraints violated" ) diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index 9f045c12d9e..4c22376d3b1 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -180,7 +180,7 @@ def test_compute(self) -> None: "sem", "error_margin", "constraints_violated", - "size_column", + "overall_probability_constraints_violated", }, ) self.assertIsNotNone(card.blob) @@ -380,7 +380,9 @@ def test_constraints(self) -> None: str(non_sq_df["constraints_violated"][0]), ) # AND THEN it marks that constraints are not violated for the SQ - self.assertEqual(sq_row["size_column"].iloc[0], 100) + self.assertEqual( + sq_row["overall_probability_constraints_violated"].iloc[0], 0 + ) self.assertEqual( sq_row["constraints_violated"].iloc[0], "No constraints violated" )