66
77from pytensor .gradient import DisconnectedType
88from pytensor .graph import Apply , Constant
9+ from pytensor .graph .op import Op
910from pytensor .link .c .op import COp
1011from pytensor .scalar import as_scalar
1112from pytensor .scalar .basic import upcast
@@ -220,18 +221,16 @@ class Convolve2D(Op):
220221
221222 def __init__ (
222223 self ,
223- mode : Literal ["full" , "valid" , "same" ] = "full" ,
224+ mode : Literal ["full" , "valid" ] = "full" ,
224225 boundary : Literal ["fill" , "wrap" , "symm" ] = "fill" ,
225226 fillvalue : float | int = 0 ,
226227 ):
227- if mode not in ("full" , "valid" , "same" ):
228+ if mode not in ("full" , "valid" ):
228229 raise ValueError (f"Invalid mode: { mode } " )
229- if boundary not in ("fill" , "wrap" , "symm" ):
230- raise ValueError (f"Invalid boundary: { boundary } " )
231230
232231 self .mode = mode
233- self .boundary = boundary
234232 self .fillvalue = fillvalue
233+ self .boundary = boundary
235234
236235 def make_node (self , in1 , in2 ):
237236 in1 , in2 = map (as_tensor_variable , (in1 , in2 ))
@@ -262,8 +261,13 @@ def make_node(self, in1, in2):
262261
263262 def perform (self , node , inputs , outputs ):
264263 in1 , in2 = inputs
264+
265+ # if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
266+ # outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
267+ #
268+ # else:
265269 outputs [0 ][0 ] = scipy_convolve2d (
266- in1 , in2 , mode = self .mode , boundary = self .boundary , fillvalue = self .fillvalue
270+ in1 , in2 , mode = self .mode , fillvalue = self .fillvalue , boundary = self .boundary
267271 )
268272
269273 def infer_shape (self , fgraph , node , shapes ):
@@ -284,7 +288,18 @@ def infer_shape(self, fgraph, node, shapes):
284288 return [shape ]
285289
286290 def L_op (self , inputs , outputs , output_grads ):
287- raise NotImplementedError
291+ in1 , in2 = inputs
292+ incoming_grads = output_grads [0 ]
293+
294+ if self .mode == "full" :
295+ prop_dict = self ._props_dict ()
296+ prop_dict ["mode" ] = "valid"
297+ conv_valid = type (self )(** prop_dict )
298+
299+ in1_grad = conv_valid (in2 , incoming_grads )
300+ in2_grad = conv_valid (in1 , incoming_grads )
301+
302+ return [in1_grad , in2_grad ]
288303
289304
290305def convolve2d (
@@ -325,6 +340,9 @@ def convolve2d(
325340 in1 = as_tensor_variable (in1 )
326341 in2 = as_tensor_variable (in2 )
327342
343+ # TODO: Handle boundaries symbolically
344+ # TODO: Handle 'same' symbolically
345+
328346 blockwise_convolve = Blockwise (
329347 Convolve2D (mode = mode , boundary = boundary , fillvalue = fillvalue )
330348 )
0 commit comments