diff --git a/src/emcpy/plots/create_plots.py b/src/emcpy/plots/create_plots.py index c8a450c..afc28b0 100644 --- a/src/emcpy/plots/create_plots.py +++ b/src/emcpy/plots/create_plots.py @@ -333,7 +333,7 @@ def add_suptitle(self, text, **kwargs): self.fig.suptitle(text, **kwargs) def plot_logo(self, loc, which='noaa/nws', - single_logo=True, zoom=1, alpha=0.5): + subplot_orientation='last', zoom=1, alpha=0.5): """ Add branding logo on all axes. """ @@ -346,9 +346,17 @@ def plot_logo(self, loc, which='noaa/nws', image_path = os.path.join(emcpy.emcpy_directory, 'logos', image_dict[which]) im = Image.open(image_path) + if subplot_orientation.lower() not in ['first', 'last', 'all']: + raise TypeError(f"{subplot_orientation} is not a valid input. " + + "Valid inputs include 'first', 'last', or 'all'") + ax_list = self.fig.axes - if single_logo: + if subplot_orientation.lower() == 'first': + ax = ax_list[0] + self._display_logo(ax, im, loc, zoom, alpha) + + elif subplot_orientation.lower() == 'last': ax = ax_list[-1] self._display_logo(ax, im, loc, zoom, alpha)