Skip to content

Commit 8ba4657

Browse files
authored
No tape-able operations at module level in adjoint tests (#4867)
1 parent ec342a5 commit 8ba4657

19 files changed

+144
-312
lines changed

tests/firedrake/adjoint/test_assemble.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,16 @@
66
from firedrake.adjoint import *
77

88

9+
@pytest.fixture(autouse=True)
10+
def autouse_set_test_tape(set_test_tape):
11+
pass
12+
13+
914
@pytest.fixture
1015
def rg():
1116
return RandomGenerator(PCG64(seed=1234))
1217

1318

14-
@pytest.fixture(autouse=True)
15-
def handle_taping():
16-
yield
17-
tape = get_working_tape()
18-
tape.clear_tape()
19-
20-
21-
@pytest.fixture(autouse=True, scope="module")
22-
def handle_annotation():
23-
if not annotate_tape():
24-
continue_annotation()
25-
yield
26-
# Ensure annotation is paused when we finish.
27-
if annotate_tape():
28-
pause_annotation()
29-
30-
3119
@pytest.mark.skipcomplex
3220
def test_assemble_0_forms():
3321
mesh = IntervalMesh(10, 0, 1)

tests/firedrake/adjoint/test_assignment.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,14 @@
77
from numpy.testing import assert_approx_equal, assert_allclose
88

99

10-
@pytest.fixture
11-
def rg():
12-
return RandomGenerator(PCG64(seed=1234))
13-
14-
1510
@pytest.fixture(autouse=True)
16-
def handle_taping():
17-
yield
18-
tape = get_working_tape()
19-
tape.clear_tape()
11+
def autouse_set_test_tape(set_test_tape):
12+
pass
2013

2114

22-
@pytest.fixture(autouse=True, scope="module")
23-
def handle_annotation():
24-
if not annotate_tape():
25-
continue_annotation()
26-
yield
27-
# Ensure annotation is paused when we finish.
28-
if annotate_tape():
29-
pause_annotation()
15+
@pytest.fixture
16+
def rg():
17+
return RandomGenerator(PCG64(seed=1234))
3018

3119

3220
@pytest.mark.skipcomplex

tests/firedrake/adjoint/test_burgers_newton.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,8 @@
1313

1414

1515
@pytest.fixture(autouse=True)
16-
def handle_taping():
17-
yield
18-
tape = get_working_tape()
19-
tape.clear_tape()
20-
21-
22-
@pytest.fixture(autouse=True, scope="module")
23-
def handle_annotation():
24-
if not annotate_tape():
25-
continue_annotation()
26-
yield
27-
# Ensure annotation is paused when we finish.
28-
if annotate_tape():
29-
pause_annotation()
16+
def autouse_set_test_tape(set_test_tape):
17+
pass
3018

3119

3220
@pytest.fixture
@@ -119,7 +107,7 @@ def J(ic, nu, solve_type, timestep, total_steps, V, nu_time_dependent=False):
119107
# The comment below and the others like it are used to generate the
120108
# documentation for the firedrake/docs/source/chekpointing.rst file.
121109
# [test_disk_checkpointing 10]
122-
for step in tape.timestepper(range(total_steps)):
110+
for step in tape.timestepper(iter(range(total_steps))):
123111
# Advance the forward model
124112
# [test_disk_checkpointing 11]
125113
if nu_time_dependent and step > 4:

tests/firedrake/adjoint/test_checkpointing_multistep.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,20 @@
1010

1111

1212
@pytest.fixture(autouse=True)
13-
def handle_taping():
14-
yield
15-
tape = get_working_tape()
16-
tape.clear_tape()
17-
18-
19-
@pytest.fixture(autouse=True, scope="module")
20-
def handle_annotation():
21-
if not annotate_tape():
22-
continue_annotation()
23-
yield
24-
# Ensure annotation is paused when we finish.
25-
if annotate_tape():
26-
pause_annotation()
13+
def autouse_set_test_tape(set_test_tape):
14+
pass
2715

2816

2917
total_steps = 20
3018
dt = 0.01
31-
mesh = UnitIntervalMesh(1)
32-
V = FunctionSpace(mesh, "DG", 0)
3319

3420

35-
def J(displacement_0):
21+
@pytest.fixture
22+
def V():
23+
return FunctionSpace(UnitIntervalMesh(1), "DG", 0)
24+
25+
26+
def J(displacement_0, V):
3627
stiff = Constant(2.5)
3728
damping = Constant(0.3)
3829
rho = Constant(1.0)
@@ -59,12 +50,12 @@ def J(displacement_0):
5950

6051

6152
@pytest.mark.skipcomplex
62-
def test_multisteps():
53+
def test_multisteps(V):
6354
tape = get_working_tape()
6455
tape.progress_bar = ProgressBar
6556
tape.enable_checkpointing(MixedCheckpointSchedule(total_steps, 2, storage=StorageType.RAM))
6657
displacement_0 = Function(V).assign(1.0)
67-
val = J(displacement_0)
58+
val = J(displacement_0, V)
6859
_check_forward(tape)
6960
c = Control(displacement_0)
7061
J_hat = ReducedFunctional(val, c)
@@ -82,20 +73,20 @@ def test_multisteps():
8273

8374

8475
@pytest.mark.skipcomplex
85-
def test_validity():
76+
def test_validity(V):
8677
tape = get_working_tape()
8778
tape.progress_bar = ProgressBar
8879
displacement_0 = Function(V).assign(1.0)
8980
# Without checkpointing.
90-
val0 = J(displacement_0)
81+
val0 = J(displacement_0, V)
9182
J_hat0 = ReducedFunctional(val0, Control(displacement_0))
9283
dJ0 = J_hat0.derivative()
93-
val_recomputed0 = J(displacement_0)
84+
val_recomputed0 = J(displacement_0, V)
9485
tape.clear_tape()
9586

9687
# With checkpointing.
9788
tape.enable_checkpointing(MixedCheckpointSchedule(total_steps, 2, storage=StorageType.RAM))
98-
val = J(displacement_0)
89+
val = J(displacement_0, V)
9990
J_hat = ReducedFunctional(val, Control(displacement_0))
10091
dJ = J_hat.derivative()
10192
val_recomputed = J_hat(displacement_0)

tests/firedrake/adjoint/test_disk_checkpointing.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,8 @@
99

1010

1111
@pytest.fixture(autouse=True)
12-
def handle_taping():
12+
def autouse_test_taping(set_test_tape):
1313
yield
14-
tape = get_working_tape()
15-
tape.clear_tape()
16-
tape._package_data = {}
17-
18-
19-
@pytest.fixture(autouse=True, scope="module")
20-
def handle_annotation():
21-
if not annotate_tape():
22-
continue_annotation()
23-
yield
24-
# Ensure annotation is paused when we finish.
25-
if annotate_tape():
26-
pause_annotation()
27-
2814
if disk_checkpointing():
2915
pause_disk_checkpointing()
3016

tests/firedrake/adjoint/test_dynamic_meshes.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,13 @@
66

77

88
@pytest.fixture(autouse=True)
9-
def handle_taping():
10-
yield
11-
tape = get_working_tape()
12-
tape.clear_tape()
13-
14-
15-
@pytest.fixture(autouse=True, scope="module")
16-
def handle_annotation():
17-
if not annotate_tape():
18-
continue_annotation()
19-
yield
20-
# Ensure annotation is paused when we finish.
21-
if annotate_tape():
22-
pause_annotation()
9+
def autouse_set_test_tape(set_test_tape):
10+
pass
2311

2412

2513
@pytest.mark.skipcomplex
26-
@pytest.mark.parametrize("mesh", [UnitSquareMesh(10, 10)])
27-
def test_dynamic_meshes_2D(mesh):
14+
def test_dynamic_meshes_2D():
15+
mesh = UnitSquareMesh(10, 10)
2816
S = mesh.coordinates.function_space()
2917
s = [Function(S), Function(S), Function(S)]
3018
mesh.coordinates.assign(mesh.coordinates + s[0])
@@ -71,13 +59,26 @@ def test_dynamic_meshes_2D(mesh):
7159

7260

7361
@pytest.mark.skipcomplex
74-
@pytest.mark.parametrize("mesh", [UnitCubeMesh(4, 4, 5),
75-
UnitOctahedralSphereMesh(3),
76-
UnitIcosahedralSphereMesh(3),
77-
UnitCubedSphereMesh(3),
78-
TorusMesh(25, 10, 1, 0.5),
79-
CylinderMesh(10, 25, radius=0.5, depth=0.8)])
80-
def test_dynamic_meshes_3D(mesh):
62+
@pytest.mark.parametrize("mesh_type", ["UnitCubeMesh",
63+
"UnitOctahedralSphereMesh",
64+
"UnitIcosahedralSphereMesh",
65+
"UnitCubedSphereMesh",
66+
"TorusMesh",
67+
"CylinderMesh"])
68+
def test_dynamic_meshes_3D(mesh_type):
69+
if mesh_type == "UnitCubeMesh":
70+
mesh = UnitCubeMesh(4, 4, 5)
71+
if mesh_type == "UnitOctahedralSphereMesh":
72+
mesh = UnitOctahedralSphereMesh(3)
73+
if mesh_type == "UnitIcosahedralSphereMesh":
74+
mesh = UnitIcosahedralSphereMesh(3)
75+
if mesh_type == "UnitCubedSphereMesh":
76+
mesh = UnitCubedSphereMesh(3)
77+
if mesh_type == "TorusMesh":
78+
mesh = TorusMesh(25, 10, 1, 0.5)
79+
if mesh_type == "CylinderMesh":
80+
mesh = CylinderMesh(10, 25, radius=0.5, depth=0.8)
81+
8182
S = mesh.coordinates.function_space()
8283
s = [Function(S), Function(S), Function(S)]
8384
mesh.coordinates.assign(mesh.coordinates + s[0])

tests/firedrake/adjoint/test_ensemble_reduced_functional.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,8 @@
66

77

88
@pytest.fixture(autouse=True)
9-
def handle_taping():
10-
yield
11-
tape = get_working_tape()
12-
tape.clear_tape()
13-
14-
15-
@pytest.fixture(autouse=True, scope="module")
16-
def handle_annotation():
17-
if not annotate_tape():
18-
continue_annotation()
19-
yield
20-
# Ensure annotation is paused when we finish.
21-
if annotate_tape():
22-
pause_annotation()
9+
def autouse_set_test_tape(set_test_tape):
10+
pass
2311

2412

2513
@pytest.mark.parallel(nprocs=4)

tests/firedrake/adjoint/test_external_modification.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,8 @@
66

77

88
@pytest.fixture(autouse=True)
9-
def handle_taping():
10-
yield
11-
tape = get_working_tape()
12-
tape.clear_tape()
13-
14-
15-
@pytest.fixture(autouse=True, scope="module")
16-
def handle_annotation():
17-
if not annotate_tape():
18-
continue_annotation()
19-
yield
20-
# Ensure annotation is paused when we finish.
21-
if annotate_tape():
22-
pause_annotation()
9+
def autouse_set_test_tape(set_test_tape):
10+
pass
2311

2412

2513
@pytest.mark.skipcomplex

tests/firedrake/adjoint/test_hessian.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,17 @@
33
from firedrake import *
44
from firedrake.adjoint import *
55

6-
from numpy.random import default_rng
7-
rng = default_rng()
6+
7+
@pytest.fixture(autouse=True)
8+
def autouse_set_test_tape(set_test_tape):
9+
pass
810

911

1012
@pytest.fixture
1113
def rg():
1214
return RandomGenerator(PCG64(seed=1234))
1315

1416

15-
@pytest.fixture(autouse=True)
16-
def handle_taping():
17-
yield
18-
tape = get_working_tape()
19-
tape.clear_tape()
20-
21-
22-
@pytest.fixture(autouse=True, scope="module")
23-
def handle_annotation():
24-
if not annotate_tape():
25-
continue_annotation()
26-
yield
27-
# Ensure annotation is paused when we finish.
28-
if annotate_tape():
29-
pause_annotation()
30-
31-
3217
@pytest.mark.skipcomplex
3318
def test_simple_solve(rg):
3419
tape = Tape()

tests/firedrake/adjoint/test_optimisation.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,8 @@
1212

1313

1414
@pytest.fixture(autouse=True)
15-
def handle_taping():
16-
yield
17-
tape = get_working_tape()
18-
tape.clear_tape()
19-
20-
21-
@pytest.fixture(autouse=True, scope="module")
22-
def handle_annotation():
23-
if not annotate_tape():
24-
continue_annotation()
25-
yield
26-
# Ensure annotation is paused when we finish.
27-
if annotate_tape():
28-
pause_annotation()
15+
def autouse_set_test_tape(set_test_tape):
16+
pass
2917

3018

3119
@pytest.mark.skipcomplex

0 commit comments

Comments
 (0)