-
Notifications
You must be signed in to change notification settings - Fork 173
Simplify interpolate
#4582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Simplify interpolate
#4582
Changes from all commits
d286b36
d930a1b
7f51013
c2f5476
0c4ce97
cf81591
c1b93f6
8f764c7
e9e92dd
ba75cd5
83ef532
8adf841
9f01a4b
d8093ec
d15da75
e9d6ba9
52c6203
aff9524
93406b8
efedc48
f8c2318
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,15 +4,15 @@ | |||||||||||||||||
import abc | ||||||||||||||||||
import warnings | ||||||||||||||||||
from collections.abc import Iterable | ||||||||||||||||||
from typing import Literal | ||||||||||||||||||
from functools import partial, singledispatch | ||||||||||||||||||
from typing import Hashable | ||||||||||||||||||
from typing import Hashable, Literal | ||||||||||||||||||
|
||||||||||||||||||
import FIAT | ||||||||||||||||||
import ufl | ||||||||||||||||||
import finat.ufl | ||||||||||||||||||
from ufl.algorithms import extract_arguments, extract_coefficients, replace | ||||||||||||||||||
from ufl.algorithms import extract_arguments, extract_coefficients | ||||||||||||||||||
from ufl.domain import as_domain, extract_unique_domain | ||||||||||||||||||
from ufl.duals import is_dual | ||||||||||||||||||
|
||||||||||||||||||
from pyop2 import op2 | ||||||||||||||||||
from pyop2.caching import memory_and_disk_cache | ||||||||||||||||||
|
@@ -25,13 +25,11 @@ | |||||||||||||||||
import finat | ||||||||||||||||||
|
||||||||||||||||||
import firedrake | ||||||||||||||||||
import firedrake.bcs | ||||||||||||||||||
from firedrake import tsfc_interface, utils, functionspaceimpl | ||||||||||||||||||
from firedrake.ufl_expr import Argument, Coargument, action, adjoint as expr_adjoint | ||||||||||||||||||
from firedrake.mesh import MissingPointsBehaviour, VertexOnlyMeshMissingPointsError, VertexOnlyMeshTopology | ||||||||||||||||||
from firedrake.petsc import PETSc | ||||||||||||||||||
from firedrake.halo import _get_mtype as get_dat_mpi_type | ||||||||||||||||||
from firedrake.cofunction import Cofunction | ||||||||||||||||||
from mpi4py import MPI | ||||||||||||||||||
|
||||||||||||||||||
from pyadjoint import stop_annotating, no_annotations | ||||||||||||||||||
|
@@ -48,7 +46,7 @@ | |||||||||||||||||
|
||||||||||||||||||
class Interpolate(ufl.Interpolate): | ||||||||||||||||||
|
||||||||||||||||||
def __init__(self, expr, v, | ||||||||||||||||||
def __init__(self, expr, V, | ||||||||||||||||||
subset=None, | ||||||||||||||||||
access=None, | ||||||||||||||||||
allow_missing_dofs=False, | ||||||||||||||||||
|
@@ -60,7 +58,7 @@ def __init__(self, expr, v, | |||||||||||||||||
---------- | ||||||||||||||||||
expr : ufl.core.expr.Expr or ufl.BaseForm | ||||||||||||||||||
The UFL expression to interpolate. | ||||||||||||||||||
v : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument | ||||||||||||||||||
V : firedrake.functionspaceimpl.WithGeometryBase or firedrake.ufl_expr.Coargument | ||||||||||||||||||
The function space to interpolate into or the coargument defined | ||||||||||||||||||
on the dual of the function space to interpolate into. | ||||||||||||||||||
subset : pyop2.types.set.Subset | ||||||||||||||||||
|
@@ -95,20 +93,18 @@ def __init__(self, expr, v, | |||||||||||||||||
between a VOM and its input ordering. Defaults to ``True`` which uses SF broadcast | ||||||||||||||||||
and reduce operations. | ||||||||||||||||||
""" | ||||||||||||||||||
# Check function space | ||||||||||||||||||
expr = ufl.as_ufl(expr) | ||||||||||||||||||
if isinstance(v, functionspaceimpl.WithGeometry): | ||||||||||||||||||
expr_args = extract_arguments(expr) | ||||||||||||||||||
is_adjoint = len(expr_args) and expr_args[0].number() == 0 | ||||||||||||||||||
v = Argument(v.dual(), 1 if is_adjoint else 0) | ||||||||||||||||||
if isinstance(V, functionspaceimpl.WithGeometry): | ||||||||||||||||||
# Need to create a Firedrake Argument so that it has a .function_space() method | ||||||||||||||||||
expr_arg_numbers = {arg.number() for arg in extract_arguments(expr) if not is_dual(arg)} | ||||||||||||||||||
is_adjoint = len(expr_arg_numbers) and expr_arg_numbers == {0} | ||||||||||||||||||
V = Argument(V.dual(), 1 if is_adjoint else 0) | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put 0 if it is not there, otherwise keep increasing the count
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar fixes will be needed in UFL There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In UFL we're checking for contiguous argument numbers, I think this approach is more fool-proof: just get the smallest number that's missing
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||
|
||||||||||||||||||
V = v.arguments()[0].function_space() | ||||||||||||||||||
if len(expr.ufl_shape) != len(V.value_shape): | ||||||||||||||||||
raise RuntimeError(f'Rank mismatch: Expression rank {len(expr.ufl_shape)}, FunctionSpace rank {len(V.value_shape)}') | ||||||||||||||||||
target_shape = V.arguments()[0].function_space().value_shape | ||||||||||||||||||
if expr.ufl_shape != target_shape: | ||||||||||||||||||
raise ValueError(f"Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {target_shape}.") | ||||||||||||||||||
|
||||||||||||||||||
if expr.ufl_shape != V.value_shape: | ||||||||||||||||||
raise RuntimeError('Shape mismatch: Expression shape {expr.ufl_shape}, FunctionSpace shape {V.value_shape}') | ||||||||||||||||||
super().__init__(expr, v) | ||||||||||||||||||
super().__init__(expr, V) | ||||||||||||||||||
|
||||||||||||||||||
# -- Interpolate data (e.g. `subset` or `access`) -- # | ||||||||||||||||||
self.interp_data = {"subset": subset, | ||||||||||||||||||
|
@@ -174,32 +170,10 @@ def interpolate(expr, V, subset=None, access=None, allow_missing_dofs=False, def | |||||||||||||||||
reduction (hence using MIN will compute the MIN between the | ||||||||||||||||||
existing values and any new values). | ||||||||||||||||||
""" | ||||||||||||||||||
if isinstance(V, (Cofunction, Coargument)): | ||||||||||||||||||
dual_arg = V | ||||||||||||||||||
elif isinstance(V, ufl.BaseForm): | ||||||||||||||||||
rank = len(V.arguments()) | ||||||||||||||||||
if rank == 1: | ||||||||||||||||||
dual_arg = V | ||||||||||||||||||
else: | ||||||||||||||||||
raise TypeError(f"Expected a one-form, provided form had {rank} arguments") | ||||||||||||||||||
elif isinstance(V, functionspaceimpl.WithGeometry): | ||||||||||||||||||
dual_arg = Coargument(V.dual(), 0) | ||||||||||||||||||
expr_args = extract_arguments(ufl.as_ufl(expr)) | ||||||||||||||||||
if expr_args and expr_args[0].number() == 0: | ||||||||||||||||||
warnings.warn("Passing argument numbered 0 in expression for forward interpolation is deprecated. " | ||||||||||||||||||
"Use a TrialFunction in the expression.") | ||||||||||||||||||
v, = expr_args | ||||||||||||||||||
expr = replace(expr, {v: v.reconstruct(number=1)}) | ||||||||||||||||||
else: | ||||||||||||||||||
raise TypeError(f"V must be a FunctionSpace, Cofunction, Coargument or one-form, not a {type(V).__name__}") | ||||||||||||||||||
|
||||||||||||||||||
interp = Interpolate(expr, dual_arg, | ||||||||||||||||||
subset=subset, access=access, | ||||||||||||||||||
allow_missing_dofs=allow_missing_dofs, | ||||||||||||||||||
default_missing_val=default_missing_val, | ||||||||||||||||||
matfree=matfree) | ||||||||||||||||||
|
||||||||||||||||||
return interp | ||||||||||||||||||
return Interpolate( | ||||||||||||||||||
expr, V, subset=subset, access=access, allow_missing_dofs=allow_missing_dofs, | ||||||||||||||||||
default_missing_val=default_missing_val, matfree=matfree | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
class Interpolator(abc.ABC): | ||||||||||||||||||
|
@@ -528,7 +502,7 @@ def __init__( | |||||||||||||||||
|
||||||||||||||||||
from firedrake.assemble import assemble | ||||||||||||||||||
V_dest_vec = firedrake.VectorFunctionSpace(dest_mesh, ufl_scalar_element) | ||||||||||||||||||
f_dest_node_coords = Interpolate(dest_mesh.coordinates, V_dest_vec) | ||||||||||||||||||
f_dest_node_coords = interpolate(dest_mesh.coordinates, V_dest_vec) | ||||||||||||||||||
f_dest_node_coords = assemble(f_dest_node_coords) | ||||||||||||||||||
dest_node_coords = f_dest_node_coords.dat.data_ro.reshape(-1, dest_mesh_gdim) | ||||||||||||||||||
try: | ||||||||||||||||||
|
@@ -553,15 +527,15 @@ def __init__( | |||||||||||||||||
else: | ||||||||||||||||||
fs_type = partial(firedrake.TensorFunctionSpace, shape=shape) | ||||||||||||||||||
P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0) | ||||||||||||||||||
self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom) | ||||||||||||||||||
self.point_eval_interpolate = interpolate(self.expr_renumbered, P0DG_vom) | ||||||||||||||||||
# The parallel decomposition of the nodes of V_dest in the DESTINATION | ||||||||||||||||||
# mesh (dest_mesh) is retrieved using the input_ordering attribute of the | ||||||||||||||||||
# VOM. This again is an interpolation operation, which, under the hood | ||||||||||||||||||
# is a PETSc SF reduce. | ||||||||||||||||||
P0DG_vom_i_o = fs_type( | ||||||||||||||||||
self.vom_dest_node_coords_in_src_mesh.input_ordering, "DG", 0 | ||||||||||||||||||
) | ||||||||||||||||||
self.to_input_ordering_interpolate = Interpolate( | ||||||||||||||||||
self.to_input_ordering_interpolate = interpolate( | ||||||||||||||||||
firedrake.TrialFunction(P0DG_vom), P0DG_vom_i_o | ||||||||||||||||||
) | ||||||||||||||||||
# The P0DG function outputted by the above interpolation has the | ||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very unrelated, but I think that a much more friendly interface is to allow either or both left and right arguments to be a primal FunctionSpace.
Right now we do this under the hood
Interpolate(Function(V1), V2) -> Interpolate(Function(V1), Argument(V2.dual(), 0))
It'd be reasonable to have a similar shortcut for the adjoint. When the left argument is a FunctionSpace, we would then automatically create the Argument for it.
Interpolate(V1, Cofunction(V2.dual())) -> Interpolate(Argument(V1, 0), Cofunction(V2.dual()))
And supplying two FunctionSpaces is a perfectly natural interface:
Interpolate(V1, V2) -> Interpolate(Argument(V1, 1), Argument(V2.dual(), 0))
Of course we need to arbitrarily decide who gets the lowest number, the more intuitive numbering that produces the forward Interpolation is to go from right to left.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts @dham ?