Skip to content

Commit 93d48ec

Browse files
authoredNov 5, 2024··
Merge pull request #384 from jcapriot/random_warnings
Warn for non-repeatable random tests in a testing environment
2 parents 636e3ff + 58ff79e commit 93d48ec

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed
 

‎discretize/tests.py

+25
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
assert_isadjoint
2626
""" # NOQA D205
2727

28+
import warnings
29+
2830
import numpy as np
2931
import scipy.sparse as sp
3032

@@ -81,6 +83,23 @@
8183
_happiness_rng = np.random.default_rng()
8284

8385

86+
def _warn_random_test():
87+
stack = inspect.stack()
88+
in_pytest = any(x[0].f_globals["__name__"].startswith("_pytest.") for x in stack)
89+
in_nosetest = any(x[0].f_globals["__name__"].startswith("nose.") for x in stack)
90+
91+
if in_pytest or in_nosetest:
92+
test = "pytest" if in_pytest else "nosetest"
93+
warnings.warn(
94+
f"You are running a {test} without setting a random seed, the results might not be"
95+
"repeatable. For repeatable tests please pass an argument to `random seed` that is"
96+
"not `None`.",
97+
UserWarning,
98+
stacklevel=3,
99+
)
100+
return in_pytest or in_nosetest
101+
102+
84103
def setup_mesh(mesh_type, nC, nDim, random_seed=None):
85104
"""Generate arbitrary mesh for testing.
86105
@@ -110,6 +129,8 @@ def setup_mesh(mesh_type, nC, nDim, random_seed=None):
110129
A discretize mesh of class specified by the input argument *mesh_type*
111130
"""
112131
if "random" in mesh_type:
132+
if random_seed is None:
133+
_warn_random_test()
113134
rng = np.random.default_rng(random_seed)
114135
if "TensorMesh" in mesh_type:
115136
if "uniform" in mesh_type:
@@ -649,6 +670,8 @@ def check_derivative(
649670
x0 = mkvc(x0)
650671

651672
if dx is None:
673+
if random_seed is None:
674+
_warn_random_test()
652675
rng = np.random.default_rng(random_seed)
653676
dx = rng.standard_normal(len(x0))
654677

@@ -867,6 +890,8 @@ def assert_isadjoint(
867890
"""
868891
__tracebackhide__ = True
869892

893+
if random_seed is None:
894+
_warn_random_test()
870895
rng = np.random.default_rng(random_seed)
871896

872897
def random(size, iscomplex):

‎tests/base/test_tests.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
import subprocess
55
import numpy as np
66
import scipy.sparse as sp
7-
from discretize.tests import assert_isadjoint, check_derivative, assert_expected_order
7+
from discretize.tests import (
8+
assert_isadjoint,
9+
check_derivative,
10+
assert_expected_order,
11+
_warn_random_test,
12+
setup_mesh,
13+
)
814

915

1016
class TestAssertIsAdjoint:
@@ -166,3 +172,22 @@ def test_import_time():
166172

167173
# Currently we check t < 1.0s.
168174
assert float(out.stderr.decode("utf-8")[:-1]) < 1.0
175+
176+
177+
def test_random_test_warning():
178+
179+
match = r"You are running a pytest without setting a random seed.*"
180+
with pytest.warns(UserWarning, match=match):
181+
_warn_random_test()
182+
183+
def simple_deriv(x):
184+
return np.sin(x), lambda y: np.cos(x) * y
185+
186+
with pytest.warns(UserWarning, match=match):
187+
check_derivative(simple_deriv, np.zeros(10), plotIt=False)
188+
189+
with pytest.warns(UserWarning, match=match):
190+
setup_mesh("randomTensorMesh", 10, 1)
191+
192+
with pytest.warns(UserWarning, match=match):
193+
assert_isadjoint(lambda x: x, lambda x: x, 5, 5)

0 commit comments

Comments
 (0)
Please sign in to comment.