Skip to content

Commit

Permalink
Use SubFigures
Browse files Browse the repository at this point in the history
  • Loading branch information
has2k1 committed Jan 7, 2025
1 parent 9fe97bd commit 54ec707
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 196 deletions.
29 changes: 16 additions & 13 deletions plotnine/animation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from __future__ import annotations

import typing
from copy import deepcopy
from typing import TYPE_CHECKING, cast

from matplotlib.animation import ArtistAnimation

from .exceptions import PlotnineError

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from typing import Iterable

from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.figure import Figure, SubFigure

from plotnine import ggplot
from plotnine.scales.scale import scale
Expand Down Expand Up @@ -87,6 +87,7 @@ def _draw_plots(
List of [](`Matplotlib.artist.Artist`)
"""
import matplotlib.pyplot as plt
from matplotlib.figure import Figure, SubFigure

# For keeping track of artists for each frame
artist_offsets: dict[str, list[int]] = {
Expand Down Expand Up @@ -189,6 +190,7 @@ def check_scale_limits(scales: list[scale], frame_no: int):
)

figure: Figure | None = None
subfigure: SubFigure | None = None
axs: list[Axes] = []
artists = []
scales = None # Will hold the scales of the first frame
Expand All @@ -198,14 +200,19 @@ def check_scale_limits(scales: list[scale], frame_no: int):
# onto the figure and axes created by the first ggplot and
# they create the subsequent frames.
for frame_no, p in enumerate(plots):
if figure is None:
figure = p.draw()
axs = figure.get_axes()
if frame_no == 0:
p._create_figure()
p.draw()
figure, subfigure = p.figure, p.subfigure
axs = subfigure.get_axes()
initialise_artist_offsets(len(axs))
scales = p._build_objs.scales
set_scale_limits(scales)
else:
plot = self._draw_animation_plot(p, figure, axs)
p.figure = cast(Figure, figure)
p.subfigure = cast(SubFigure, subfigure)
p.axs = axs
plot = self._draw_animation_plot(p)
check_scale_limits(plot.scales, frame_no)

artists.append(get_frame_artists(axs))
Expand All @@ -218,9 +225,7 @@ def check_scale_limits(scales: list[scale], frame_no: int):
plt.close(figure)
return figure, artists

def _draw_animation_plot(
self, plot: ggplot, figure: Figure, axs: list[Axes]
) -> ggplot:
def _draw_animation_plot(self, plot: ggplot) -> ggplot:
"""
Draw a plot/frame of the animation
Expand All @@ -229,10 +234,8 @@ def _draw_animation_plot(
from ._utils.context import plot_context

plot = deepcopy(plot)
plot.figure = figure
plot.axs = axs
with plot_context(plot):
plot._build()
plot.figure, plot.axs = plot.facet.setup(plot)
plot.facet.setup(plot)
plot._draw_layers()
return plot
44 changes: 16 additions & 28 deletions plotnine/facets/facet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy.typing as npt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.figure import SubFigure
from matplotlib.gridspec import GridSpec

from plotnine import ggplot, theme
Expand Down Expand Up @@ -82,9 +82,6 @@ class facet:
# Theme object, automatically updated before drawing the plot
theme: theme

# Figure object on which the facet panels are created
figure: Figure

# coord object, automatically updated before drawing the plot
coordinates: coord

Expand All @@ -100,8 +97,6 @@ class facet:
# Facet strips
strips: Strips

grid_spec: GridSpec

# The plot environment
environment: Environment

Expand Down Expand Up @@ -138,16 +133,16 @@ def setup(self, plot: ggplot):
self.plot = plot
self.layout = plot.layout

if hasattr(plot, "figure"):
self.figure, self.axs = plot.figure, plot.axs
if hasattr(plot, "axs"):
self.axs = plot.axs
else:
self.figure, self.axs = self.make_figure()
self.axs = self._make_axes(plot.subfigure)

self.coordinates = plot.coordinates
self.theme = plot.theme
self.layout.axs = self.axs
self.strips = Strips.from_facet(self)
return self.figure, self.axs
return self.axs

def setup_data(self, data: list[pd.DataFrame]) -> list[pd.DataFrame]:
"""
Expand Down Expand Up @@ -363,7 +358,7 @@ def __deepcopy__(self, memo: dict[Any, Any]) -> facet:
new = result.__dict__

# don't make a deepcopy of the figure & the axes
shallow = {"figure", "axs", "first_ax", "last_ax"}
shallow = {"axs", "first_ax", "last_ax"}
for key, item in old.items():
if key in shallow:
new[key] = item
Expand All @@ -373,35 +368,28 @@ def __deepcopy__(self, memo: dict[Any, Any]) -> facet:

return result

def _make_figure(self) -> tuple[Figure, GridSpec]:
def _get_gridspec(self) -> GridSpec:
"""
Create figure & gridspec
Create gridspec for the panels
"""
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

return plt.figure(), GridSpec(self.nrow, self.ncol)
return GridSpec(self.nrow, self.ncol)

def make_figure(self) -> tuple[Figure, list[Axes]]:
def _make_axes(self, subfigure: SubFigure) -> list[Axes]:
"""
Create and return Matplotlib figure and subplot axes
Create and return subplot axes
"""
num_panels = len(self.layout.layout)
axsarr = np.empty((self.nrow, self.ncol), dtype=object)

# Create figure & gridspec
figure, gs = self._make_figure()
self.grid_spec = gs
# Create gridspec
gs = self._get_gridspec()

# Create axes
it = itertools.product(range(self.nrow), range(self.ncol))
for i, (row, col) in enumerate(it):
axsarr[row, col] = figure.add_subplot(gs[i])

# axsarr = np.array([
# figure.add_subplot(gs[i])
# for i in range(self.nrow * self.ncol)
# ]).reshape((self.nrow, self.ncol))
axsarr[row, col] = subfigure.add_subplot(gs[i])

# Rearrange axes
# They are ordered to match the positions in the layout table
Expand All @@ -420,9 +408,9 @@ def make_figure(self) -> tuple[Figure, list[Axes]]:

# Delete unused axes
for ax in axs[num_panels:]:
figure.delaxes(ax)
subfigure.delaxes(ax)
axs = axs[:num_panels]
return figure, list(axs)
return list(axs)

def _aspect_ratio(self) -> Optional[float]:
"""
Expand Down
8 changes: 5 additions & 3 deletions plotnine/facets/facet_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def __init__(
self.space = space
self.margins = margins

def _make_figure(self):
import matplotlib.pyplot as plt
def _get_gridspec(self):
"""
Create gridspec for the panels
"""
from matplotlib.gridspec import GridSpec

layout = self.layout
Expand Down Expand Up @@ -155,7 +157,7 @@ def _make_figure(self):
ratios["width_ratios"] = self.space.get("x")
ratios["height_ratios"] = self.space.get("y")

return plt.figure(), GridSpec(self.nrow, self.ncol, **ratios)
return GridSpec(self.nrow, self.ncol, **ratios)

def compute_layout(self, data: list[pd.DataFrame]) -> pd.DataFrame:
if not self.rows and not self.cols:
Expand Down
3 changes: 1 addition & 2 deletions plotnine/facets/strips.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
self.ax = ax
self.position = position
self.facet = facet
self.figure = facet.figure
self.theme = facet.theme
self.layout_info = layout_info
label_info = strip_label_details.make(layout_info, vars, position)
Expand Down Expand Up @@ -135,7 +134,7 @@ def draw(self):
text = StripText(draw_info)
rect = text.patch

self.figure.add_artist(text)
self.facet.plot.subfigure.add_artist(text)

if draw_info.position == "right":
targets.strip_background_y.append(rect)
Expand Down
Loading

0 comments on commit 54ec707

Please sign in to comment.