@@ -1521,10 +1521,13 @@ def __init__(self, sf, forward_reduce, V, source_vom, expr, arguments):
15211521 self .arguments = arguments
15221522 # Calculate correct local and global sizes for the matrix
15231523 nroots , leaves , _ = sf .getGraph ()
1524- nleaves = len (leaves )
1524+ self . nleaves = len (leaves )
15251525 self ._local_sizes = V .comm .allgather (nroots )
1526- self .source_size = (nroots , sum (self ._local_sizes ))
1527- self .target_size = (nleaves , self .V .comm .allreduce (nleaves , op = MPI .SUM ))
1526+ self .source_size = (self .V .block_size * nroots , self .V .block_size * sum (self ._local_sizes ))
1527+ self .target_size = (
1528+ self .V .block_size * self .nleaves ,
1529+ self .V .block_size * V .comm .allreduce (self .nleaves , op = MPI .SUM ),
1530+ )
15281531
15291532 @property
15301533 def mpi_type (self ):
@@ -1565,7 +1568,7 @@ def expr_as_coeff(self, source_vec=None):
15651568 raise ValueError ("Need to provide a source dat for the argument!" )
15661569 arg = self .arguments [0 ]
15671570 arg_coeff = firedrake .Function (arg .function_space ())
1568- arg_coeff .dat .data_wo [:] = source_vec .getArray ().reshape (
1571+ arg_coeff .dat .data_wo [:] = source_vec .getArray (readonly = True ).reshape (
15691572 arg_coeff .dat .data_wo .shape
15701573 )
15711574 coeff_expr = ufl .replace (self .expr , {arg : arg_coeff })
@@ -1643,14 +1646,6 @@ def multTranspose(self, mat, source_vec, target_vec):
16431646 target_vec .zeroEntries ()
16441647 self .reduce (source_vec , target_vec )
16451648
1646- def _get_sizes (self ):
1647- nroots , leaves , _ = self .sf .getGraph ()
1648- nleaves = len (leaves )
1649- local_sizes = self .V .comm .allgather (nroots )
1650- source_size = (nroots , sum (local_sizes ))
1651- target_size = (nleaves , self .V .comm .allreduce (nleaves , op = MPI .SUM ))
1652- return source_size , target_size
1653-
16541649 def _create_permutation_mat (self ):
16551650 """Creates the PETSc matrix that represents the interpolation operator from a vertex-only mesh to
16561651 its input ordering vertex-only mesh"""
@@ -1659,25 +1654,23 @@ def _create_permutation_mat(self):
16591654 start = sum (self ._local_sizes [:self .V .comm .rank ])
16601655 end = start + self .source_size [0 ]
16611656 contiguous_indices = numpy .arange (start , end , dtype = utils .IntType )
1662- perm = numpy .zeros (self .target_size [ 0 ] , dtype = utils .IntType )
1657+ perm = numpy .zeros (self .nleaves , dtype = utils .IntType )
16631658 self .sf .bcastBegin (MPI .INT , contiguous_indices , perm , MPI .REPLACE )
16641659 self .sf .bcastEnd (MPI .INT , contiguous_indices , perm , MPI .REPLACE )
16651660 rows = numpy .arange (self .target_size [0 ] + 1 , dtype = utils .IntType )
1666- mat .setValuesCSR (rows , perm , numpy .ones_like (perm , dtype = utils .IntType ))
1661+ cols = (self .V .block_size * perm [:, None ] + numpy .arange (self .V .block_size , dtype = utils .IntType )[None , :]).reshape (- 1 )
1662+ mat .setValuesCSR (rows , cols , numpy .ones_like (cols , dtype = utils .IntType ))
16671663 mat .assemble ()
16681664 if self .forward_reduce :
16691665 mat .transpose ()
16701666 return mat
16711667
16721668 def _wrap_dummy_mat (self ):
16731669 mat = PETSc .Mat ().create (comm = self .V .comm )
1674- dim = self .V .value_size
1675- source_size = tuple (dim * i for i in self .source_size )
1676- target_size = tuple (dim * i for i in self .target_size )
16771670 if self .forward_reduce :
1678- mat_size = (source_size , target_size )
1671+ mat_size = (self . source_size , self . target_size )
16791672 else :
1680- mat_size = (target_size , source_size )
1673+ mat_size = (self . target_size , self . source_size )
16811674 mat .setSizes (mat_size )
16821675 mat .setType (mat .Type .PYTHON )
16831676 mat .setPythonContext (self )
0 commit comments