diff --git a/firedrake/ensemble/ensemble.py b/firedrake/ensemble/ensemble.py index 49ab7e2c03..059dc9e905 100644 --- a/firedrake/ensemble/ensemble.py +++ b/firedrake/ensemble/ensemble.py @@ -1,12 +1,35 @@ +from functools import wraps import weakref from itertools import zip_longest from firedrake.petsc import PETSc +from firedrake.function import Function +from firedrake.cofunction import Cofunction from pyop2.mpi import MPI, internal_comm __all__ = ("Ensemble", ) +def _ensemble_mpi_dispatch(func): + """ + This wrapper checks if any arg or kwarg of the wrapped + ensemble method is a Function or Cofunction, and if so + it calls the specialised Firedrake implementation. + Otherwise the standard mpi4py implementation is called. + """ + @wraps(func) + def _mpi_dispatch(self, *args, **kwargs): + if any(isinstance(arg, (Function, Cofunction)) + for arg in [*args, *kwargs.values()]): + return func(self, *args, **kwargs) + else: + mpicall = getattr( + self.ensemble_comm, + func.__name__) + return mpicall(*args, **kwargs) + return _mpi_dispatch + + class Ensemble(object): def __init__(self, comm, M, **kwargs): """ @@ -79,6 +102,7 @@ def _check_function(self, f, g=None): raise ValueError("Mismatching function spaces for functions") @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def allreduce(self, f, f_reduced, op=MPI.SUM): """ Allreduce a function f into f_reduced over ``ensemble_comm`` . @@ -96,6 +120,7 @@ def allreduce(self, f, f_reduced, op=MPI.SUM): return f_reduced @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def iallreduce(self, f, f_reduced, op=MPI.SUM): """ Allreduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` . @@ -113,6 +138,7 @@ def iallreduce(self, f, f_reduced, op=MPI.SUM): for fdat, rdat in zip(f.dat, f_reduced.dat)] @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def reduce(self, f, f_reduced, op=MPI.SUM, root=0): """ Reduce a function f into f_reduced over ``ensemble_comm`` to rank root @@ -136,6 +162,7 @@ def reduce(self, f, f_reduced, op=MPI.SUM, root=0): return f_reduced @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def ireduce(self, f, f_reduced, op=MPI.SUM, root=0): """ Reduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` to rank root @@ -154,6 +181,7 @@ def ireduce(self, f, f_reduced, op=MPI.SUM, root=0): for fdat, rdat in zip(f.dat, f_reduced.dat)] @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def bcast(self, f, root=0): """ Broadcast a function f over ``ensemble_comm`` from rank root @@ -169,6 +197,7 @@ def bcast(self, f, root=0): return f @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def ibcast(self, f, root=0): """ Broadcast (non-blocking) a function f over ``ensemble_comm`` from rank root @@ -184,6 +213,7 @@ def ibcast(self, f, root=0): for dat in f.dat] @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def send(self, f, dest, tag=0): """ Send (blocking) a function f over ``ensemble_comm`` to another @@ -199,6 +229,7 @@ def send(self, f, dest, tag=0): self._ensemble_comm.Send(dat.data_ro, dest=dest, tag=tag) @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def recv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, statuses=None): """ Receive (blocking) a function f over ``ensemble_comm`` from @@ -215,8 +246,10 @@ def recv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, statuses=None): raise ValueError("Need to provide enough status objects for all parts of the Function") for dat, status in zip_longest(f.dat, statuses or (), fillvalue=None): self._ensemble_comm.Recv(dat.data, source=source, tag=tag, status=status) + return f @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def isend(self, f, dest, tag=0): """ Send (non-blocking) a function f over ``ensemble_comm`` to another @@ -233,6 +266,7 @@ def isend(self, f, dest, tag=0): for dat in f.dat] @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def irecv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG): """ Receive (non-blocking) a function f over ``ensemble_comm`` from @@ -249,6 +283,7 @@ def irecv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG): for dat in f.dat] @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def sendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, recvtag=MPI.ANY_TAG, status=None): """ Send (blocking) a function fsend and receive a function frecv over ``ensemble_comm`` to another @@ -270,8 +305,10 @@ def sendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, re self._ensemble_comm.Sendrecv(sendvec, dest, sendtag=sendtag, recvbuf=recvvec, source=source, recvtag=recvtag, status=status) + return frecv @PETSc.Log.EventDecorator() + @_ensemble_mpi_dispatch def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, recvtag=MPI.ANY_TAG): """ Send a function fsend and receive a function frecv over ``ensemble_comm`` to another diff --git a/tests/firedrake/ensemble/test_ensemble.py b/tests/firedrake/ensemble/test_ensemble.py index f7c3a9a893..6f9fef11d0 100644 --- a/tests/firedrake/ensemble/test_ensemble.py +++ b/tests/firedrake/ensemble/test_ensemble.py @@ -3,9 +3,6 @@ import pytest from pytest_mpi.parallel_assert import parallel_assert -from operator import mul -from functools import reduce - max_ncpts = 2 @@ -60,7 +57,7 @@ def W(request, mesh): if COMM_WORLD.size == 1: return V = FunctionSpace(mesh, "CG", 1) - return reduce(mul, [V for _ in range(request.param)]) + return MixedFunctionSpace([V for _ in range(request.param)]) # initialise unique function on each rank diff --git a/tests/firedrake/ensemble/test_ensemble_wrapper.py b/tests/firedrake/ensemble/test_ensemble_wrapper.py new file mode 100644 index 0000000000..69d2851338 --- /dev/null +++ b/tests/firedrake/ensemble/test_ensemble_wrapper.py @@ -0,0 +1,133 @@ +from firedrake import * +import pytest +from pytest_mpi.parallel_assert import parallel_assert + + +min_root = 0 +max_root = 1 + +roots = [] +roots.extend([pytest.param(None, id="root_none")]) +roots.extend([pytest.param(i, id=f"root_{i}") + for i in range(min_root, max_root + 1)]) + +blocking = [ + pytest.param(True, id="blocking"), + pytest.param(False, id="nonblocking") +] + +sendrecv_pairs = [ + pytest.param((0, 1), id="ranks01"), + pytest.param((1, 2), id="ranks12"), + pytest.param((2, 0), id="ranks20") +] + + +@pytest.fixture(scope="module") +def ensemble(): + if COMM_WORLD.size == 1: + return + return Ensemble(COMM_WORLD, 1) + + +@pytest.mark.parallel(nprocs=2) +def test_ensemble_allreduce(ensemble): + rank = ensemble.ensemble_rank + result = ensemble.allreduce(rank+1) + expected = sum([r+1 for r in range(ensemble.ensemble_size)]) + parallel_assert( + result == expected, + msg=f"{result=} does not match {expected=}") + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize("root", roots) +def test_ensemble_reduce(ensemble, root): + rank = ensemble.ensemble_rank + + # check default root=0 works + if root is None: + result = ensemble.reduce(rank+1) + root = 0 + else: + result = ensemble.reduce(rank+1, root=root) + + expected = sum([r+1 for r in range(ensemble.ensemble_size)]) + + parallel_assert( + result == expected, + participating=(rank == root), + msg=f"{result=} does not match {expected=} on rank {root=}" + ) + parallel_assert( + result is None, + participating=(rank != root), + msg=f"Unexpected {result=} on non-root rank" + ) + + +@pytest.mark.parallel(nprocs=2) +@pytest.mark.parametrize("root", roots) +def test_ensemble_bcast(ensemble, root): + rank = ensemble.ensemble_rank + + # check default root=0 works + if root is None: + result = ensemble.bcast(rank+1) + root = 0 + else: + result = ensemble.bcast(rank+1, root=root) + + expected = root + 1 + + parallel_assert(result == expected) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize("ranks", sendrecv_pairs) +def test_send_and_recv(ensemble, ranks): + rank = ensemble.ensemble_rank + + rank0, rank1 = ranks + + send_data = rank + 1 + + if rank == rank0: + recv_expected = rank1 + 1 + + ensemble.send(send_data, dest=rank1, tag=rank0) + recv_data = ensemble.recv(source=rank1, tag=rank1) + + elif rank == rank1: + recv_expected = rank0 + 1 + + recv_data = ensemble.recv(source=rank0, tag=rank0) + ensemble.send(send_data, dest=rank0, tag=rank1) + + else: + recv_expected = None + recv_data = None + + # Test send/recv between first two spatial comms + # ie: ensemble.ensemble_comm.rank == 0 and 1 + parallel_assert( + recv_data == recv_expected, + participating=rank in (rank0, rank1), + ) + + +@pytest.mark.parallel(nprocs=3) +def test_sendrecv(ensemble): + rank = ensemble.ensemble_rank + size = ensemble.ensemble_size + src_rank = (rank - 1) % size + dst_rank = (rank + 1) % size + + send_data = rank + 1 + recv_expected = src_rank + 1 + + recv_result = ensemble.sendrecv( + send_data, dst_rank, sendtag=rank, + source=src_rank, recvtag=src_rank) + + parallel_assert(recv_result == recv_expected)