|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import ufl |
2 | 4 | from itertools import chain |
3 | 5 | from contextlib import ExitStack |
@@ -131,14 +133,38 @@ def dm(self): |
131 | 133 | return self.u_restrict.function_space().dm |
132 | 134 |
|
133 | 135 | @staticmethod |
134 | | - def compute_bc_lifting(J, u): |
135 | | - """Return the action of the bilinear form J (without bcs) on a Function u.""" |
| 136 | + def compute_bc_lifting(J: ufl.BaseForm | slate.TensorBase, |
| 137 | + u: Function, |
| 138 | + L: ufl.BaseForm | slate.TensorBase | 0 = 0): |
| 139 | + """Compute the residual after lifting DirichletBCs. |
| 140 | +
|
| 141 | + Parameters |
| 142 | + ---------- |
| 143 | + J |
| 144 | + The Jacobian bilinear form. |
| 145 | + u |
| 146 | + The Function on which DirichletBCs are applied. |
| 147 | + L |
| 148 | + The unlifted residual linear form. |
| 149 | +
|
| 150 | + Return |
| 151 | + ------ |
| 152 | + F : ufl.BaseForm | slate.TensorBase |
| 153 | + The residual J*u-L after lifting DirichletBCs. |
| 154 | + """ |
136 | 155 | if isinstance(J, MatrixBase) and J.has_bcs: |
137 | 156 | # Extract the full form without bcs |
138 | 157 | if not isinstance(J.a, (ufl.BaseForm, slate.slate.TensorBase)): |
139 | 158 | raise TypeError(f"Could not remove bcs from {type(J).__name__}.") |
140 | 159 | J = J.a |
141 | | - return ufl_expr.action(J, u) |
| 160 | + F = ufl_expr.action(J, u) |
| 161 | + if isinstance(F, slate.slate.TensorBase) and not isinstance(L, (ufl.Form, slate.slate.TensorBase)): |
| 162 | + # Slate expressions should not combine with assembled Cofunctions |
| 163 | + # because assemble(AssembledVector(L)) repeats element summation on L |
| 164 | + F = ufl.FormSum((F, 1), (L, -1)) |
| 165 | + elif L != 0: |
| 166 | + F = F - L |
| 167 | + return F |
142 | 168 |
|
143 | 169 |
|
144 | 170 | class NonlinearVariationalSolver(OptionsManager, NonlinearVariationalSolverMixin): |
@@ -403,14 +429,12 @@ def __init__(self, a, L, u, bcs=None, aP=None, |
403 | 429 | """ |
404 | 430 | # In the linear case, the Jacobian is the equation LHS (J=a). |
405 | 431 | # Jacobian is checked in superclass, but let's check L here. |
406 | | - if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)) and L == 0: |
407 | | - F = self.compute_bc_lifting(a, u) |
408 | | - else: |
409 | | - if not isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)): |
410 | | - raise TypeError("Provided RHS is a '%s', not a Form or Slate Tensor" % type(L).__name__) |
| 432 | + if isinstance(L, (ufl.BaseForm, slate.slate.TensorBase)): |
411 | 433 | if len(L.arguments()) != 1 and not L.empty(): |
412 | 434 | raise ValueError("Provided RHS is not a linear form") |
413 | | - F = self.compute_bc_lifting(a, u) - L |
| 435 | + elif L != 0: |
| 436 | + raise TypeError(f"Provided RHS is a '{type(L).__name__}', not a Form or Slate Tensor") |
| 437 | + F = self.compute_bc_lifting(a, u, L=L) |
414 | 438 |
|
415 | 439 | super(LinearVariationalProblem, self).__init__(F, u, bcs=bcs, J=a, Jp=aP, |
416 | 440 | form_compiler_parameters=form_compiler_parameters, |
|
0 commit comments