diff --git a/firedrake/function.py b/firedrake/function.py index a628ac6599..2714957de0 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -27,7 +27,7 @@ from firedrake import utils from firedrake.adjoint_utils import FunctionMixin from firedrake.petsc import PETSc -from firedrake.mesh import MeshGeometry, VertexOnlyMesh +from firedrake.mesh import MeshGeometry, VertexOnlyMesh, VertexOnlyMeshTopology from firedrake.functionspace import FunctionSpace, VectorFunctionSpace, TensorFunctionSpace @@ -282,6 +282,9 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType, if isinstance(function_space, Function): self.assign(function_space) + if isinstance(V._mesh, VertexOnlyMeshTopology): + V._mesh.register_field(self) + @property def topological(self): r"""The underlying coordinateless function.""" diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 87757a2e54..762bb2476a 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -1904,8 +1904,17 @@ def __init__(self, swarm, parentmesh, name, reorder, input_ordering_swarm=None, "overlap_type": (DistributedMeshOverlapType.NONE, 0)} self.input_ordering_swarm = input_ordering_swarm self._parent_mesh = parentmesh + self._fields = weakref.WeakSet() super().__init__(swarm, name, reorder, None, perm_is, distribution_name, permutation_name, parentmesh.comm) + def register_field(self, f) -> None: + self._fields.add(f) + + def update_fields(self) -> None: + # Update all registered fields after VOM has moved + for field in self._fields: + print(f"Hello from {field.name()}, defined on {field.function_space()}") + def _distribute(self): pass @@ -2175,6 +2184,70 @@ def input_ordering_without_halos_sf(self): # cells first; self.cell_set.size is the number of rank-local non-halo cells. return self.input_ordering_sf.createEmbeddedLeafSF(np.arange(self.cell_set.size, dtype=IntType)) + def _update_swarm( + self, new_coords: np.ndarray, new_global_idxs: np.ndarray, new_ref_coords: np.ndarray, + new_parent_cell_nums: np.ndarray, new_ranks: np.ndarray + ) -> None: + """Updates the VOM's DMSwarm to new coordinates. Assumes there are N new coordinates, + where N is the total number of vertices in the VOM. + + Parameters + ---------- + new_coords : np.array + An (N, gdim) array of new global coordinates for the N vertices, + where gdim is the geometric dimension of the VOM / parent mesh. + new_global_idxs : np.array + An (N,) array of new global indices for the N vertices. + new_ref_coords : np.array + An (N, tdim) array of new reference coordinates for the N vertices, + where tdim is the topological dimension of the parent mesh. + new_parent_cell_nums : np.array + An (N,) array of new parent cell numbers for the N vertices. + new_ranks : np.array + An (N,) array of new MPI ranks for the N vertices. + + """ + num_vertices = new_global_idxs.shape[0] + gdim = self.geometric_dimension() + + if num_vertices != new_coords.shape[0]: + raise ValueError("Number of new coordinates does not match number of global indices") + if gdim != new_coords.shape[1]: + raise ValueError("New coordinates do not have the same geometric dimension as the mesh") + if num_vertices == 0: + raise ValueError("No points to move") + if np.unique(new_global_idxs).shape[0] != num_vertices: + raise ValueError("global_indices must be unique") + if isinstance(self._parent_mesh, ExtrudedMeshTopology): + raise NotImplementedError("move_points is not implemented for extruded meshes yet") + + # mesh.topology.cell_closure[:, -1] maps Firedrake cell numbers to DMplex numbers + new_plex_parent_cell_nums = self._parent_mesh.topology.cell_closure[new_parent_cell_nums, -1] + + tdim = self._parent_mesh.topological_dimension + swarm = self.topology_dm + + current_coords = swarm.getField("DMSwarmPIC_coor").reshape((num_vertices, gdim)) + current_dmplex_parent_cell_nums = swarm.getField("DMSwarm_cellid").ravel() + current_parent_cell_nums = swarm.getField("parentcellnum").ravel() + current_ref_coords = swarm.getField("refcoord").reshape((num_vertices, tdim)) + current_global_idxs = swarm.getField("globalindex").ravel() + current_ranks = swarm.getField("DMSwarm_rank").ravel() + + current_coords[...] = new_coords + current_dmplex_parent_cell_nums[...] = new_plex_parent_cell_nums + current_parent_cell_nums[...] = new_parent_cell_nums + current_ref_coords[...] = new_ref_coords + current_global_idxs[...] = new_global_idxs + current_ranks[...] = new_ranks + + swarm.restoreField("DMSwarm_rank") + swarm.restoreField("globalindex") + swarm.restoreField("refcoord") + swarm.restoreField("parentcellnum") + swarm.restoreField("DMSwarmPIC_coor") + swarm.restoreField("DMSwarm_cellid") + class CellOrientationsRuntimeError(RuntimeError): """Exception raised when there are problems with cell orientations.""" @@ -3456,22 +3529,23 @@ def other_fields(self, fields): def _pic_swarm_in_mesh( - parent_mesh, - coords, - fields=None, - tolerance=None, - redundant=True, - exclude_halos=True, -): - """Create a Particle In Cell (PIC) DMSwarm immersed in a Mesh + parent_mesh: AbstractMeshTopology, + coords: np.ndarray, + fields: list[Tuple[str, int, np.dtype]] | None = None, + tolerance: float | None = None, + redundant: bool = True, + exclude_halos: bool = True, +) -> Tuple[FiredrakeDMSwarm, FiredrakeDMSwarm, int]: + """Creates a Particle In Cell (PIC) DMSwarm immersed in a Mesh. - This should only by used for meshes with straight edges. If not, the - particles may be placed in the wrong cells. - - :arg parent_mesh: the :class:`Mesh` within with the DMSwarm should be - immersed. - :arg coords: an ``ndarray`` of (npoints, coordsdim) shape. - :kwarg fields: An optional list of named data which can be stored for each + Parameters + ---------- + parent_mesh + The parent mesh in which the DMSwarm should be immersed. + coords + An array of shape (npoints, coordsdim) defining the point coordinates. + fields + An optional list of named data which can be stored for each point in the DMSwarm. The format should be:: [(fieldname1, blocksize1, dtype1), @@ -3484,7 +3558,8 @@ def _pic_swarm_in_mesh( RealType)]``. All fields must have the same number of points. For more information see `the DMSWARM API reference _. - :kwarg tolerance: The relative tolerance (i.e. as defined on the reference + tolerance + The relative tolerance (i.e. as defined on the reference cell) for the distance a point can be from a cell and still be considered to be in the cell. Note that this tolerance uses an L1 distance (aka 'manhattan', 'taxicab' or rectilinear distance) so @@ -3492,21 +3567,26 @@ def _pic_swarm_in_mesh( mesh's ``tolerance`` property. Changing this from default will cause the parent mesh's spatial index to be rebuilt which can take some time. - :kwarg redundant: If True, the DMSwarm will be created using only the - points specified on MPI rank 0. - :kwarg exclude_halos: If True, the DMSwarm will not contain any points in + redundant + If True, the DMSwarm will be created using only the + points specified on MPI rank 0. Defaults to True. + exclude_halos + If True, the DMSwarm will not contain any points in the mesh halos. If False, it will but the global index of the points in the halos will match a global index of a point which is not in the - halo. - :returns: (swarm, input_ordering_swarm, n_missing_points) - - swarm: the immersed DMSwarm - - input_ordering_swarm: a DMSwarm with points in the same order and with the - same rank decomposition as the supplied ``coords`` argument. This - includes any points which are not found in the parent mesh! Note - that if ``redundant=True``, all points in the generated DMSwarm - will be found on rank 0 since that was where they were taken from. - - n_missing_points: the number of points in the supplied ``coords`` - argument which were not found in the parent mesh. + halo. Defaults to True. + + Returns + ------- + (swarm, input_ordering_swarm, n_missing_points) + - swarm: the immersed DMSwarm + - input_ordering_swarm: a DMSwarm with points in the same order and with the + same rank decomposition as the supplied ``coords`` argument. This + includes any points which are not found in the parent mesh! Note + that if ``redundant=True``, all points in the generated DMSwarm + will be found on rank 0 since that was where they were taken from. + - n_missing_points: the number of points in the supplied ``coords`` + argument which were not found in the parent mesh. .. note:: @@ -3569,9 +3649,7 @@ def _pic_swarm_in_mesh( directly with PETSc's DMSwarm API. For the ``swarm`` output, this is the parent mesh's topology DM (in most cases a DMPlex). For the ``input_ordering_swarm`` output, this is the ``swarm`` itself. - """ - if tolerance is None: tolerance = parent_mesh.tolerance else: @@ -4113,12 +4191,9 @@ def _parent_mesh_embedding( (ncoords_global, coords.shape[1]), dtype=coords_local.dtype ) parent_mesh._comm.Allgatherv(coords_local, (coords_global, coords_local_sizes)) - # # ncoords_local_allranks is in rank order so we can just sum up the - # # previous ranks to get the starting index for the global numbering. - # # For rank 0 we make use of the fact that sum([]) = 0. - # startidx = sum(ncoords_local_allranks[:parent_mesh._comm.rank]) - # endidx = startidx + ncoords_local - # global_idxs_global = np.arange(startidx, endidx) + # ncoords_local_allranks is in rank order so we can just sum up the + # previous ranks to get the starting index for the global numbering. + # For rank 0 we make use of the fact that sum([]) = 0. global_idxs_global = np.arange(coords_global.shape[0]) input_coords_idxs_local = np.arange(ncoords_local) input_coords_idxs_global = np.empty(ncoords_global, dtype=int) diff --git a/tests/firedrake/vertexonly/test_swarm.py b/tests/firedrake/vertexonly/test_swarm.py index 96c4602511..7ccf817a19 100644 --- a/tests/firedrake/vertexonly/test_swarm.py +++ b/tests/firedrake/vertexonly/test_swarm.py @@ -1,4 +1,5 @@ from firedrake import * +import firedrake.mesh as fd_mesh from firedrake.utils import IntType, RealType import pytest import numpy as np @@ -170,9 +171,9 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos): # global cell midpoints only on rank 0. Note that this is the default # behaviour so it needn't be specified explicitly. if MPI.COMM_WORLD.rank == 0: - swarm, original_swarm, n_missing_coords = mesh._pic_swarm_in_mesh(parentmesh, inputpointcoords, fields=other_fields, exclude_halos=exclude_halos) + swarm, original_swarm, n_missing_coords = fd_mesh._pic_swarm_in_mesh(parentmesh, inputpointcoords, fields=other_fields, exclude_halos=exclude_halos) else: - swarm, original_swarm, n_missing_coords = mesh._pic_swarm_in_mesh(parentmesh, np.empty(inputpointcoords.shape), fields=other_fields, exclude_halos=exclude_halos) + swarm, original_swarm, n_missing_coords = fd_mesh._pic_swarm_in_mesh(parentmesh, np.empty(inputpointcoords.shape), fields=other_fields, exclude_halos=exclude_halos) input_rank = 0 # inputcoordindices is the correct set of input indices for # redundant==True but I need to work out where they will be after @@ -191,7 +192,7 @@ def test_pic_swarm_in_mesh(parentmesh, redundant, exclude_halos): # When redundant == False we expect the same behaviour by only # supplying the local cell midpoints on each MPI ranks. Note that this # is not the default behaviour so it must be specified explicitly. - swarm, original_swarm, n_missing_coords = mesh._pic_swarm_in_mesh(parentmesh, inputlocalpointcoords, fields=other_fields, redundant=redundant, exclude_halos=exclude_halos) + swarm, original_swarm, n_missing_coords = fd_mesh._pic_swarm_in_mesh(parentmesh, inputlocalpointcoords, fields=other_fields, redundant=redundant, exclude_halos=exclude_halos) input_rank = parentmesh.comm.rank input_local_coord_indices = np.arange(len(inputlocalpointcoords))