diff --git a/firedrake/adjoint/__init__.py b/firedrake/adjoint/__init__.py index d3d28e6129..e2e57fc452 100644 --- a/firedrake/adjoint/__init__.py +++ b/firedrake/adjoint/__init__.py @@ -38,6 +38,7 @@ from firedrake.adjoint.ufl_constraints import UFLInequalityConstraint, \ UFLEqualityConstraint # noqa F401 from firedrake.adjoint.ensemble_reduced_functional import EnsembleReducedFunctional # noqa F401 +from firedrake.adjoint.transformed_functional import L2RieszMap, L2TransformedFunctional # noqa: F401 import numpy_adjoint # noqa F401 import firedrake.ufl_expr import types diff --git a/firedrake/adjoint/transformed_functional.py b/firedrake/adjoint/transformed_functional.py new file mode 100644 index 0000000000..cbe40667cc --- /dev/null +++ b/firedrake/adjoint/transformed_functional.py @@ -0,0 +1,428 @@ +from contextlib import contextmanager +from operator import itemgetter + +import firedrake as fd +from firedrake.adjoint import Control, ReducedFunctional +import finat +from pyadjoint import no_annotations +from pyadjoint.enlisting import Enlist +from pyadjoint.reduced_functional import AbstractReducedFunctional +import ufl + +__all__ = \ + [ + "L2RieszMap", + "L2TransformedFunctional" + ] + + +@contextmanager +def local_vector(u, *, readonly=False): + u_local = u.createLocalVector() + u.getLocalVector(u_local, readonly=readonly) + yield u_local + u.restoreLocalVector(u_local, readonly=readonly) + + +class L2Cholesky: + """Mass matrix Cholesky factorization for a (real) DG space. + + Parameters + ---------- + + space : WithGeometry + DG space. + constant_jacobian : bool + Whether the mass matrix is constant. + """ + + def __init__(self, space, *, constant_jacobian=True): + if fd.utils.complex_mode: + raise NotImplementedError("complex not supported") + + self._space = space + self._constant_jacobian = constant_jacobian + self._cached_pc = None + + @property + def space(self) -> fd.functionspaceimpl.WithGeometry: + """Function space. + """ + + return self._space + + def _pc(self): + import petsc4py.PETSc as PETSc + + if self._cached_pc is None: + M = fd.assemble(fd.inner(fd.TrialFunction(self.space), fd.TestFunction(self.space)) * fd.dx, + mat_type="aij") + M_local = M.petscmat.getDiagonalBlock() + + pc = PETSc.PC().create(M_local.comm) + pc.setType(PETSc.PC.Type.CHOLESKY) + pc.setFactorSolverType(PETSc.Mat.SolverType.PETSC) + pc.setOperators(M_local) + pc.setUp() + + if self._constant_jacobian: + self._cached_pc = M, M_local, pc + else: + _, _, pc = self._cached_pc + + return pc + + def C_inv_action(self, u): + """For the Cholesky factorization + + ... math : + + M = C C^T, + + compute the action of :math:`C^{-1}`. + + Parameters + ---------- + + u : Function or Cofunction + Compute :math:`C^{-1} \tilde{u}` where :math:`\tilde{u}` is the + vector of degrees of freedom for :math:`u`. + + Returns + ------- + + v : Cofunction + Has vector of degrees of freedom :math:`C^{-1} \tilde{u}`. + """ + + pc = self._pc() + v = fd.Cofunction(self.space.dual()) + with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v: + with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s: + pc.applySymmetricLeft(u_v_s, v_v_s) + return v + + def C_T_inv_action(self, u): + """For the Cholesky factorization + + ... math : + + M = C C^T, + + compute the action of :math:`C^{-T}`. + + Parameters + ---------- + + u : Function or Cofunction + Compute :math:`C^{-T} \tilde{u}` where :math:`\tilde{u}` is the + vector of degrees of freedom for :math:`u`. + + Returns + ------- + + v : Function + Has vector of degrees of freedom :math:`C^{-T} \tilde{u}`. + """ + + pc = self._pc() + v = fd.Function(self.space) + with u.dat.vec_ro as u_v, v.dat.vec_wo as v_v: + with local_vector(u_v, readonly=True) as u_v_s, local_vector(v_v) as v_v_s: + pc.applySymmetricRight(u_v_s, v_v_s) + return v + + +class L2RieszMap(fd.RieszMap): + """An :math:`L^2` Riesz map. + + Parameters + ---------- + + target : WithGeometry + Function space. + + Keyword arguments are passed to the :class:`firedrake.RieszMap` + constructor. + """ + + def __init__(self, target, **kwargs): + if not isinstance(target, fd.functionspaceimpl.WithGeometry): + raise TypeError("Target must be a WithGeometry") + super().__init__(target, ufl.L2, **kwargs) + + +def is_dg_space(space): + """Return whether a function space is DG. + + Parameters + ---------- + + space : WithGeometry + The function space. + + Returns + ------- + + bool + Whether the function space is DG. + """ + + e, _ = finat.element_factory.convert(space.ufl_element()) + return e.is_dg() + + +def dg_space(space): + """Construct a DG space containing a given function space as a subspace. + + Parameters + ---------- + + space : WithGeometry + A function space. + + Returns + ------- + + WithGeometry + A DG space containing `space` as a subspace. May be `space`. + """ + + if is_dg_space(space): + return space + else: + return fd.FunctionSpace(space.mesh(), finat.ufl.BrokenElement(space.ufl_element())) + + +class L2TransformedFunctional(AbstractReducedFunctional): + r"""Represents the functional + + .. math:: + + J \circ \Pi \circ \Xi + + where + + - :math:`J` is the functional definining an optimization problem. + - :math:`\Pi` is the :math:`L^2` projection from a DG space containing + the control space as a subspace. + - :math:`\Xi` represents a change of basis from an :math:`L^2` + orthonormal basis to the finite element basis for the DG space. + + The optimization is therefore transformed into an optimization problem + using an :math:`L^2` orthonormal basis for a DG finite element space. + + The transformation is related to the factorization in section 4.1 of + https://doi.org/10.1137/18M1175239 -- specifically the factorization + in their equation (4.2) can be related to :math:`\Pi \circ \Xi`. + + Parameters + ---------- + + functional : OverloadedType + Functional defining the optimization problem, :math:`J`. + controls : Control or Sequence[Control] + Controls. Must be :class:`firedrake.Function` objects. + space_D : None, WithGeometry, or Sequence[None or WithGeometry] + DG space containing the control space. + riesz_map : L2RieszMap or Sequence[L2RieszMap] + Used for projecting from the DG space onto the control space. Ignored + for DG controls. + alpha : Real + Modifies the functional, equivalent to adding an extra term to + :math:`J \circ \Pi` + + .. math:: + + \frac{1}{2} \alpha \left\| m_D - \Pi ( m_D ) \right\|_{L^2}^2. + + e.g. in a minimization problem this adds a penalty term which can + be used to avoid ill-posedness due to the use of a larger DG space. + tape : Tape + Tape used in evaluations involving :math:`J`. + """ + + @no_annotations + def __init__(self, functional, controls, *, space_D=None, riesz_map=None, alpha=0, tape=None): + if not all(isinstance(control.control, fd.Function) for control in Enlist(controls)): + raise TypeError("controls must be Function objects") + + super().__init__() + self._J = ReducedFunctional(functional, controls, tape=tape) + + self._space = tuple(control.control.function_space() + for control in self._J.controls) + if space_D is None: + space_D = tuple(None for _ in self._space) + self._space_D = Enlist(space_D) + if len(self._space_D) != len(self._space): + raise ValueError("Invalid length") + self._space_D = tuple(dg_space(space) if space_D is None else space_D + for space, space_D in zip(self._space, self._space_D)) + + self._controls = tuple(Control(fd.Function(space_D), riesz_map="l2") + for space_D in self._space_D) + self._controls = Enlist(Enlist(controls).delist(self._controls)) + + if riesz_map is None: + riesz_map = tuple(map(L2RieszMap, self._space)) + self._riesz_map = Enlist(riesz_map) + if len(self._riesz_map) != len(self._controls): + raise ValueError("Invalid length") + self._C = tuple(L2Cholesky(space_D, constant_jacobian=riesz_map.constant_jacobian) + for space_D, riesz_map in zip(self._space_D, self._riesz_map)) + + self._alpha = alpha + self._m_k = None + + # Map the initial guess + controls_t = self._dual_transform(tuple(control.control for control in self._J.controls), apply_riesz=False) + for control, control_t in zip(self._controls, controls_t): + control.control.assign(control_t) + + @property + def controls(self) -> Enlist[Control]: + return Enlist(self._controls.delist()) + + def _dual_transform(self, u, u_D=None, *, apply_riesz=False): + u = Enlist(u) + if len(u) != len(self.controls): + raise ValueError("Invalid length") + if u_D is None: + u_D = tuple(None for _ in u) + else: + u_D = Enlist(u_D) + if len(u_D) != len(self.controls): + raise ValueError("Invalid length") + + def transform(C, u, u_D, space, space_D, riesz_map): + if apply_riesz: + if space is space_D: + v = u + else: + v = fd.assemble(fd.inner(riesz_map(u), fd.TestFunction(space_D)) * fd.dx) + else: + v = fd.assemble(fd.inner(u, fd.TestFunction(space_D)) * fd.dx) + if u_D is not None: + v.dat.axpy(1, u_D.dat) + v = C.C_inv_action(v) + return v.riesz_representation("l2") + + v = tuple(map(transform, self._C, u, u_D, self._space, self._space_D, self._riesz_map)) + return u.delist(v) + + def _primal_transform(self, u): + u = Enlist(u) + if len(u) != len(self.controls): + raise ValueError("Invalid length") + + def transform(C, u, space, space_D, riesz_map): + if fd.utils.complex_mode: + # Would need to be adjoint + raise NotImplementedError("complex not supported") + v = C.C_T_inv_action(u) + if space is space_D: + w = v + else: + w = riesz_map(fd.assemble(fd.inner(v, fd.TestFunction(space)) * fd.dx)) + return v, w + + vw = tuple(map(transform, self._C, u, self._space, self._space_D, self._riesz_map)) + return u.delist(tuple(map(itemgetter(0), vw))), u.delist(tuple(map(itemgetter(1), vw))) + + @no_annotations + def map_result(self, m): + """Map the result of an optimization. + + Parameters + ---------- + + m : firedrake.Function or Sequence[firedrake.Function] + The result of the optimization. Represents an expansion in an + :math:`L^2` orthonormal basis for the DG space. + + Returns + ------- + + firedrake.Function or Sequence[firedrake.Function] + The mapped result in the original control space. + """ + + _, m_J = self._primal_transform(m) + return m_J + + @no_annotations + def __call__(self, values): + values = Enlist(values) + m_D, m_J = self._primal_transform(values) + J = self._J(m_J) + if self._alpha != 0: + for space, space_D, m_D_i, m_J_i in zip(self._space, self._space_D, m_D, m_J): + if space is not space_D: + J += fd.assemble(0.5 * fd.Constant(self._alpha) * fd.inner(m_D_i - m_J_i, m_D_i - m_J_i) * fd.dx) + self._m_k = m_D, m_J + return J + + @no_annotations + def derivative(self, adj_input=1.0, apply_riesz=False): + if adj_input != 1: + raise NotImplementedError("adj_input != 1 not supported") + + u = Enlist(self._J.derivative()) + + if self._alpha == 0: + v_alpha = None + else: + v_alpha = [] + for space, space_D, m_D, m_J in zip(self._space, self._space_D, *self._m_k): + if space is space_D: + v_alpha.append(None) + else: + if fd.utils.complex_mode: + raise RuntimeError("Not complex differentiable") + v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, fd.TestFunction(space_D)) * fd.dx)) + v = self._dual_transform(u, v_alpha, apply_riesz=True) + if apply_riesz: + v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) + for v_i, control in zip(v, self.controls)) + return u.delist(v) + + @no_annotations + def hessian(self, m_dot, hessian_input=None, evaluate_tlm=True, apply_riesz=False): + if hessian_input is not None: + raise NotImplementedError("hessian_input not None not supported") + + m_dot = Enlist(m_dot) + m_dot_D, m_dot_J = self._primal_transform(m_dot) + u = Enlist(self._J.hessian(m_dot.delist(m_dot_J), evaluate_tlm=evaluate_tlm)) + + if self._alpha == 0: + v_alpha = None + else: + v_alpha = [] + for space, space_D, m_dot_D_i, m_dot_J_i in zip(self._space, self._space_D, m_dot_D, m_dot_J): + if space is space_D: + v_alpha.append(None) + else: + if fd.utils.complex_mode: + raise RuntimeError("Not complex differentiable") + v_alpha.append(fd.assemble(fd.Constant(self._alpha) * fd.inner(m_dot_D_i - m_dot_J_i, fd.TestFunction(space_D)) * fd.dx)) + v = self._dual_transform(u, v_alpha, apply_riesz=True) + if apply_riesz: + v = tuple(v_i._ad_convert_riesz(v_i, riesz_map=control.riesz_map) + for v_i, control in zip(v, self.controls)) + return u.delist(v) + + @no_annotations + def tlm(self, m_dot): + m_dot = Enlist(m_dot) + m_dot_D, m_dot_J = self._primal_transform(m_dot) + tau_J = self._J.tlm(m_dot.delist(m_dot_J)) + + if self._alpha != 0: + for space, space_D, m_dot_D_i, m_D, m_J in zip(self._space, self._space_D, m_dot_D, *self._m_k): + if space is not space_D: + if fd.utils.complex_mode: + raise RuntimeError("Not complex differentiable") + tau_J += fd.assemble(fd.Constant(self._alpha) * fd.inner(m_D - m_J, m_dot_D_i) * fd.dx) + return tau_J diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index 47bb0b6ba4..3a80e599cd 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -394,6 +394,8 @@ class RieszMap: variational problem that solves for the Riesz map. restrict: bool If `True`, use restricted function spaces in the Riesz map solver. + constant_jacobian : bool + Whether the matrix associated with the map is constant. """ def __init__(self, function_space_or_inner_product=None, @@ -498,3 +500,10 @@ def __call__(self, value): f"Unable to ascertain if {value} is primal or dual." ) return output + + @property + def constant_jacobian(self) -> bool: + """Whether the matrix associated with the map is constant. + """ + + return self._constant_jacobian diff --git a/tests/firedrake/adjoint/test_transformed_functional.py b/tests/firedrake/adjoint/test_transformed_functional.py new file mode 100644 index 0000000000..ada587b5e4 --- /dev/null +++ b/tests/firedrake/adjoint/test_transformed_functional.py @@ -0,0 +1,270 @@ +from collections.abc import Sequence +from functools import partial + +import firedrake as fd +from firedrake.adjoint import ( + Control, L2TransformedFunctional, MinimizationProblem, ReducedFunctional, + continue_annotation, minimize, pause_annotation, set_working_tape) +import numpy as np +from pyadjoint import TAOSolver +from pyadjoint.reduced_functional_numpy import ReducedFunctionalNumPy +import pytest +import ufl + + +@pytest.fixture(scope="module", autouse=True) +def setup_tape(): + with set_working_tape(): + pause_annotation() + yield + pause_annotation() + + +class ReducedFunctional(ReducedFunctional): + def __init__(self, *args, **kwargs): + self._test_transformed_functional__ncalls = 0 + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + self._test_transformed_functional__ncalls += 1 + return super().__call__(*args, **kwargs) + + +class L2TransformedFunctional(L2TransformedFunctional): + def __init__(self, *args, **kwargs): + self._test_transformed_functional__ncalls = 0 + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + self._test_transformed_functional__ncalls += 1 + return super().__call__(*args, **kwargs) + + +class MinimizeCallback(Sequence): + def __init__(self, m_0, error_norm): + self._space = m_0.function_space() + self._error_norm = error_norm + self._data = [] + + self(np.asarray(m_0._ad_to_list(m_0))) + + def __len__(self): + return len(self._data) + + def __getitem__(self, key): + return self._data[key] + + def __call__(self, xk): + k = len(self) + if ufl.duals.is_primal(self._space): + m_k = fd.Function(self._space, name="m_k") + elif ufl.duals.is_dual(self._space): + m_k = fd.Cofunction(self._space, name="m_k") + else: + raise ValueError("space is neither primal nor dual") + m_k._ad_assign_numpy(m_k, xk, 0) + error_norm = self._error_norm(m_k) + print(f"{k=} {error_norm=:6g}") + self._data.append(error_norm) + + +@pytest.mark.parametrize("family", ("Lagrange", "Discontinuous Lagrange")) +def test_transformed_functional_mass_inverse(family): + mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") + x, y = fd.SpatialCoordinate(mesh) + space = fd.FunctionSpace(mesh, family, 1, variant="equispaced") + + def forward(m): + return fd.assemble(fd.inner(m - m_ref, m - m_ref) * fd.dx) + + m_ref = fd.Function(space, name="m_ref").interpolate( + fd.exp(x) * fd.sin(fd.pi * x) * fd.cos(fd.pi * y)) + + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward(m_0) + pause_annotation() + c = Control(m_0, riesz_map="l2") + + J_hat = ReducedFunctional(J, c) + + def error_norm(m): + return fd.norm(m - m_ref, norm_type="L2") + + cb = MinimizeCallback(m_0, error_norm) + _ = minimize(J_hat, method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-6}) + assert 1e-6 < cb[-1] < 1e-5 + if family == "Lagrange": + assert len(cb) > 12 # == 15 + assert J_hat._test_transformed_functional__ncalls > 12 # == 15 + elif family == "Discontinuous Lagrange": + assert len(cb) == 5 + assert J_hat._test_transformed_functional__ncalls == 6 + else: + raise ValueError(f"Invalid element family: '{family}'") + + J_hat = L2TransformedFunctional(J, c, alpha=1) + + def error_norm(m): + m = J_hat.map_result(m) + return fd.norm(m - m_ref, norm_type="L2") + + cb = MinimizeCallback(J_hat.controls[0].control, error_norm) + _ = minimize(ReducedFunctionalNumPy(J_hat), method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-6}) + assert cb[-1] < 1e-10 + assert len(cb) == 3 + assert J_hat._test_transformed_functional__ncalls == 3 + + +def test_transformed_functional_poisson(): + mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") + x, y = fd.SpatialCoordinate(mesh) + space = fd.FunctionSpace(mesh, "Lagrange", 1) + test = fd.TestFunction(space) + trial = fd.TrialFunction(space) + bc = fd.DirichletBC(space, 0, "on_boundary") + + def pre_process(m): + m_0 = fd.Function(space, name="m_0").assign(m) + bc.apply(m_0) + m_1 = fd.Function(space, name="m_1").assign(m - m_0) + return m_0, m_1 + + def forward(m): + m_0, m_1 = pre_process(m) + u = fd.Function(space, name="u") + fd.solve(fd.inner(fd.grad(trial), fd.grad(test)) * fd.dx + == fd.inner(m_0, test) * fd.dx, + u, bc) + return m_0, m_1, u + + def forward_J(m, u_ref, alpha): + _, m_1, u = forward(m) + return fd.assemble(fd.inner(u - u_ref, u - u_ref) * fd.dx + + fd.Constant(alpha ** 2) * fd.inner(m_1, m_1) * fd.ds) + + m_ref = fd.Function(space, name="m_ref").interpolate( + fd.exp(x) * fd.sin(fd.pi * x) * fd.sin(fd.pi * y)) + m_ref, _, u_ref = forward(m_ref) + forward_J = partial(forward_J, u_ref=u_ref, alpha=1) + + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward_J(m_0) + pause_annotation() + c = Control(m_0, riesz_map="l2") + + J_hat = ReducedFunctional(J, c) + + def error_norm(m): + m, _ = pre_process(m) + return fd.norm(m - m_ref, norm_type="L2") + + cb = MinimizeCallback(m_0, error_norm) + _ = minimize(J_hat, method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-10}) + assert 1e-2 < cb[-1] < 5e-2 + assert len(cb) > 80 # == 85 + assert J_hat._test_transformed_functional__ncalls > 90 # == 95 + + J_hat = L2TransformedFunctional(J, c, alpha=1e-5) + + def error_norm(m): + m = J_hat.map_result(m) + m, _ = pre_process(m) + return fd.norm(m - m_ref, norm_type="L2") + + cb = MinimizeCallback(J_hat.controls[0].control, error_norm) + _ = minimize(ReducedFunctionalNumPy(J_hat), method="L-BFGS-B", + callback=cb, + options={"ftol": 0, + "gtol": 1e-10}) + assert 1e-4 < cb[-1] < 5e-4 + assert len(cb) < 55 # == 51 + assert J_hat._test_transformed_functional__ncalls < 60 # == 55 + + +def test_transformed_functional_poisson_tao_nls(): + mesh = fd.UnitSquareMesh(5, 5, diagonal="crossed") + x, y = fd.SpatialCoordinate(mesh) + space = fd.FunctionSpace(mesh, "Lagrange", 1) + test = fd.TestFunction(space) + trial = fd.TrialFunction(space) + bc = fd.DirichletBC(space, 0, "on_boundary") + + def pre_process(m): + m_0 = fd.Function(space, name="m_0").assign(m) + bc.apply(m_0) + m_1 = fd.Function(space, name="m_1").assign(m - m_0) + return m_0, m_1 + + def forward(m): + m_0, m_1 = pre_process(m) + u = fd.Function(space, name="u") + fd.solve(fd.inner(fd.grad(trial), fd.grad(test)) * fd.dx + == fd.inner(m_0, test) * fd.dx, + u, bc) + return m_0, m_1, u + + def forward_J(m, u_ref, alpha): + _, m_1, u = forward(m) + return fd.assemble(fd.inner(u - u_ref, u - u_ref) * fd.dx + + fd.Constant(alpha ** 2) * fd.inner(m_1, m_1) * fd.ds) + + m_ref = fd.Function(space, name="m_ref").interpolate( + fd.exp(x) * fd.sin(fd.pi * x) * fd.sin(fd.pi * y)) + m_ref, _, u_ref = forward(m_ref) + forward_J = partial(forward_J, u_ref=u_ref, alpha=1) + + continue_annotation() + m_0 = fd.Function(space, name="m_0") + J = forward_J(m_0) + pause_annotation() + c = Control(m_0) + + J_hat = ReducedFunctional(J, c) + + def error_norm(m): + m, _ = pre_process(m) + return fd.norm(m - m_ref, norm_type="L2") + + problem = MinimizationProblem(J_hat) + solver = TAOSolver(problem, {"tao_type": "nls", + "tao_monitor": None, + "tao_converged_reason": None, + "tao_gatol": 1.0e-5, + "tao_grtol": 0.0, + "tao_gttol": 1.0e-6}) + m_opt = solver.solve() + error_norm_opt = error_norm(m_opt) + print(f"{error_norm_opt=:.6g}") + assert 1e-2 < error_norm_opt < 5e-2 + assert J_hat._test_transformed_functional__ncalls > 22 # == 24 + + J_hat = L2TransformedFunctional(J, c, alpha=1e-5) + + def error_norm(m): + m = J_hat.map_result(m) + m, _ = pre_process(m) + return fd.norm(m - m_ref, norm_type="L2") + + problem = MinimizationProblem(J_hat) + solver = TAOSolver(problem, {"tao_type": "nls", + "tao_monitor": None, + "tao_converged_reason": None, + "tao_gatol": 1.0e-5, + "tao_grtol": 0.0, + "tao_gttol": 1.0e-6}) + m_opt = solver.solve() + error_norm_opt = error_norm(m_opt) + print(f"{error_norm_opt=:.6g}") + assert 1e-3 < error_norm_opt < 1e-2 + assert J_hat._test_transformed_functional__ncalls < 18 # == 16