@@ -386,6 +386,16 @@ def allocation_integral_types(self):
386386 else :
387387 return self ._allocation_integral_types
388388
389+ @staticmethod
390+ def _as_pyop2_type (tensor , indices = None ):
391+ if isinstance (tensor , (firedrake .Cofunction , firedrake .Function )):
392+ return OneFormAssembler ._as_pyop2_type (tensor , indices = indices )
393+ elif isinstance (tensor , ufl .Matrix ):
394+ return ExplicitMatrixAssembler ._as_pyop2_type (tensor , indices = indices )
395+ else :
396+ assert indices is None
397+ return tensor
398+
389399 def assemble (self , tensor = None , current_state = None ):
390400 """Assemble the form.
391401
@@ -410,21 +420,22 @@ def assemble(self, tensor=None, current_state=None):
410420 """
411421 def visitor (e , * operands ):
412422 t = tensor if e is self ._form else None
413- return self .base_form_assembly_visitor (e , t , * operands )
423+ # Deal with 2-form bcs inside the visitor
424+ bcs = self ._bcs if isinstance (e , ufl .BaseForm ) and len (e .arguments ()) == 2 else ()
425+ return self .base_form_assembly_visitor (e , t , bcs , * operands )
414426
415427 # DAG assembly: traverse the DAG in a post-order fashion and evaluate the node on the fly.
416428 visited = {}
417429 result = BaseFormAssembler .base_form_postorder_traversal (self ._form , visitor , visited )
418430
419- # Apply BCs after assembly
431+ # Deal with 1-form bcs outside the visitor
420432 rank = len (self ._form .arguments ())
421433 if rank == 1 and not isinstance (result , ufl .ZeroBaseForm ):
422434 for bc in self ._bcs :
423435 OneFormAssembler ._apply_bc (self , result , bc , u = current_state )
424-
425436 return result
426437
427- def base_form_assembly_visitor (self , expr , tensor , * args ):
438+ def base_form_assembly_visitor (self , expr , tensor , bcs , * args ):
428439 r"""Assemble a :class:`~ufl.classes.BaseForm` object given its assembled operands.
429440
430441 This functions contains the assembly handlers corresponding to the different nodes that
@@ -445,7 +456,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
445456 assembler = OneFormAssembler (form , form_compiler_parameters = self ._form_compiler_params ,
446457 zero_bc_nodes = self ._zero_bc_nodes , diagonal = self ._diagonal , weight = self ._weight )
447458 elif rank == 2 :
448- assembler = TwoFormAssembler (form , bcs = self . _bcs , form_compiler_parameters = self ._form_compiler_params ,
459+ assembler = TwoFormAssembler (form , bcs = bcs , form_compiler_parameters = self ._form_compiler_params ,
449460 mat_type = self ._mat_type , sub_mat_type = self ._sub_mat_type ,
450461 options_prefix = self ._options_prefix , appctx = self ._appctx , weight = self ._weight ,
451462 allocation_integral_types = self .allocation_integral_types )
@@ -456,13 +467,12 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
456467 if len (args ) != 1 :
457468 raise TypeError ("Not enough operands for Adjoint" )
458469 mat , = args
459- res = tensor .petscmat if tensor else PETSc .Mat ()
460- petsc_mat = mat .petscmat
470+ result = tensor .petscmat if tensor else PETSc .Mat ()
461471 # Out-of-place Hermitian transpose
462- petsc_mat . hermitianTranspose (out = res )
463- ( row , col ) = mat . arguments ()
464- return matrix . AssembledMatrix (( col , row ), self ._bcs , res ,
465- options_prefix = self . _options_prefix )
472+ mat . petscmat . hermitianTranspose (out = result )
473+ if tensor is None :
474+ tensor = self .assembled_matrix ( expr , bcs , result )
475+ return tensor
466476 elif isinstance (expr , ufl .Action ):
467477 if len (args ) != 2 :
468478 raise TypeError ("Not enough operands for Action" )
@@ -480,7 +490,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
480490 result = tensor .petscmat if tensor else PETSc .Mat ()
481491 lhs .petscmat .matMult (rhs .petscmat , result = result )
482492 if tensor is None :
483- tensor = self .assembled_matrix (expr , result )
493+ tensor = self .assembled_matrix (expr , bcs , result )
484494 return tensor
485495 else :
486496 raise TypeError ("Incompatible RHS for Action." )
@@ -499,9 +509,6 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
499509 raise TypeError ("Mismatching weights and operands in FormSum" )
500510 if len (args ) == 0 :
501511 raise TypeError ("Empty FormSum" )
502- if tensor :
503- tensor .zero ()
504-
505512 # Assemble weights
506513 weights = []
507514 for w in expr .weights ():
@@ -519,27 +526,54 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
519526 raise ValueError ("Expecting a scalar weight expression" )
520527 weights .append (w )
521528
529+ # Scalar FormSum
522530 if all (isinstance (op , numbers .Complex ) for op in args ):
523- result = sum (weight * arg for weight , arg in zip (weights , args ))
524- return tensor .assign (result ) if tensor else result
525- elif (all (isinstance (op , firedrake .Cofunction ) for op in args )
531+ result = numpy .dot (weights , args )
532+ return tensor .assign (result ) if tensor else result .item ()
533+
534+ # Accumulate coefficients in a dictionary for each unique Dat/Mat
535+ terms = defaultdict (PETSc .ScalarType )
536+ for arg , weight in zip (args , weights ):
537+ t = self ._as_pyop2_type (arg )
538+ terms [t ] += weight
539+
540+ # Zero the output tensor, or rescale it if it appears in the sum
541+ tensor_scale = terms .pop (self ._as_pyop2_type (tensor ), 0 )
542+ if tensor is None or tensor_scale == 1 :
543+ pass
544+ elif tensor_scale == 0 :
545+ tensor .zero ()
546+ elif isinstance (tensor , (firedrake .Cofunction , firedrake .Function )):
547+ with tensor .dat .vec as v :
548+ v .scale (tensor_scale )
549+ elif isinstance (tensor , ufl .Matrix ):
550+ tensor .petscmat .scale (tensor_scale )
551+ else :
552+ raise ValueError ("Expecting tensor to be None, Function, Cofunction, or Matrix" )
553+
554+ # Compute the linear combination
555+ if (all (isinstance (op , firedrake .Cofunction ) for op in args )
526556 or all (isinstance (op , firedrake .Function ) for op in args )):
557+ # Vector FormSum
527558 V , = set (a .function_space () for a in args )
528559 result = tensor if tensor else firedrake .Function (V )
529- result .dat .maxpy (weights , [a .dat for a in args ])
560+ weights = terms .values ()
561+ dats = terms .keys ()
562+ result .dat .maxpy (weights , dats )
530563 return result
531564 elif all (isinstance (op , ufl .Matrix ) for op in args ):
565+ # Matrix FormSum
532566 result = tensor .petscmat if tensor else PETSc .Mat ()
533- for (op , w ) in zip ( args , weights ):
567+ for (op , w ) in terms . items ( ):
534568 if result :
535569 # If result is not void, then accumulate on it
536- result .axpy (w , op .petscmat )
570+ result .axpy (w , op .handle )
537571 else :
538572 # If result is void, then allocate it with first term
539- op .petscmat .copy (result = result )
573+ op .handle .copy (result = result )
540574 result .scale (w )
541575 if tensor is None :
542- tensor = self .assembled_matrix (expr , result )
576+ tensor = self .assembled_matrix (expr , bcs , result )
543577 return tensor
544578 else :
545579 raise TypeError ("Mismatching FormSum shapes" )
@@ -571,9 +605,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
571605 # Occur in situations such as Interpolate composition
572606 operand = assembled_operand [0 ]
573607
574- reconstruct_interp = expr ._ufl_expr_reconstruct_
575608 if (v , operand ) != expr .argument_slots ():
576- expr = reconstruct_interp (operand , v = v )
609+ expr = expr . _ufl_expr_reconstruct_ (operand , v = v )
577610
578611 rank = len (expr .arguments ())
579612 if rank > 2 :
@@ -586,7 +619,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
586619 default_missing_val = interp_data .pop ('default_missing_val' , None )
587620 if rank == 1 and isinstance (tensor , firedrake .Function ):
588621 V = tensor
589- interpolator = firedrake .Interpolator (expr , V , ** interp_data )
622+ interpolator = firedrake .Interpolator (expr , V , bcs = bcs , ** interp_data )
590623 # Assembly
591624 return interpolator .assemble (tensor = tensor , default_missing_val = default_missing_val )
592625 elif tensor and isinstance (expr , (firedrake .Function , firedrake .Cofunction , firedrake .MatrixBase )):
@@ -598,8 +631,8 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
598631 else :
599632 raise TypeError (f"Unrecognised BaseForm instance: { expr } " )
600633
601- def assembled_matrix (self , expr , petscmat ):
602- return matrix .AssembledMatrix (expr .arguments (), self . _bcs , petscmat ,
634+ def assembled_matrix (self , expr , bcs , petscmat ):
635+ return matrix .AssembledMatrix (expr .arguments (), bcs , petscmat ,
603636 options_prefix = self ._options_prefix )
604637
605638 @staticmethod
@@ -1448,10 +1481,11 @@ def _apply_bc(self, tensor, bc, u=None):
14481481 index = 0 if V .index is None else V .index
14491482 space = V if V .parent is None else V .parent
14501483 if isinstance (bc , DirichletBC ):
1451- if space != spaces [0 ]:
1452- raise TypeError ("bc space does not match the test function space" )
1453- elif space != spaces [1 ]:
1454- raise TypeError ("bc space does not match the trial function space" )
1484+ if not any (space == fs for fs in spaces ):
1485+ raise TypeError ("bc space does not match the test or trial function space" )
1486+ if spaces [0 ] != spaces [1 ]:
1487+ # Not on a diagonal block, we cannot set diagonal entries
1488+ return
14551489
14561490 # Set diagonal entries on bc nodes to 1 if the current
14571491 # block is on the matrix diagonal and its index matches the
0 commit comments