Skip to content

ENH: speed up wide DataFrame.line plots by using a single LineCollection #61764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ Performance improvements
- Performance improvement in :meth:`DataFrame.stack` when using ``future_stack=True`` and the DataFrame does not have a :class:`MultiIndex` (:issue:`58391`)
- Performance improvement in :meth:`DataFrame.where` when ``cond`` is a :class:`DataFrame` with many columns (:issue:`61010`)
- Performance improvement in :meth:`to_hdf` avoid unnecessary reopenings of the HDF5 file to speedup data addition to files with a very large number of groups . (:issue:`58248`)
- Performance improvement in ``DataFrame.plot(kind="line")``: very wide DataFrames (more than 200 columns) are now rendered with a single :class:`matplotlib.collections.LineCollection` instead of one ``Line2D`` per column, reducing draw time by roughly 5x on a 2000-column frame. (:issue:`61532`)
- Performance improvement in ``DataFrameGroupBy.__len__`` and ``SeriesGroupBy.__len__`` (:issue:`57595`)
- Performance improvement in indexing operations for string dtypes (:issue:`56997`)
- Performance improvement in unary methods on a :class:`RangeIndex` returning a :class:`RangeIndex` instead of a :class:`Index` when possible. (:issue:`57825`)
Expand Down
121 changes: 86 additions & 35 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
Series,
)

import itertools

from matplotlib.collections import LineCollection


def holds_integer(column: Index) -> bool:
return column.inferred_type in {"integer", "mixed-integer"}
Expand Down Expand Up @@ -1549,66 +1553,113 @@ def __init__(self, data, **kwargs) -> None:
self.data = self.data.fillna(value=0)

def _make_plot(self, fig: Figure) -> None:
"""
Draw a DataFrame line plot. For very wide frames (> 200 columns) that are
*not* time-series and have no stacking or error bars, all columns are
rendered with a single LineCollection for a large speed-up while keeping
public behaviour identical to the original per-column path.
"""
# decide once whether we can use the LineCollection fast draw
threshold = 200
use_collection = (
not self._is_ts_plot()
and not self.stacked
and not com.any_not_none(*self.errors.values())
and len(self.data.columns) > threshold
)

# choose ts-plot helper vs. regular helper
if self._is_ts_plot():
data = maybe_convert_index(self._get_ax(0), self.data)

x = data.index # dummy, not used
x = data.index # dummy; _ts_plot ignores it
plotf = self._ts_plot
it = data.items()
else:
x = self._get_xticks()
# error: Incompatible types in assignment (expression has type
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
plotf = self._plot # type: ignore[assignment]
# error: Incompatible types in assignment (expression has type
# "Iterator[tuple[Hashable, ndarray[Any, Any]]]", variable has
# type "Iterable[tuple[Hashable, Series]]")
it = self._iter_data(data=self.data) # type: ignore[assignment]

# shared state
stacking_id = self._get_stacking_id()
is_errorbar = com.any_not_none(*self.errors.values())

colors = self._get_colors()
segments: list[np.ndarray] = [] # vertices for LineCollection

# unified per-column loop
for i, (label, y) in enumerate(it):
ax = self._get_ax(i)
ax = self._get_ax(i if not use_collection else 0)

kwds = self.kwds.copy()
if self.color is not None:
kwds["color"] = self.color

style, kwds = self._apply_style_colors(
colors,
kwds,
i,
# error: Argument 4 to "_apply_style_colors" of "MPLPlot" has
# incompatible type "Hashable"; expected "str"
label, # type: ignore[arg-type]
)
kwds.update(self._get_errorbars(label=label, index=i))

label_str = self._mark_right_label(pprint_thing(label), index=i)
kwds["label"] = label_str

if use_collection:
# collect vertices; defer drawing
segments.append(np.column_stack((x, y)))

# tiny proxy only if legend is requested
if self.legend:
proxy = mpl.lines.Line2D(
[],
[],
color=kwds.get("color"),
linewidth=kwds.get(
"linewidth", mpl.rcParams["lines.linewidth"]
),
linestyle=kwds.get("linestyle", "-"),
marker=kwds.get("marker"),
)
self._append_legend_handles_labels(proxy, label_str)
else:
newlines = plotf(
ax,
x,
y,
style=style,
column_num=i,
stacking_id=stacking_id,
is_errorbar=is_errorbar,
**kwds,
)
self._append_legend_handles_labels(newlines[0], label_str)

errors = self._get_errorbars(label=label, index=i)
kwds = dict(kwds, **errors)
# reset x-limits for true ts plots
if self._is_ts_plot():
lines = get_all_lines(ax)
left, right = get_xlim(lines)
ax.set_xlim(left, right)

label = pprint_thing(label)
label = self._mark_right_label(label, index=i)
kwds["label"] = label

newlines = plotf(
ax,
x,
y,
style=style,
column_num=i,
stacking_id=stacking_id,
is_errorbar=is_errorbar,
**kwds,
# single draw call for fast path
if use_collection and segments:
if self.legend:
lc_colors = [
cast(mpl.lines.Line2D, h).get_color() # mypy: h is Line2D
for h in self.legend_handles
]
else:
# no legend - repeat default colour cycle
base = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
lc_colors = list(itertools.islice(itertools.cycle(base), len(segments)))

lc = LineCollection(
segments,
colors=lc_colors,
linewidths=self.kwds.get("linewidth", mpl.rcParams["lines.linewidth"]),
)
self._append_legend_handles_labels(newlines[0], label)

if self._is_ts_plot():
# reset of xlim should be used for ts data
# TODO: GH28021, should find a way to change view limit on xaxis
lines = get_all_lines(ax)
left, right = get_xlim(lines)
ax.set_xlim(left, right)
ax0 = self._get_ax(0)
ax0.add_collection(lc)
ax0.margins(0.05)

# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
Expand Down
27 changes: 27 additions & 0 deletions pandas/tests/plotting/frame/test_linecollection_speedup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Ensure wide DataFrame.line plots use a single LineCollection
instead of one Line2D per column (PR #61764).
"""

import numpy as np
import pytest

import pandas as pd

# Skip this entire module if matplotlib is not installed
mpl = pytest.importorskip("matplotlib")
plt = pytest.importorskip("matplotlib.pyplot")
from matplotlib.collections import LineCollection


def test_linecollection_used_for_wide_dataframe():
rng = np.random.default_rng(0)
df = pd.DataFrame(rng.standard_normal((10, 201)).cumsum(axis=0))

ax = df.plot(legend=False)

# exactly one LineCollection, and no Line2D artists
assert sum(isinstance(c, LineCollection) for c in ax.collections) == 1
assert len(ax.lines) == 0

plt.close(ax.figure)
Loading