11from contextlib import contextmanager
2- from functools import cached_property
32from operator import itemgetter
43
54import firedrake as fd
@@ -26,30 +25,33 @@ def local_vector(u, *, readonly=False):
2625
2726
2827class L2Cholesky :
29- def __init__ (self , space ):
28+ def __init__ (self , space , * , constant_jacobian = True ):
3029 self ._space = space
30+ self ._constant_jacobian = constant_jacobian
31+ self ._pc = None
3132
3233 @property
3334 def space (self ):
3435 return self ._space
3536
36- @cached_property
37- def M (self ):
38- return fd .assemble (fd .inner (fd .TrialFunction (self .space ), fd .TestFunction (self .space )) * fd .dx ,
39- mat_type = "aij" )
40-
41- @cached_property
42- def M_local (self ):
43- return self .M .petscmat .getDiagonalBlock ()
44-
45- @cached_property
37+ @property
4638 def pc (self ):
4739 import petsc4py .PETSc as PETSc
48- pc = PETSc .PC ().create (self .M_local .comm )
49- pc .setType (PETSc .PC .Type .CHOLESKY )
50- pc .setFactorSolverType (PETSc .Mat .SolverType .PETSC )
51- pc .setOperators (self .M_local )
52- pc .setUp ()
40+
41+ pc = self ._pc
42+ if self ._pc is None :
43+ M = fd .assemble (fd .inner (fd .TrialFunction (self .space ), fd .TestFunction (self .space )) * fd .dx ,
44+ mat_type = "aij" )
45+ M_local = M .petscmat .getDiagonalBlock ()
46+
47+ pc = PETSc .PC ().create (M_local .comm )
48+ pc .setType (PETSc .PC .Type .CHOLESKY )
49+ pc .setFactorSolverType (PETSc .Mat .SolverType .PETSC )
50+ pc .setOperators (M_local )
51+ pc .setUp ()
52+ if self ._constant_jacobian :
53+ self ._pc = pc
54+
5355 return pc
5456
5557 def C_inv_action (self , u ):
@@ -157,7 +159,6 @@ def __init__(self, functional, controls, *, riesz_map=None, alpha=0, tape=None):
157159 self ._space = tuple (control .control .function_space ()
158160 for control in self ._J .controls )
159161 self ._space_D = tuple (map (dg_space , self ._space ))
160- self ._C = tuple (map (L2Cholesky , self ._space_D ))
161162 self ._controls = tuple (Control (fd .Function (space_D ), riesz_map = "l2" )
162163 for space_D in self ._space_D )
163164 self ._controls = Enlist (Enlist (controls ).delist (self ._controls ))
@@ -169,6 +170,8 @@ def __init__(self, functional, controls, *, riesz_map=None, alpha=0, tape=None):
169170 self ._riesz_map = Enlist (riesz_map )
170171 if len (self ._riesz_map ) != len (self ._controls ):
171172 raise ValueError ("Invalid length" )
173+ self ._C = tuple (L2Cholesky (space_D , constant_jacobian = riesz_map .constant_jacobian )
174+ for space_D , riesz_map in zip (self ._space_D , self ._riesz_map ))
172175
173176 # Map the initial guess
174177 controls_t = self ._primal_transform (tuple (control .control for control in self ._J .controls ), apply_riesz = False )
0 commit comments