Skip to content

Commit e4d8276

Browse files
committed
Fix cphyc#87
Only label those lines that are passed to the labelLines function
1 parent b7f5a30 commit e4d8276

File tree

3 files changed

+76
-17
lines changed

3 files changed

+76
-17
lines changed

labellines/core.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22

3+
import matplotlib.pyplot as plt
34
import numpy as np
45
from matplotlib.container import ErrorbarContainer
56
from matplotlib.dates import DateConverter, num2date
@@ -81,7 +82,7 @@ def labelLine(
8182

8283

8384
def labelLines(
84-
lines,
85+
lines=None,
8586
align=True,
8687
xvals=None,
8788
drop_label=False,
@@ -95,8 +96,8 @@ def labelLines(
9596
9697
Parameters
9798
----------
98-
lines : list of matplotlib lines
99-
The lines to label
99+
lines : list of matplotlib lines, optional.
100+
Lines to label. If empty, label all lines that have a label.
100101
align : boolean, optional
101102
If True, the label will be aligned with the slope of the line
102103
at the location of the label. If False, they will be horizontal.
@@ -119,16 +120,34 @@ def labelLines(
119120
kwargs : dict, optional
120121
Optional arguments passed to ax.text
121122
"""
122-
ax = lines[0].axes
123+
if lines:
124+
ax = lines[0].axes
125+
else:
126+
ax = plt.gca()
123127

124128
handles, allLabels = ax.get_legend_handles_labels()
125129

126130
all_lines = []
127131
for h in handles:
128132
if isinstance(h, ErrorbarContainer):
129-
all_lines.append(h.lines[0])
133+
line = h.lines[0]
130134
else:
131-
all_lines.append(h)
135+
line = h
136+
137+
if (lines is not None) and (line not in lines):
138+
continue
139+
all_lines.append(line)
140+
141+
# Check that the lines passed to the function have all a label
142+
if lines is not None:
143+
for line in lines:
144+
if line in all_lines:
145+
continue
146+
147+
warnings.warn(
148+
"Tried to label line %s, but could not find a label for it.",
149+
UserWarning,
150+
)
132151

133152
# In case no x location was provided, we need to use some heuristics
134153
# to generate them.

labellines/test.py

+50-11
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import warnings
21
from datetime import datetime
32

43
import matplotlib.pyplot as plt
54
import numpy as np
65
import pytest
76
from matplotlib.dates import UTC, DateFormatter, DayLocator
87
from matplotlib.testing import setup
9-
from numpy.testing import assert_raises
108

119
from .core import labelLine, labelLines
1210

@@ -203,22 +201,22 @@ def test_nan_warning():
203201

204202
line = plt.plot(x, y, label="test")[0]
205203

206-
with warnings.catch_warnings(record=True) as w:
204+
warn_msg = (
205+
".* could not be annotated due to `nans` values. "
206+
"Consider using another location via the `x` argument."
207+
)
208+
with pytest.warns(UserWarning, match=warn_msg):
207209
labelLine(line, 0.5)
208-
assert issubclass(w[-1].category, UserWarning)
209-
assert "could not be annotated" in str(w[-1].message)
210210

211-
with warnings.catch_warnings(record=True) as w:
212-
labelLine(line, 2.5)
213-
assert len(w) == 0
211+
labelLine(line, 2.5)
214212

215213

216214
def test_nan_failure():
217215
x = np.array([0, 1])
218216
y = np.array([np.nan, np.nan])
219217

220218
line = plt.plot(x, y, label="test")[0]
221-
with assert_raises(Exception):
219+
with pytest.raises(Exception):
222220
labelLine(line, 0.5)
223221

224222

@@ -228,9 +226,9 @@ def test_label_range(setupMpl):
228226
line = plt.plot(x, x**2, label="lorem ipsum")[0]
229227

230228
# This should fail
231-
with assert_raises(Exception):
229+
with pytest.raises(Exception):
232230
labelLine(line, -1)
233-
with assert_raises(Exception):
231+
with pytest.raises(Exception):
234232
labelLine(line, 2)
235233

236234
# This should work
@@ -365,3 +363,44 @@ def test_errorbars(setupMpl):
365363

366364
labelLines(ax.get_lines(), align=False, xvals=pos)
367365
return fig
366+
367+
368+
@pytest.fixture
369+
def create_plot():
370+
fig, ax = plt.subplots()
371+
X = [0, 1]
372+
Y = [0, 1]
373+
374+
lines = (
375+
*ax.plot(X, Y, label="label1"),
376+
*ax.plot(X, Y), # no label
377+
*ax.plot(X, Y, label="label2"),
378+
)
379+
return fig, ax, lines
380+
381+
382+
def test_warning_line_labeling(create_plot):
383+
_fig, _ax, lines = create_plot
384+
385+
warn_msg = "Tried to label line .*, but could not find a label for it."
386+
with pytest.warns(UserWarning, match=warn_msg):
387+
txts = labelLines(lines)
388+
# Make sure only two lines have been labeled
389+
assert len(txts) == 2
390+
391+
with pytest.warns(UserWarning, match=warn_msg):
392+
txts = labelLines(lines[1:])
393+
# Make sure only one line has been labeled
394+
assert len(txts) == 1
395+
396+
397+
def test_no_warning_line_labeling(create_plot):
398+
_fig, _ax, lines = create_plot
399+
400+
txts = labelLines(lines[0:1])
401+
assert len(txts) == 1
402+
403+
404+
def test_labeling_by_axis(create_plot):
405+
txts = labelLines()
406+
assert len(txts) == 2

pytest.ini

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[pytest]
22
python_files = test*.py
3+
filterwarnings = error

0 commit comments

Comments
 (0)