Skip to content

Commit 54ec707

Browse files
committed
Use SubFigures
1 parent 9fe97bd commit 54ec707

File tree

8 files changed

+162
-196
lines changed

8 files changed

+162
-196
lines changed

plotnine/animation.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from __future__ import annotations
22

3-
import typing
43
from copy import deepcopy
4+
from typing import TYPE_CHECKING, cast
55

66
from matplotlib.animation import ArtistAnimation
77

88
from .exceptions import PlotnineError
99

10-
if typing.TYPE_CHECKING:
10+
if TYPE_CHECKING:
1111
from typing import Iterable
1212

1313
from matplotlib.artist import Artist
1414
from matplotlib.axes import Axes
15-
from matplotlib.figure import Figure
15+
from matplotlib.figure import Figure, SubFigure
1616

1717
from plotnine import ggplot
1818
from plotnine.scales.scale import scale
@@ -87,6 +87,7 @@ def _draw_plots(
8787
List of [](`Matplotlib.artist.Artist`)
8888
"""
8989
import matplotlib.pyplot as plt
90+
from matplotlib.figure import Figure, SubFigure
9091

9192
# For keeping track of artists for each frame
9293
artist_offsets: dict[str, list[int]] = {
@@ -189,6 +190,7 @@ def check_scale_limits(scales: list[scale], frame_no: int):
189190
)
190191

191192
figure: Figure | None = None
193+
subfigure: SubFigure | None = None
192194
axs: list[Axes] = []
193195
artists = []
194196
scales = None # Will hold the scales of the first frame
@@ -198,14 +200,19 @@ def check_scale_limits(scales: list[scale], frame_no: int):
198200
# onto the figure and axes created by the first ggplot and
199201
# they create the subsequent frames.
200202
for frame_no, p in enumerate(plots):
201-
if figure is None:
202-
figure = p.draw()
203-
axs = figure.get_axes()
203+
if frame_no == 0:
204+
p._create_figure()
205+
p.draw()
206+
figure, subfigure = p.figure, p.subfigure
207+
axs = subfigure.get_axes()
204208
initialise_artist_offsets(len(axs))
205209
scales = p._build_objs.scales
206210
set_scale_limits(scales)
207211
else:
208-
plot = self._draw_animation_plot(p, figure, axs)
212+
p.figure = cast(Figure, figure)
213+
p.subfigure = cast(SubFigure, subfigure)
214+
p.axs = axs
215+
plot = self._draw_animation_plot(p)
209216
check_scale_limits(plot.scales, frame_no)
210217

211218
artists.append(get_frame_artists(axs))
@@ -218,9 +225,7 @@ def check_scale_limits(scales: list[scale], frame_no: int):
218225
plt.close(figure)
219226
return figure, artists
220227

221-
def _draw_animation_plot(
222-
self, plot: ggplot, figure: Figure, axs: list[Axes]
223-
) -> ggplot:
228+
def _draw_animation_plot(self, plot: ggplot) -> ggplot:
224229
"""
225230
Draw a plot/frame of the animation
226231
@@ -229,10 +234,8 @@ def _draw_animation_plot(
229234
from ._utils.context import plot_context
230235

231236
plot = deepcopy(plot)
232-
plot.figure = figure
233-
plot.axs = axs
234237
with plot_context(plot):
235238
plot._build()
236-
plot.figure, plot.axs = plot.facet.setup(plot)
239+
plot.facet.setup(plot)
237240
plot._draw_layers()
238241
return plot

plotnine/facets/facet.py

+16-28
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy.typing as npt
2121
from matplotlib.axes import Axes
22-
from matplotlib.figure import Figure
22+
from matplotlib.figure import SubFigure
2323
from matplotlib.gridspec import GridSpec
2424

2525
from plotnine import ggplot, theme
@@ -82,9 +82,6 @@ class facet:
8282
# Theme object, automatically updated before drawing the plot
8383
theme: theme
8484

85-
# Figure object on which the facet panels are created
86-
figure: Figure
87-
8885
# coord object, automatically updated before drawing the plot
8986
coordinates: coord
9087

@@ -100,8 +97,6 @@ class facet:
10097
# Facet strips
10198
strips: Strips
10299

103-
grid_spec: GridSpec
104-
105100
# The plot environment
106101
environment: Environment
107102

@@ -138,16 +133,16 @@ def setup(self, plot: ggplot):
138133
self.plot = plot
139134
self.layout = plot.layout
140135

141-
if hasattr(plot, "figure"):
142-
self.figure, self.axs = plot.figure, plot.axs
136+
if hasattr(plot, "axs"):
137+
self.axs = plot.axs
143138
else:
144-
self.figure, self.axs = self.make_figure()
139+
self.axs = self._make_axes(plot.subfigure)
145140

146141
self.coordinates = plot.coordinates
147142
self.theme = plot.theme
148143
self.layout.axs = self.axs
149144
self.strips = Strips.from_facet(self)
150-
return self.figure, self.axs
145+
return self.axs
151146

152147
def setup_data(self, data: list[pd.DataFrame]) -> list[pd.DataFrame]:
153148
"""
@@ -363,7 +358,7 @@ def __deepcopy__(self, memo: dict[Any, Any]) -> facet:
363358
new = result.__dict__
364359

365360
# don't make a deepcopy of the figure & the axes
366-
shallow = {"figure", "axs", "first_ax", "last_ax"}
361+
shallow = {"axs", "first_ax", "last_ax"}
367362
for key, item in old.items():
368363
if key in shallow:
369364
new[key] = item
@@ -373,35 +368,28 @@ def __deepcopy__(self, memo: dict[Any, Any]) -> facet:
373368

374369
return result
375370

376-
def _make_figure(self) -> tuple[Figure, GridSpec]:
371+
def _get_gridspec(self) -> GridSpec:
377372
"""
378-
Create figure & gridspec
373+
Create gridspec for the panels
379374
"""
380-
import matplotlib.pyplot as plt
381375
from matplotlib.gridspec import GridSpec
382376

383-
return plt.figure(), GridSpec(self.nrow, self.ncol)
377+
return GridSpec(self.nrow, self.ncol)
384378

385-
def make_figure(self) -> tuple[Figure, list[Axes]]:
379+
def _make_axes(self, subfigure: SubFigure) -> list[Axes]:
386380
"""
387-
Create and return Matplotlib figure and subplot axes
381+
Create and return subplot axes
388382
"""
389383
num_panels = len(self.layout.layout)
390384
axsarr = np.empty((self.nrow, self.ncol), dtype=object)
391385

392-
# Create figure & gridspec
393-
figure, gs = self._make_figure()
394-
self.grid_spec = gs
386+
# Create gridspec
387+
gs = self._get_gridspec()
395388

396389
# Create axes
397390
it = itertools.product(range(self.nrow), range(self.ncol))
398391
for i, (row, col) in enumerate(it):
399-
axsarr[row, col] = figure.add_subplot(gs[i])
400-
401-
# axsarr = np.array([
402-
# figure.add_subplot(gs[i])
403-
# for i in range(self.nrow * self.ncol)
404-
# ]).reshape((self.nrow, self.ncol))
392+
axsarr[row, col] = subfigure.add_subplot(gs[i])
405393

406394
# Rearrange axes
407395
# They are ordered to match the positions in the layout table
@@ -420,9 +408,9 @@ def make_figure(self) -> tuple[Figure, list[Axes]]:
420408

421409
# Delete unused axes
422410
for ax in axs[num_panels:]:
423-
figure.delaxes(ax)
411+
subfigure.delaxes(ax)
424412
axs = axs[:num_panels]
425-
return figure, list(axs)
413+
return list(axs)
426414

427415
def _aspect_ratio(self) -> Optional[float]:
428416
"""

plotnine/facets/facet_grid.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,10 @@ def __init__(
107107
self.space = space
108108
self.margins = margins
109109

110-
def _make_figure(self):
111-
import matplotlib.pyplot as plt
110+
def _get_gridspec(self):
111+
"""
112+
Create gridspec for the panels
113+
"""
112114
from matplotlib.gridspec import GridSpec
113115

114116
layout = self.layout
@@ -155,7 +157,7 @@ def _make_figure(self):
155157
ratios["width_ratios"] = self.space.get("x")
156158
ratios["height_ratios"] = self.space.get("y")
157159

158-
return plt.figure(), GridSpec(self.nrow, self.ncol, **ratios)
160+
return GridSpec(self.nrow, self.ncol, **ratios)
159161

160162
def compute_layout(self, data: list[pd.DataFrame]) -> pd.DataFrame:
161163
if not self.rows and not self.cols:

plotnine/facets/strips.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def __init__(
4444
self.ax = ax
4545
self.position = position
4646
self.facet = facet
47-
self.figure = facet.figure
4847
self.theme = facet.theme
4948
self.layout_info = layout_info
5049
label_info = strip_label_details.make(layout_info, vars, position)
@@ -135,7 +134,7 @@ def draw(self):
135134
text = StripText(draw_info)
136135
rect = text.patch
137136

138-
self.figure.add_artist(text)
137+
self.facet.plot.subfigure.add_artist(text)
139138

140139
if draw_info.position == "right":
141140
targets.strip_background_y.append(rect)

0 commit comments

Comments
 (0)