Skip to content

Commit

Permalink
Add absolute value to montecarlo plot
Browse files Browse the repository at this point in the history
  • Loading branch information
pachoning committed Jul 13, 2023
1 parent 77314a5 commit 7086b7b
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions survlimepy/survlime_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def plot_montecarlo_weights(
scale_with_data_point: bool = False,
figure_path: Optional[str] = None,
with_colour: bool = True,
absolute_vals: bool = False,
) -> None:
"""Generates explanations for a prediction.
Expand All @@ -403,6 +404,7 @@ def plot_montecarlo_weights(
scale_with_data_point (bool): whether to perform the elementwise multiplication between the point to be explained and the coefficients.
figure_path (Optional[str]): path to save the figure.
with_colour (bool): boolean indicating whether the colour palette for positive coefficients is different than thecolour palette for negative coefficients. Default is set to True.
absolute_vals (bool): whether to plot the absolute values of the coefficients.
Returns:
None.
Expand Down Expand Up @@ -454,9 +456,25 @@ def plot_montecarlo_weights(
colors_up = {key: val for key, val in zip(median_up.keys(), pal_up)}
colors_down = {key: val for key, val in zip(median_down.keys(), pal_down)}
custom_pal = {**colors_up, **colors_down}
data_reindex = data.reindex(columns=custom_pal.keys())
data_melt = pd.melt(data_reindex)

if absolute_vals:
absolute_order = {**median_up, **median_down}
absolute_order = {key: np.abs(val) for key, val in absolute_order.items()}
absolute_order = dict(
sorted(
absolute_order.items(),
key=lambda item: np.abs(item[1]),
reverse=True,
)
)
custom_pal = {key: custom_pal[key] for key in absolute_order.keys()}
data_reindex = data.reindex(columns=list(custom_pal.keys()))
data_melt = pd.melt(data_reindex)
data_melt.value = np.abs(data_melt.value)
plot_title = "Absolute feature importance"
else:
data_reindex = data.reindex(columns=custom_pal.keys())
data_melt = pd.melt(data_reindex)
plot_title = "Feature importance"
_, ax = plt.subplots(figsize=figsize)
ax.tick_params(labelrotation=90)
if with_colour:
Expand All @@ -481,7 +499,7 @@ def plot_montecarlo_weights(
p.yaxis.grid(True)
p.xaxis.grid(True)

p.set_title("Feature importance", fontsize=16, fontweight="bold")
p.set_title(plot_title, fontsize=16, fontweight="bold")

plt.xticks(fontsize=16, rotation=90)
plt.yticks(fontsize=14, rotation=0)
Expand Down

0 comments on commit 7086b7b

Please sign in to comment.