Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions firedrake/ensemble/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
from functools import wraps
import weakref
from itertools import zip_longest

from firedrake.petsc import PETSc
from firedrake.function import Function
from firedrake.cofunction import Cofunction
from pyop2.mpi import MPI, internal_comm

__all__ = ("Ensemble", )


def _ensemble_mpi_dispatch(func):
@wraps(func)
def _mpi_dispatch(self, *args, **kwargs):
# dispatch to either our specialised impl
# for # Firedrake or the default MPI impl
# for everything else.
if any(isinstance(arg, (Function, Cofunction))
for arg in [*args, *kwargs.values()]):
return func(self, *args, **kwargs)
else:
mpicall = getattr(
self.ensemble_comm,
func.__name__)
return mpicall(*args, **kwargs)
return _mpi_dispatch


class Ensemble(object):
def __init__(self, comm, M, **kwargs):
"""
Expand Down Expand Up @@ -79,6 +99,7 @@ def _check_function(self, f, g=None):
raise ValueError("Mismatching function spaces for functions")

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def allreduce(self, f, f_reduced, op=MPI.SUM):
"""
Allreduce a function f into f_reduced over ``ensemble_comm`` .
Expand All @@ -96,6 +117,7 @@ def allreduce(self, f, f_reduced, op=MPI.SUM):
return f_reduced

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def iallreduce(self, f, f_reduced, op=MPI.SUM):
"""
Allreduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` .
Expand All @@ -113,6 +135,7 @@ def iallreduce(self, f, f_reduced, op=MPI.SUM):
for fdat, rdat in zip(f.dat, f_reduced.dat)]

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def reduce(self, f, f_reduced, op=MPI.SUM, root=0):
"""
Reduce a function f into f_reduced over ``ensemble_comm`` to rank root
Expand All @@ -136,6 +159,7 @@ def reduce(self, f, f_reduced, op=MPI.SUM, root=0):
return f_reduced

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def ireduce(self, f, f_reduced, op=MPI.SUM, root=0):
"""
Reduce (non-blocking) a function f into f_reduced over ``ensemble_comm`` to rank root
Expand All @@ -154,6 +178,7 @@ def ireduce(self, f, f_reduced, op=MPI.SUM, root=0):
for fdat, rdat in zip(f.dat, f_reduced.dat)]

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def bcast(self, f, root=0):
"""
Broadcast a function f over ``ensemble_comm`` from rank root
Expand All @@ -169,6 +194,7 @@ def bcast(self, f, root=0):
return f

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def ibcast(self, f, root=0):
"""
Broadcast (non-blocking) a function f over ``ensemble_comm`` from rank root
Expand All @@ -184,6 +210,7 @@ def ibcast(self, f, root=0):
for dat in f.dat]

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def send(self, f, dest, tag=0):
"""
Send (blocking) a function f over ``ensemble_comm`` to another
Expand All @@ -199,6 +226,7 @@ def send(self, f, dest, tag=0):
self._ensemble_comm.Send(dat.data_ro, dest=dest, tag=tag)

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def recv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, statuses=None):
"""
Receive (blocking) a function f over ``ensemble_comm`` from
Expand All @@ -215,8 +243,10 @@ def recv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, statuses=None):
raise ValueError("Need to provide enough status objects for all parts of the Function")
for dat, status in zip_longest(f.dat, statuses or (), fillvalue=None):
self._ensemble_comm.Recv(dat.data, source=source, tag=tag, status=status)
return f

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def isend(self, f, dest, tag=0):
"""
Send (non-blocking) a function f over ``ensemble_comm`` to another
Expand All @@ -233,6 +263,7 @@ def isend(self, f, dest, tag=0):
for dat in f.dat]

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def irecv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG):
"""
Receive (non-blocking) a function f over ``ensemble_comm`` from
Expand All @@ -249,6 +280,7 @@ def irecv(self, f, source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG):
for dat in f.dat]

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def sendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, recvtag=MPI.ANY_TAG, status=None):
"""
Send (blocking) a function fsend and receive a function frecv over ``ensemble_comm`` to another
Expand All @@ -270,8 +302,10 @@ def sendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, re
self._ensemble_comm.Sendrecv(sendvec, dest, sendtag=sendtag,
recvbuf=recvvec, source=source, recvtag=recvtag,
status=status)
return frecv

@PETSc.Log.EventDecorator()
@_ensemble_mpi_dispatch
def isendrecv(self, fsend, dest, sendtag=0, frecv=None, source=MPI.ANY_SOURCE, recvtag=MPI.ANY_TAG):
"""
Send a function fsend and receive a function frecv over ``ensemble_comm`` to another
Expand Down
5 changes: 1 addition & 4 deletions tests/firedrake/ensemble/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import pytest
from pytest_mpi.parallel_assert import parallel_assert

from operator import mul
from functools import reduce


max_ncpts = 2

Expand Down Expand Up @@ -60,7 +57,7 @@ def W(request, mesh):
if COMM_WORLD.size == 1:
return
V = FunctionSpace(mesh, "CG", 1)
return reduce(mul, [V for _ in range(request.param)])
return MixedFunctionSpace([V for _ in range(request.param)])


# initialise unique function on each rank
Expand Down
126 changes: 126 additions & 0 deletions tests/firedrake/ensemble/test_ensemble_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from firedrake import *
import pytest
from pytest_mpi.parallel_assert import parallel_assert


min_root = 1
max_root = 1
roots = [None] + [i for i in range(min_root, max_root + 1)]

roots = []
roots.extend([pytest.param(None, id="root_none")])
roots.extend([pytest.param(i, id="root_%d" % (i))
for i in range(min_root, max_root + 1)])

blocking = [pytest.param(True, id="blocking"),
pytest.param(False, id="nonblocking")]


@pytest.fixture(scope="module")
def ensemble():
if COMM_WORLD.size == 1:
return
return Ensemble(COMM_WORLD, 1)


@pytest.mark.parallel(nprocs=2)
def test_ensemble_allreduce(ensemble):
rank = ensemble.ensemble_rank
result = ensemble.allreduce(rank+1)
expected = sum([r+1 for r in range(ensemble.ensemble_size)])
parallel_assert(
result == expected,
msg=f"{result=} does not match {expected=}")


@pytest.mark.parallel(nprocs=2)
@pytest.mark.parametrize("root", roots)
def test_ensemble_reduce(ensemble, root):
rank = ensemble.ensemble_rank

# check default root=0 works
if root is None:
result = ensemble.reduce(rank+1)
root = 0
else:
result = ensemble.reduce(rank+1, root=root)

expected = sum([r+1 for r in range(ensemble.ensemble_size)])

parallel_assert(
result == expected,
participating=(rank == root),
msg=f"{result=} does not match {expected=} on rank {root=}"
)
parallel_assert(
result is None,
participating=(rank != root),
msg=f"Unexpected {result=} on non-root rank"
)


@pytest.mark.parallel(nprocs=2)
@pytest.mark.parametrize("root", roots)
def test_ensemble_bcast(ensemble, root):
rank = ensemble.ensemble_rank

# check default root=0 works
if root is None:
result = ensemble.bcast(rank+1)
root = 0
else:
result = ensemble.bcast(rank+1, root=root)

expected = root + 1

parallel_assert(result == expected)


@pytest.mark.parallel(nprocs=3)
def test_send_and_recv(ensemble):
rank = ensemble.ensemble_rank

rank0 = 0
rank1 = 1

send_data = rank + 1

if rank == rank0:
recv_expected = rank1 + 1

ensemble.send(send_data, dest=rank1, tag=rank0)
recv_data = ensemble.recv(source=rank1, tag=rank1)

elif rank == rank1:
recv_expected = rank0 + 1

recv_data = ensemble.recv(source=rank0, tag=rank0)
ensemble.send(send_data, dest=rank0, tag=rank1)

else:
recv_expected = None
recv_data = None

# Test send/recv between first two spatial comms
# ie: ensemble.ensemble_comm.rank == 0 and 1
parallel_assert(
recv_data == recv_expected,
participating=rank in (rank0, rank1),
)


@pytest.mark.parallel(nprocs=3)
def test_sendrecv(ensemble):
rank = ensemble.ensemble_rank
size = ensemble.ensemble_size
src_rank = (rank - 1) % size
dst_rank = (rank + 1) % size

send_data = rank + 1
recv_expected = src_rank + 1

recv_result = ensemble.sendrecv(
send_data, dst_rank, sendtag=rank,
source=src_rank, recvtag=src_rank)

parallel_assert(recv_result == recv_expected)
Loading