@@ -183,37 +183,41 @@ def matrix(self, o):
183183 bcs = ()
184184 return AssembledMatrix (tuple (args ), bcs , submat )
185185
186+ def zero_base_form (self , o ):
187+ return ZeroBaseForm (tuple (map (self , o .arguments ())))
188+
186189 def interpolate (self , o , operand ):
187190 if isinstance (operand , Zero ):
188- return ZeroBaseForm (o .arguments ())
191+ return self ( ZeroBaseForm (o .arguments () ))
189192
190193 dual_arg , _ = o .argument_slots ()
191- V = dual_arg .function_space ()
192- if len ( V ) == 1 or len ( dual_arg . arguments ()) == 1 :
194+ if len ( dual_arg . arguments ()) == 1 or len ( dual_arg .arguments ()[ - 1 ]. function_space ()) == 1 :
195+ # The dual argument has been contracted or does not need to be split
193196 return o ._ufl_expr_reconstruct_ (operand , dual_arg )
194197
198+ if not isinstance (dual_arg , Coargument ):
199+ raise NotImplementedError (f"I do not know how to split an Interpolate with a { type (dual_arg ).__name__ } ." )
200+
201+ indices = self .blocks [dual_arg .number ()]
202+ V = dual_arg .function_space ()
203+
195204 # Split the target (dual) argument
196- if isinstance (dual_arg , Coargument ):
197- dual_arg = self (dual_arg )
198- indices = self .blocks [dual_arg .number ()]
199- else :
200- raise NotImplementedError ()
205+ sub_dual_arg = self (dual_arg )
206+ W = sub_dual_arg .function_space ()
201207
202- # Unflatten the expression into the target shapes
208+ # Unflatten the expression into the target shape
203209 cur = 0
204- cindices = []
210+ components = []
205211 for i , Vi in enumerate (V ):
206212 if i in indices :
207- cindices .extend (range (cur , cur + Vi .value_size ))
213+ components .extend (operand [ i ] for i in range (cur , cur + Vi .value_size ))
208214 cur += Vi .value_size
209215
210- W = dual_arg .function_space ()
211- components = [operand [i ] for i in cindices ]
212216 operand = as_tensor (numpy .reshape (components , W .value_shape ))
213217 if isinstance (operand , Zero ):
214- return ZeroBaseForm (o .arguments ())
218+ return self ( ZeroBaseForm (o .arguments () ))
215219
216- return o ._ufl_expr_reconstruct_ (operand , dual_arg )
220+ return o ._ufl_expr_reconstruct_ (operand , sub_dual_arg )
217221
218222
219223SplitForm = collections .namedtuple ("SplitForm" , ["indices" , "form" ])
0 commit comments