Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demos/boussinesq/boussinesq.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ implements a boundary condition that fixes a field at a single point. ::

# Take the basis function with the largest abs value at bc_point
v = TestFunction(V)
F = assemble(Interpolate(inner(v, v), Fvom))
F = assemble(interpolate(inner(v, v), Fvom))
with F.dat.vec as Fvec:
max_index, _ = Fvec.max()
nodes = V.dof_dset.lgmap.applyInverse([max_index])
Expand Down
2 changes: 1 addition & 1 deletion demos/multicomponent/multicomponent.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ mathematically valid to do this)::

# Take the basis function with the largest abs value at bc_point
v = TestFunction(V)
F = assemble(Interpolate(inner(v, v), Fvom))
F = assemble(interpolate(inner(v, v), Fvom))
with F.dat.vec as Fvec:
max_index, _ = Fvec.max()
nodes = V.dof_dset.lgmap.applyInverse([max_index])
Expand Down
6 changes: 3 additions & 3 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def interpolate(self,
Parameters
----------
expression
A dual UFL expression to interpolate.
A UFL BaseForm to adjoint interpolate.
ad_block_tag
An optional string for tagging the resulting assemble
block on the Pyadjoint tape.
Expand All @@ -331,9 +331,9 @@ def interpolate(self,
firedrake.cofunction.Cofunction
Returns `self`
"""
from firedrake import interpolation, assemble
from firedrake import interpolate, assemble
v, = self.arguments()
interp = interpolation.Interpolate(v, expression, **kwargs)
interp = interpolate(v, expression, **kwargs)
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)

@property
Expand Down
10 changes: 5 additions & 5 deletions firedrake/external_operators/point_expr_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import firedrake.ufl_expr as ufl_expr
from firedrake.assemble import assemble
from firedrake.interpolation import Interpolate
from firedrake.interpolation import interpolate
from firedrake.external_operators import AbstractExternalOperator, assemble_method


Expand Down Expand Up @@ -58,7 +58,7 @@ def assemble_operator(self, *args, **kwargs):
V = self.function_space()
expr = as_ufl(self.expr(*self.ufl_operands))
if len(V) < 2:
interp = Interpolate(expr, self.function_space())
interp = interpolate(expr, self.function_space())
return assemble(interp)
# Interpolation of UFL expressions for mixed functions is not yet supported
# -> `Function.assign` might be enough in some cases.
Expand All @@ -72,7 +72,7 @@ def assemble_operator(self, *args, **kwargs):
def assemble_Jacobian_action(self, *args, **kwargs):
V = self.function_space()
expr = as_ufl(self.expr(*self.ufl_operands))
interp = Interpolate(expr, V)
interp = interpolate(expr, V)

u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
w = self.argument_slots()[-1]
Expand All @@ -83,7 +83,7 @@ def assemble_Jacobian_action(self, *args, **kwargs):
def assemble_Jacobian(self, *args, assembly_opts, **kwargs):
V = self.function_space()
expr = as_ufl(self.expr(*self.ufl_operands))
interp = Interpolate(expr, V)
interp = interpolate(expr, V)

u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
jac = ufl_expr.derivative(interp, u)
Expand All @@ -99,7 +99,7 @@ def assemble_Jacobian_adjoint(self, *args, assembly_opts, **kwargs):
def assemble_Jacobian_adjoint_action(self, *args, **kwargs):
V = self.function_space()
expr = as_ufl(self.expr(*self.ufl_operands))
interp = Interpolate(expr, V)
interp = interpolate(expr, V)

u, = [e for i, e in enumerate(self.ufl_operands) if self.derivatives[i] == 1]
ustar = self.argument_slots()[0]
Expand Down
8 changes: 4 additions & 4 deletions firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,9 @@ def interpolate(self,
firedrake.function.Function
Returns `self`
"""
from firedrake import interpolation, assemble
from firedrake import interpolate, assemble
V = self.function_space()
interp = interpolation.Interpolate(expression, V, **kwargs)
interp = interpolate(expression, V, **kwargs)
return assemble(interp, tensor=self, ad_block_tag=ad_block_tag)

def zero(self, subset=None):
Expand Down Expand Up @@ -697,7 +697,7 @@ def __init__(self, domain, point):
self.point = point

def __str__(self):
return "domain %s does not contain point %s" % (self.domain, self.point)
return f"Domain {self.domain} does not contain point {self.point}"


class PointEvaluator:
Expand All @@ -712,7 +712,7 @@ def __init__(self, mesh: MeshGeometry, points: np.ndarray | list, tolerance: flo
The mesh on which to embed the points.
points : numpy.ndarray | list
Array or list of points to evaluate at.
tolerance : float | None
tolerance : Optional[float]
Tolerance to use when checking if a point is in a cell.
If ``None`` (the default), the ``tolerance`` of the ``mesh`` is used.
missing_points_behaviour : str
Expand Down
68 changes: 21 additions & 47 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import abc
import warnings
from collections.abc import Iterable
from typing import Literal
from functools import partial, singledispatch
from typing import Hashable
from typing import Hashable, Literal

import FIAT
import ufl
import finat.ufl
from ufl.algorithms import extract_arguments, extract_coefficients, replace
from ufl.algorithms import extract_arguments, extract_coefficients
from ufl.domain import as_domain, extract_unique_domain
from ufl.duals import is_dual

from pyop2 import op2
from pyop2.caching import memory_and_disk_cache
Expand All @@ -25,13 +25,11 @@
import finat

import firedrake
import firedrake.bcs
from firedrake import tsfc_interface, utils, functionspaceimpl
from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology
from firedrake.petsc import PETSc
from firedrake.halo import _get_mtype as get_dat_mpi_type
from firedrake.cofunction import Cofunction
from mpi4py import MPI

from pyadjoint import stop_annotating, no_annotations
Expand All @@ -48,7 +46,7 @@

class Interpolate(ufl.Interpolate):

def __init__(self, expr, v,
def __init__(self, expr, V,
Copy link
Contributor

@pbrubeck pbrubeck Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very unrelated, but I think that a much more friendly interface is to allow either or both left and right arguments to be a primal FunctionSpace.

Right now we do this under the hood

Interpolate(Function(V1), V2) -> Interpolate(Function(V1), Argument(V2.dual(), 0))

It'd be reasonable to have a similar shortcut for the adjoint. When the left argument is a FunctionSpace, we would then automatically create the Argument for it.

Interpolate(V1, Cofunction(V2.dual())) -> Interpolate(Argument(V1, 0), Cofunction(V2.dual()))

And supplying two FunctionSpaces is a perfectly natural interface:

Interpolate(V1, V2) -> Interpolate(Argument(V1, 1), Argument(V2.dual(), 0))

Of course we need to arbitrarily decide who gets the lowest number, the more intuitive numbering that produces the forward Interpolation is to go from right to left.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts @dham ?

subset=None,
access=None,
allow_missing_dofs=False,
Expand All @@ -60,7 +58,7 @@ def __init__(self, expr, v,
----------
expr : ufl.core.expr.Expr or ufl.BaseForm
The UFL expression to interpolate.
v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument
V : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument
The function space to interpolate into or the coargument defined
on the dual of the function space to interpolate into.
subset : pyop2.types.set.Subset
Expand Down Expand Up @@ -95,20 +93,18 @@ def __init__(self, expr, v,
between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast
and reduce operations.
"""
# Check function space
expr = ufl.as_ufl(expr)
if isinstance(v, functionspaceimpl.WithGeometry):
expr_args = extract_arguments(expr)
is_adjoint = len(expr_args) and expr_args[0].number() == 0
v = Argument(v.dual(), 1 if is_adjoint else 0)
if isinstance(V, functionspaceimpl.WithGeometry):
# Need to create a Firedrake Argument so that it has a .function_space() method
expr_arg_numbers = {arg.number() for arg in extract_arguments(expr) if not is_dual(arg)}
is_adjoint = len(expr_arg_numbers) and expr_arg_numbers == {0}
V = Argument(V.dual(), 1 if is_adjoint else 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put 0 if it is not there, otherwise keep increasing the count

Suggested change
V = Argument(V.dual(), 1 if is_adjoint else 0)
arg_numbers = [a.number() for a in expr_args]
number = 0 if len(expr_args) == 0 or min(arg_numbers) > 0 else max(arg_numbers) + 1
V = Argument(V.dual(), number)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_adjoint goes away.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar fixes will be needed in UFL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In UFL we're checking for contiguous argument numbers, I think this approach is more fool-proof: just get the smallest number that's missing

Suggested change
V = Argument(V.dual(), 1 if is_adjoint else 0)
arg_numbers = set(a.number() for a in expr_args)
number = min(set(range(len(expr_args) + 1)) - arg_numbers)
V = Argument(V.dual(), number)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expr_args can only contain zero or one arguments so I think this logic is more confusing than is_adjoint


V = v.arguments()[0].function_space()
if len(expr.ufl_shape) != len(V.value_shape):
raise RuntimeError(f'Rank mismatch: Expression rank {len(expr.ufl_shape)}, FunctionSpace rank {len(V.value_shape)}')
target_shape = V.arguments()[0].function_space().value_shape
if expr.ufl_shape != target_shape:
raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {target_shape}.")

if expr.ufl_shape != V.value_shape:
raise RuntimeError('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}')
super().__init__(expr, v)
super().__init__(expr, V)

# -- Interpolate data (e.g. `subset` or `access`) -- #
self.interp_data = {"subset": subset,
Expand Down Expand Up @@ -174,32 +170,10 @@ def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, def
reduction (hence using MIN will compute the MIN between the
existing values and any new values).
"""
if isinstance(V, (Cofunction, Coargument)):
dual_arg = V
elif isinstance(V, ufl.BaseForm):
rank = len(V.arguments())
if rank == 1:
dual_arg = V
else:
raise TypeError(f"Expected a one-form, provided form had {rank} arguments")
elif isinstance(V, functionspaceimpl.WithGeometry):
dual_arg = Coargument(V.dual(), 0)
expr_args = extract_arguments(ufl.as_ufl(expr))
if expr_args and expr_args[0].number() == 0:
warnings.warn("Passing argument numbered 0 in expression for forward interpolation is deprecated. "
"Use a TrialFunction in the expression.")
v, = expr_args
expr = replace(expr, {v: v.reconstruct(number=1)})
else:
raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}")

interp = Interpolate(expr, dual_arg,
subset=subset, access=access,
allow_missing_dofs=allow_missing_dofs,
default_missing_val=default_missing_val,
matfree=matfree)

return interp
return Interpolate(
expr, V, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs,
default_missing_val=default_missing_val, matfree=matfree
)


class Interpolator(abc.ABC):
Expand Down Expand Up @@ -528,7 +502,7 @@ def __init__(

from firedrake.assemble import assemble
V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element)
f_dest_node_coords = Interpolate(dest_mesh.coordinates, V_dest_vec)
f_dest_node_coords = interpolate(dest_mesh.coordinates, V_dest_vec)
f_dest_node_coords = assemble(f_dest_node_coords)
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim)
try:
Expand All @@ -553,15 +527,15 @@ def __init__(
else:
fs_type = partial(firedrake.TensorFunctionSpace, shape=shape)
P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0)
self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom)
self.point_eval_interpolate = interpolate(self.expr_renumbered, P0DG_vom)
# The parallel decomposition of the nodes of V_dest in the DESTINATION
# mesh (dest_mesh) is retrieved using the input_ordering attribute of the
# VOM. This again is an interpolation operation, which, under the hood
# is a PETSc SF reduce.
P0DG_vom_i_o = fs_type(
self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0
)
self.to_input_ordering_interpolate = Interpolate(
self.to_input_ordering_interpolate = interpolate(
firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o
)
# The P0DG function outputted by the above interpolation has the
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4140,7 +4140,7 @@ def _parent_mesh_embedding(
# nessesary, to other processes.
P0DG = functionspace.FunctionSpace(parent_mesh, "DG", 0)
with stop_annotating():
visible_ranks = interpolation.Interpolate(
visible_ranks = interpolation.interpolate(
constant.Constant(parent_mesh.comm.rank), P0DG
)
visible_ranks = assemble(visible_ranks).dat.data_ro_with_halos.real
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def physical_node_locations(V):
Vc = V.collapse().reconstruct(element=finat.ufl.VectorElement(element, dim=mesh.geometric_dimension()))

# FIXME: This is unsafe for DG coordinates and CG target spaces.
locations = firedrake.assemble(firedrake.Interpolate(firedrake.SpatialCoordinate(mesh), Vc))
locations = firedrake.assemble(firedrake.interpolate(firedrake.SpatialCoordinate(mesh), Vc))
return cache.setdefault(key, locations)


Expand Down
4 changes: 2 additions & 2 deletions firedrake/preconditioners/gtmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from firedrake.petsc import PETSc
from firedrake.preconditioners.base import PCBase
from firedrake.parameters import parameters
from firedrake.interpolation import Interpolate
from firedrake.interpolation import interpolate
from firedrake.solving_utils import _SNESContext
from firedrake.matrix_free.operators import ImplicitMatrixContext
import firedrake.dmhooks as dmhooks
Expand Down Expand Up @@ -155,7 +155,7 @@ def initialize(self, pc):
# Create interpolation matrix from coarse space to fine space
fine_space = ctx.J.arguments()[0].function_space()
coarse_test, coarse_trial = coarse_operator.arguments()
interp = assemble(Interpolate(coarse_trial, fine_space))
interp = assemble(interpolate(coarse_trial, fine_space))
interp_petscmat = interp.petscmat
restr_petscmat = appctx.get("restriction_matrix", None)

Expand Down
6 changes: 3 additions & 3 deletions firedrake/preconditioners/hypre_ads.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from firedrake.preconditioners.base import PCBase
from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.ufl_expr import TestFunction
from firedrake.ufl_expr import TrialFunction
from firedrake.dmhooks import get_function_space
from firedrake.preconditioners.hypre_ams import chop
from firedrake.interpolation import interpolate
Expand Down Expand Up @@ -31,12 +31,12 @@ def initialize(self, obj):
NC1 = V.reconstruct(family="N1curl" if mesh.ufl_cell().is_simplex() else "NCE", degree=1)
G_callback = appctx.get("get_gradient", None)
if G_callback is None:
G = chop(assemble(interpolate(grad(TestFunction(P1)), NC1)).petscmat)
G = chop(assemble(interpolate(grad(TrialFunction(P1)), NC1)).petscmat)
else:
G = G_callback(P1, NC1)
C_callback = appctx.get("get_curl", None)
if C_callback is None:
C = chop(assemble(interpolate(curl(TestFunction(NC1)), V)).petscmat)
C = chop(assemble(interpolate(curl(TrialFunction(NC1)), V)).petscmat)
else:
C = C_callback(NC1, V)

Expand Down
4 changes: 2 additions & 2 deletions firedrake/preconditioners/hypre_ams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from firedrake.preconditioners.base import PCBase
from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.ufl_expr import TestFunction
from firedrake.ufl_expr import TrialFunction
from firedrake.dmhooks import get_function_space
from firedrake.utils import complex_mode
from firedrake.interpolation import interpolate
Expand Down Expand Up @@ -51,7 +51,7 @@ def initialize(self, obj):
P1 = V.reconstruct(family="Lagrange", degree=1)
G_callback = appctx.get("get_gradient", None)
if G_callback is None:
G = chop(assemble(interpolate(grad(TestFunction(P1)), V)).petscmat)
G = chop(assemble(interpolate(grad(TrialFunction(P1)), V)).petscmat)
else:
G = G_callback(P1, V)

Expand Down
4 changes: 2 additions & 2 deletions firedrake/preconditioners/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from firedrake.solving_utils import _SNESContext
from firedrake.utils import cached_property, complex_mode, IntType
from firedrake.dmhooks import get_appctx, push_appctx, pop_appctx
from firedrake.interpolation import Interpolate
from firedrake.interpolation import interpolate

from collections import namedtuple
import operator
Expand Down Expand Up @@ -660,7 +660,7 @@ def sort_entities(self, dm, axis, dir, ndiv=None, divisions=None):
# with access descriptor MAX to define a consistent opinion
# about where the vertices are.
CGk = V.reconstruct(family="Lagrange")
coordinates = assemble(Interpolate(coordinates, CGk, access=op2.MAX))
coordinates = assemble(interpolate(coordinates, CGk, access=op2.MAX))

select = partial(select_entity, dm=dm, exclude="pyop2_ghost")
entities = [(p, self.coords(dm, p, coordinates)) for p in
Expand Down
Loading