Skip to content

Commit

Permalink
Added adjust_legend argument to solve the overlap of legend in all sc…
Browse files Browse the repository at this point in the history
…atter plot
  • Loading branch information
Starlitnightly committed Oct 17, 2024
1 parent 4ff7b14 commit e6f292b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
22 changes: 22 additions & 0 deletions dynamo/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def _matplotlib_points(
inset_dict={},
show_colorbar=True,
projection=None, # default in matplotlib
adjust_legend=False,
**kwargs,
):
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -833,6 +834,7 @@ def _matplotlib_points(
)
elif len(unique_labels) > 1 and show_legend == "on data":
font_color = "white" if background in ["black", "#ffffff"] else "black"
texts=[]
for i in unique_labels:
if i == "other":
continue
Expand All @@ -855,6 +857,15 @@ def _matplotlib_points(
PathEffects.Normal(),
]
)
texts.append(txt)
if adjust_legend==True:
from adjustText import adjust_text
import adjustText
if adjustText.__version__<='0.8':
adjust_text(texts,only_move={'text': 'xy'},arrowprops=dict(arrowstyle='->', color='red'),)
else:
adjust_text(texts,only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
arrowprops=dict(arrowstyle='->', color='black'))
else:
ax.legend(
handles=legend_elements,
Expand Down Expand Up @@ -885,6 +896,7 @@ def _datashade_points(
vmax=98,
sort="raw",
projection="2d",
adjust_legend=False,
**kwargs,
):
import datashader as ds
Expand Down Expand Up @@ -1007,6 +1019,7 @@ def _datashade_points(
if show_legend and legend_elements is not None:
if len(unique_labels) > 1 and show_legend == "on data":
font_color = "white" if background == "black" else "black"
texts=[]
for i in unique_labels:
color_cnt = np.nanmedian(points.iloc[np.where(labels == i)[0], :2], 0)
txt = plt.text(
Expand All @@ -1025,6 +1038,15 @@ def _datashade_points(
PathEffects.Normal(),
]
)
texts.append(txt)
if adjust_legend==True:
from adjustText import adjust_text
import adjustText
if adjustText.__version__<='0.8':
adjust_text(texts,only_move={'text': 'xy'},arrowprops=dict(arrowstyle='->', color='red'),)
else:
adjust_text(texts,only_move={"text": "xy", "static": "xy", "explode": "xy", "pull": "xy"},
arrowprops=dict(arrowstyle='->', color='black'))
else:
if type(show_legend) == "str":
ax.legend(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ get_version>=3.5.4
openpyxl
typing-extensions
session-info>=1.0.0
adjustText

0 comments on commit e6f292b

Please sign in to comment.