Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions firedrake/adjoint_utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,14 @@ class CofunctionMixin(FunctionMixin):

def _ad_dot(self, other):
return firedrake.assemble(firedrake.action(self, other))

def _ad_init_object(cls, obj):
from firedrake import Cofunction
return Cofunction(cls.function_space()).assign(obj)

def _ad_init_zero(self, dual=False):
from firedrake import Function, Cofunction
if dual:
return Function(self.function_space().dual())
else:
return Cofunction(self.function_space())
28 changes: 28 additions & 0 deletions tests/firedrake/adjoint/test_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,34 @@ def test_tao_simple_inversion(minimize, riesz_representation):
assert_allclose(x.dat.data, source_ref.dat.data, rtol=1e-2)


@pytest.mark.parametrize("minimize", [minimize_tao_lmvm,
pytest.param(minimize_tao_nls, marks=pytest.mark.xfail)])
@pytest.mark.parametrize("riesz_representation", ["L2", "H1"])
@pytest.mark.skipcomplex
def test_tao_cofunction_control(minimize, riesz_representation):
"""Test inversion of source term in helmholtz eqn using TAO."""
mesh = UnitIntervalMesh(10)
V = FunctionSpace(mesh, "CG", 1)
source_ref = Function(V)
x = SpatialCoordinate(mesh)
source_ref.interpolate(cos(pi*x**2))

# compute reference solution
with stop_annotating():
u_ref = _simple_helmholz_model(V, source_ref)

# now rerun annotated model with zero source
source = Cofunction(V.dual())
c = Control(source, riesz_map=riesz_representation)
u = _simple_helmholz_model(V, source.riesz_representation(riesz_representation))

J = assemble(1e6 * (u - u_ref)**2*dx)
rf = ReducedFunctional(J, c)

x = minimize(rf).riesz_representation(riesz_representation)
assert_allclose(x.dat.data, source_ref.dat.data, rtol=1e-2)


class TransformType(Enum):
PRIMAL = auto()
DUAL = auto()
Expand Down
Loading