Skip to content

Commit ad78681

Browse files
Rewrite using numpy.testing.assert_equal
1 parent f952d52 commit ad78681

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

tests/module/mobject/graphing/test_coordinate_system.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44

55
import numpy as np
6+
import numpy.testing as nt
67
import pytest
78

89
from manim import (
@@ -197,42 +198,42 @@ def test_input_to_graph_point():
197198

198199
def test_matmul_operations():
199200
ax = Axes()
200-
assert (ax @ (1, 2) == ax.coords_to_point(1, 2)).all()
201+
nt.assert_equal(ax @ (1, 2), ax.coords_to_point(1, 2))
201202
# should work with mobjects too, using their center
202203
mob = Dot().move_to((1, 2, 0))
203-
assert (ax @ mob == ax.coords_to_point(1, 2)).all()
204+
nt.assert_equal(ax @ mob, ax.coords_to_point(1, 2))
204205

205206
# other coordinate systems like PolarPlane and ComplexPlane should override __matmul__ indirectly
206207
polar = PolarPlane()
207-
assert (polar @ (1, 2) == polar.polar_to_point(1, 2)).all()
208+
nt.assert_equal(polar @ (1, 2), polar.polar_to_point(1, 2))
208209

209210
complx = ComplexPlane()
210-
assert (complx @ (1 + 2j) == complx.number_to_point(1 + 2j)).all()
211+
nt.assert_equal(complx @ (1 + 2j), complx.number_to_point(1 + 2j))
211212

212213
# Numberline doesn't inherit from CoordinateSystem, but it should still work
213214
n = NumberLine()
214-
assert (n @ 3 == n.number_to_point(3)).all()
215+
nt.assert_equal(n @ 3, n.number_to_point(3))
215216

216217

217218
def test_rmatmul_operations():
218219
point = (1, 2, 0)
219220

220221
ax = Axes()
221-
assert (point @ ax == ax.point_to_coords(point)).all()
222+
nt.assert_equal(point @ ax, ax.point_to_coords(point))
222223

223224
polar = PolarPlane()
224225
assert point @ polar == polar.point_to_polar(point)
225226

226227
complx = ComplexPlane()
227-
assert point @ complx == complx.point_to_number(point)
228+
nt.assert_equal(point @ complx, complx.point_to_number(point))
228229

229230
n = NumberLine()
230231
point = n @ 4
231232

232-
assert (
233-
tuple(point) @ n # ndarray overrides __matmul__
234-
== n.point_to_number(point)
235-
).all()
233+
nt.assert_equal(
234+
tuple(point) @ n, # ndarray overrides __matmul__
235+
n.point_to_number(point),
236+
)
236237

237238
mob = Dot().move_to(point)
238-
assert (mob @ n == n.point_to_number(mob.get_center())).all()
239+
nt.assert_equal(mob @ n, n.point_to_number(mob.get_center()))

0 commit comments

Comments
 (0)