Skip to content

Commit c01eb01

Browse files
authored
ENH: Streamplot control for integration max step and error (matplotlib#29333)
* Add arguments to streamplot to control integration max step and error
1 parent 5c3f2f6 commit c01eb01

File tree

7 files changed

+234
-10
lines changed

7 files changed

+234
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
Streamplot integration control
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3+
4+
Two new options have been added to the `~.axes.Axes.streamplot` function that
5+
give the user better control of the streamline integration. The first is called
6+
``integration_max_step_scale`` and multiplies the default max step computed by the
7+
integrator. The second is called ``integration_max_error_scale`` and multiplies the
8+
default max error set by the integrator. Values for these parameters between
9+
zero and one reduce (tighten) the max step or error to improve streamline
10+
accuracy by performing more computation. Values greater than one increase
11+
(loosen) the max step or error to reduce computation time at the cost of lower
12+
streamline accuracy.
13+
14+
The integrator defaults are both hand-tuned values and may not be applicable to
15+
all cases, so this allows customizing the behavior to specific use cases.
16+
Modifying only ``integration_max_step_scale`` has proved effective, but it may be useful
17+
to control the error as well.

galleries/examples/images_contours_and_fields/plot_streamplot.py

+87
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
* Unbroken streamlines even when exceeding the limit of lines within a single
1515
grid cell.
1616
"""
17+
import time
18+
1719
import matplotlib.pyplot as plt
1820
import numpy as np
1921

@@ -74,6 +76,91 @@
7476
axs[7].streamplot(X, Y, U, V, broken_streamlines=False)
7577
axs[7].set_title('Streamplot with unbroken streamlines')
7678

79+
plt.tight_layout()
80+
# plt.show()
81+
82+
# %%
83+
# Streamline computation
84+
# ----------------------
85+
#
86+
# The streamlines are computed by integrating along the provided vector field
87+
# from the seed points, which are either automatically generated or manually
88+
# specified. The accuracy and smoothness of the streamlines can be adjusted using
89+
# the ``integration_max_step_scale`` and ``integration_max_error_scale`` optional
90+
# parameters. See the `~.axes.Axes.streamplot` function documentation for more
91+
# details.
92+
#
93+
# This example shows how adjusting the maximum allowed step size and error for
94+
# the integrator changes the appearance of the streamline. The differences can
95+
# be subtle, but can be observed particularly where the streamlines have
96+
# high curvature (as shown in the zoomed in region).
97+
98+
# Linear potential flow over a lifting cylinder
99+
n = 50
100+
x, y = np.meshgrid(np.linspace(-2, 2, n), np.linspace(-3, 3, n))
101+
th = np.arctan2(y, x)
102+
r = np.sqrt(x**2 + y**2)
103+
vr = -np.cos(th) / r**2
104+
vt = -np.sin(th) / r**2 - 1 / r
105+
vx = vr * np.cos(th) - vt * np.sin(th) + 1.0
106+
vy = vr * np.sin(th) + vt * np.cos(th)
107+
108+
# Seed points
109+
n_seed = 50
110+
seed_pts = np.column_stack((np.full(n_seed, -1.75), np.linspace(-2, 2, n_seed)))
111+
112+
_, axs = plt.subplots(3, 1, figsize=(6, 14))
113+
th_circ = np.linspace(0, 2 * np.pi, 100)
114+
for ax, max_val in zip(axs, [0.05, 1, 5]):
115+
ax_ins = ax.inset_axes([0.0, 0.7, 0.3, 0.35])
116+
for ax_curr, is_inset in zip([ax, ax_ins], [False, True]):
117+
t_start = time.time()
118+
ax_curr.streamplot(
119+
x,
120+
y,
121+
vx,
122+
vy,
123+
start_points=seed_pts,
124+
broken_streamlines=False,
125+
arrowsize=1e-10,
126+
linewidth=2 if is_inset else 0.6,
127+
color="k",
128+
integration_max_step_scale=max_val,
129+
integration_max_error_scale=max_val,
130+
)
131+
if is_inset:
132+
t_total = time.time() - t_start
133+
134+
# Draw the cylinder
135+
ax_curr.fill(
136+
np.cos(th_circ),
137+
np.sin(th_circ),
138+
color="w",
139+
ec="k",
140+
lw=6 if is_inset else 2,
141+
)
142+
143+
# Set axis properties
144+
ax_curr.set_aspect("equal")
145+
146+
# Label properties of each circle
147+
text = f"integration_max_step_scale: {max_val}\n" \
148+
f"integration_max_error_scale: {max_val}\n" \
149+
f"streamplot time: {t_total:.2f} sec"
150+
if max_val == 1:
151+
text += "\n(default)"
152+
ax.text(0.0, 0.0, text, ha="center", va="center")
153+
154+
# Set axis limits and show zoomed region
155+
ax_ins.set_xlim(-1.2, -0.7)
156+
ax_ins.set_ylim(-0.8, -0.4)
157+
ax_ins.set_yticks(())
158+
ax_ins.set_xticks(())
159+
160+
ax.set_ylim(-1.5, 1.5)
161+
ax.axis("off")
162+
ax.indicate_inset_zoom(ax_ins, ec="k")
163+
77164
plt.tight_layout()
78165
plt.show()
79166
# %%

lib/matplotlib/pyplot.py

+4
Original file line numberDiff line numberDiff line change
@@ -4128,6 +4128,8 @@ def streamplot(
41284128
integration_direction="both",
41294129
broken_streamlines=True,
41304130
*,
4131+
integration_max_step_scale=1.0,
4132+
integration_max_error_scale=1.0,
41314133
num_arrows=1,
41324134
data=None,
41334135
):
@@ -4150,6 +4152,8 @@ def streamplot(
41504152
maxlength=maxlength,
41514153
integration_direction=integration_direction,
41524154
broken_streamlines=broken_streamlines,
4155+
integration_max_step_scale=integration_max_step_scale,
4156+
integration_max_error_scale=integration_max_error_scale,
41534157
num_arrows=num_arrows,
41544158
**({"data": data} if data is not None else {}),
41554159
)

lib/matplotlib/streamplot.py

+50-8
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
1919
cmap=None, norm=None, arrowsize=1, arrowstyle='-|>',
2020
minlength=0.1, transform=None, zorder=None, start_points=None,
2121
maxlength=4.0, integration_direction='both',
22-
broken_streamlines=True, *, num_arrows=1):
22+
broken_streamlines=True, *, integration_max_step_scale=1.0,
23+
integration_max_error_scale=1.0, num_arrows=1):
2324
"""
2425
Draw streamlines of a vector flow.
2526
@@ -73,6 +74,24 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
7374
If False, forces streamlines to continue until they
7475
leave the plot domain. If True, they may be terminated if they
7576
come too close to another streamline.
77+
integration_max_step_scale : float, default: 1.0
78+
Multiplier on the maximum allowable step in the streamline integration routine.
79+
A value between zero and one results in a max integration step smaller than
80+
the default max step, resulting in more accurate streamlines at the cost
81+
of greater computation time; a value greater than one does the converse. Must be
82+
greater than zero.
83+
84+
.. versionadded:: 3.11
85+
86+
integration_max_error_scale : float, default: 1.0
87+
Multiplier on the maximum allowable error in the streamline integration routine.
88+
A value between zero and one results in a tighter max integration error than
89+
the default max error, resulting in more accurate streamlines at the cost
90+
of greater computation time; a value greater than one does the converse. Must be
91+
greater than zero.
92+
93+
.. versionadded:: 3.11
94+
7695
num_arrows : int
7796
Number of arrows per streamline. The arrows are spaced equally along the steps
7897
each streamline takes. Note that this can be different to being spaced equally
@@ -97,6 +116,18 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
97116
mask = StreamMask(density)
98117
dmap = DomainMap(grid, mask)
99118

119+
if integration_max_step_scale <= 0.0:
120+
raise ValueError(
121+
"The value of integration_max_step_scale must be > 0, " +
122+
f"got {integration_max_step_scale}"
123+
)
124+
125+
if integration_max_error_scale <= 0.0:
126+
raise ValueError(
127+
"The value of integration_max_error_scale must be > 0, " +
128+
f"got {integration_max_error_scale}"
129+
)
130+
100131
if num_arrows < 0:
101132
raise ValueError(f"The value of num_arrows must be >= 0, got {num_arrows=}")
102133

@@ -159,7 +190,9 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
159190
for xm, ym in _gen_starting_points(mask.shape):
160191
if mask[ym, xm] == 0:
161192
xg, yg = dmap.mask2grid(xm, ym)
162-
t = integrate(xg, yg, broken_streamlines)
193+
t = integrate(xg, yg, broken_streamlines,
194+
integration_max_step_scale,
195+
integration_max_error_scale)
163196
if t is not None:
164197
trajectories.append(t)
165198
else:
@@ -187,7 +220,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
187220
xg = np.clip(xg, 0, grid.nx - 1)
188221
yg = np.clip(yg, 0, grid.ny - 1)
189222

190-
t = integrate(xg, yg, broken_streamlines)
223+
t = integrate(xg, yg, broken_streamlines, integration_max_step_scale,
224+
integration_max_error_scale)
191225
if t is not None:
192226
trajectories.append(t)
193227

@@ -480,7 +514,8 @@ def backward_time(xi, yi):
480514
dxi, dyi = forward_time(xi, yi)
481515
return -dxi, -dyi
482516

483-
def integrate(x0, y0, broken_streamlines=True):
517+
def integrate(x0, y0, broken_streamlines=True, integration_max_step_scale=1.0,
518+
integration_max_error_scale=1.0):
484519
"""
485520
Return x, y grid-coordinates of trajectory based on starting point.
486521
@@ -500,14 +535,18 @@ def integrate(x0, y0, broken_streamlines=True):
500535
return None
501536
if integration_direction in ['both', 'backward']:
502537
s, xyt = _integrate_rk12(x0, y0, dmap, backward_time, maxlength,
503-
broken_streamlines)
538+
broken_streamlines,
539+
integration_max_step_scale,
540+
integration_max_error_scale)
504541
stotal += s
505542
xy_traj += xyt[::-1]
506543

507544
if integration_direction in ['both', 'forward']:
508545
dmap.reset_start_point(x0, y0)
509546
s, xyt = _integrate_rk12(x0, y0, dmap, forward_time, maxlength,
510-
broken_streamlines)
547+
broken_streamlines,
548+
integration_max_step_scale,
549+
integration_max_error_scale)
511550
stotal += s
512551
xy_traj += xyt[1:]
513552

@@ -524,7 +563,9 @@ class OutOfBounds(IndexError):
524563
pass
525564

526565

527-
def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True):
566+
def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True,
567+
integration_max_step_scale=1.0,
568+
integration_max_error_scale=1.0):
528569
"""
529570
2nd-order Runge-Kutta algorithm with adaptive step size.
530571
@@ -550,7 +591,7 @@ def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True):
550591
# This error is below that needed to match the RK4 integrator. It
551592
# is set for visual reasons -- too low and corners start
552593
# appearing ugly and jagged. Can be tuned.
553-
maxerror = 0.003
594+
maxerror = 0.003 * integration_max_error_scale
554595

555596
# This limit is important (for all integrators) to avoid the
556597
# trajectory skipping some mask cells. We could relax this
@@ -559,6 +600,7 @@ def _integrate_rk12(x0, y0, dmap, f, maxlength, broken_streamlines=True):
559600
# nature of the interpolation, this doesn't boost speed by much
560601
# for quite a bit of complexity.
561602
maxds = min(1. / dmap.mask.nx, 1. / dmap.mask.ny, 0.1)
603+
maxds *= integration_max_step_scale
562604

563605
ds = maxds
564606
stotal = 0

lib/matplotlib/streamplot.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def streamplot(
2929
integration_direction: Literal["forward", "backward", "both"] = ...,
3030
broken_streamlines: bool = ...,
3131
*,
32+
integration_max_step_scale: float = ...,
33+
integration_max_error_scale: float = ...,
3234
num_arrows: int = ...,
3335
) -> StreamplotSet: ...
3436

Loading

lib/matplotlib/tests/test_streamplot.py

+74-2
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,66 @@ def test_direction():
100100
linewidth=2, density=2)
101101

102102

103+
@image_comparison(['streamplot_integration.png'], style='mpl20', tol=0.05)
104+
def test_integration_options():
105+
# Linear potential flow over a lifting cylinder
106+
n = 50
107+
x, y = np.meshgrid(np.linspace(-2, 2, n), np.linspace(-3, 3, n))
108+
th = np.arctan2(y, x)
109+
r = np.sqrt(x**2 + y**2)
110+
vr = -np.cos(th) / r**2
111+
vt = -np.sin(th) / r**2 - 1 / r
112+
vx = vr * np.cos(th) - vt * np.sin(th) + 1.0
113+
vy = vr * np.sin(th) + vt * np.cos(th)
114+
115+
# Seed points
116+
n_seed = 50
117+
seed_pts = np.column_stack((np.full(n_seed, -1.75), np.linspace(-2, 2, n_seed)))
118+
119+
fig, axs = plt.subplots(3, 1, figsize=(6, 14))
120+
th_circ = np.linspace(0, 2 * np.pi, 100)
121+
for ax, max_val in zip(axs, [0.05, 1, 5]):
122+
ax_ins = ax.inset_axes([0.0, 0.7, 0.3, 0.35])
123+
for ax_curr, is_inset in zip([ax, ax_ins], [False, True]):
124+
ax_curr.streamplot(
125+
x,
126+
y,
127+
vx,
128+
vy,
129+
start_points=seed_pts,
130+
broken_streamlines=False,
131+
arrowsize=1e-10,
132+
linewidth=2 if is_inset else 0.6,
133+
color="k",
134+
integration_max_step_scale=max_val,
135+
integration_max_error_scale=max_val,
136+
)
137+
138+
# Draw the cylinder
139+
ax_curr.fill(
140+
np.cos(th_circ),
141+
np.sin(th_circ),
142+
color="w",
143+
ec="k",
144+
lw=6 if is_inset else 2,
145+
)
146+
147+
# Set axis properties
148+
ax_curr.set_aspect("equal")
149+
150+
# Set axis limits and show zoomed region
151+
ax_ins.set_xlim(-1.2, -0.7)
152+
ax_ins.set_ylim(-0.8, -0.4)
153+
ax_ins.set_yticks(())
154+
ax_ins.set_xticks(())
155+
156+
ax.set_ylim(-1.5, 1.5)
157+
ax.axis("off")
158+
ax.indicate_inset_zoom(ax_ins, ec="k")
159+
160+
fig.tight_layout()
161+
162+
103163
def test_streamplot_limits():
104164
ax = plt.axes()
105165
x = np.linspace(-5, 10, 20)
@@ -156,8 +216,20 @@ def test_streamplot_grid():
156216
x = np.array([0, 20, 40])
157217
y = np.array([0, 20, 10])
158218

159-
with pytest.raises(ValueError, match="'y' must be strictly increasing"):
160-
plt.streamplot(x, y, u, v)
219+
220+
def test_streamplot_integration_params():
221+
x = np.array([[10, 20], [10, 20]])
222+
y = np.array([[10, 10], [20, 20]])
223+
u = np.ones((2, 2))
224+
v = np.zeros((2, 2))
225+
226+
err_str = "The value of integration_max_step_scale must be > 0, got -0.5"
227+
with pytest.raises(ValueError, match=err_str):
228+
plt.streamplot(x, y, u, v, integration_max_step_scale=-0.5)
229+
230+
err_str = "The value of integration_max_error_scale must be > 0, got 0.0"
231+
with pytest.raises(ValueError, match=err_str):
232+
plt.streamplot(x, y, u, v, integration_max_error_scale=0.0)
161233

162234

163235
def test_streamplot_inputs(): # test no exception occurs.

0 commit comments

Comments
 (0)