diff --git a/survlimepy/survlime_explainer.py b/survlimepy/survlime_explainer.py index 0dfdb10..8498c4b 100644 --- a/survlimepy/survlime_explainer.py +++ b/survlimepy/survlime_explainer.py @@ -394,6 +394,7 @@ def plot_montecarlo_weights( scale_with_data_point: bool = False, figure_path: Optional[str] = None, with_colour: bool = True, + absolute_vals: bool = False, ) -> None: """Generates explanations for a prediction. @@ -403,6 +404,7 @@ def plot_montecarlo_weights( scale_with_data_point (bool): whether to perform the elementwise multiplication between the point to be explained and the coefficients. figure_path (Optional[str]): path to save the figure. with_colour (bool): boolean indicating whether the colour palette for positive coefficients is different than thecolour palette for negative coefficients. Default is set to True. + absolute_vals (bool): whether to plot the absolute values of the coefficients. Returns: None. @@ -454,9 +456,25 @@ def plot_montecarlo_weights( colors_up = {key: val for key, val in zip(median_up.keys(), pal_up)} colors_down = {key: val for key, val in zip(median_down.keys(), pal_down)} custom_pal = {**colors_up, **colors_down} - data_reindex = data.reindex(columns=custom_pal.keys()) - data_melt = pd.melt(data_reindex) - + if absolute_vals: + absolute_order = {**median_up, **median_down} + absolute_order = {key: np.abs(val) for key, val in absolute_order.items()} + absolute_order = dict( + sorted( + absolute_order.items(), + key=lambda item: np.abs(item[1]), + reverse=True, + ) + ) + custom_pal = {key: custom_pal[key] for key in absolute_order.keys()} + data_reindex = data.reindex(columns=list(custom_pal.keys())) + data_melt = pd.melt(data_reindex) + data_melt.value = np.abs(data_melt.value) + plot_title = "Absolute feature importance" + else: + data_reindex = data.reindex(columns=custom_pal.keys()) + data_melt = pd.melt(data_reindex) + plot_title = "Feature importance" _, ax = plt.subplots(figsize=figsize) ax.tick_params(labelrotation=90) if with_colour: @@ -481,7 +499,7 @@ def plot_montecarlo_weights( p.yaxis.grid(True) p.xaxis.grid(True) - p.set_title("Feature importance", fontsize=16, fontweight="bold") + p.set_title(plot_title, fontsize=16, fontweight="bold") plt.xticks(fontsize=16, rotation=90) plt.yticks(fontsize=14, rotation=0)