diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py index 59a5a9ea9c..f12e5e0a63 100644 --- a/pytensor/tensor/signal/conv.py +++ b/pytensor/tensor/signal/conv.py @@ -1,13 +1,14 @@ from typing import TYPE_CHECKING, Literal, cast from numpy import convolve as numpy_convolve +from scipy.signal import convolve2d as scipy_convolve2d from pytensor.graph import Apply, Op from pytensor.scalar.basic import upcast from pytensor.tensor.basic import as_tensor_variable, join, zeros from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.math import maximum, minimum -from pytensor.tensor.type import vector +from pytensor.tensor.type import matrix, vector from pytensor.tensor.variable import TensorVariable @@ -131,3 +132,120 @@ def convolve1d( mode = "valid" return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2)) + + +class Convolve2D(Op): + __props__ = ("mode", "boundary", "fillvalue") + gufunc_signature = "(n,m),(k,l)->(o,p)" + + def __init__( + self, + mode: Literal["full", "valid", "same"] = "full", + boundary: Literal["fill", "wrap", "symm"] = "fill", + fillvalue: float | int = 0, + ): + if mode not in ("full", "valid", "same"): + raise ValueError(f"Invalid mode: {mode}") + if boundary not in ("fill", "wrap", "symm"): + raise ValueError(f"Invalid boundary: {boundary}") + + self.mode = mode + self.boundary = boundary + self.fillvalue = fillvalue + + def make_node(self, in1, in2): + in1, in2 = map(as_tensor_variable, (in1, in2)) + + assert in1.ndim == 2 + assert in2.ndim == 2 + + dtype = upcast(in1.dtype, in2.dtype) + + n, m = in1.type.shape + k, l = in2.type.shape + + if self.mode == "full": + shape_1 = None if (n is None or k is None) else n + k - 1 + shape_2 = None if (m is None or l is None) else m + l - 1 + + elif self.mode == "valid": + shape_1 = None if (n is None or k is None) else max(n, k) - max(n, k) + 1 + shape_2 = None if (m is None or l is None) else max(m, l) - min(m, l) + 1 + + else: # mode == "same" + shape_1 = n + shape_2 = m + + out_shape = (shape_1, shape_2) + out = matrix(dtype=dtype, shape=out_shape) + return Apply(self, [in1, in2], [out]) + + def perform(self, node, inputs, outputs): + in1, in2 = inputs + outputs[0][0] = scipy_convolve2d( + in1, in2, mode=self.mode, boundary=self.boundary, fillvalue=self.fillvalue + ) + + def infer_shape(self, fgraph, node, shapes): + in1_shape, in2_shape = shapes + n, m = in1_shape + k, l = in2_shape + + if self.mode == "full": + shape = (n + k - 1, m + l - 1) + elif self.mode == "valid": + shape = ( + maximum(n, k) - minimum(n, k) + 1, + maximum(m, l) - minimum(m, l) + 1, + ) + else: # self.mode == 'same': + shape = (n, m) + + return [shape] + + def L_op(self, inputs, outputs, output_grads): + raise NotImplementedError + + +def convolve2d( + in1: "TensorLike", + in2: "TensorLike", + mode: Literal["full", "valid", "same"] = "full", + boundary: Literal["fill", "wrap", "symm"] = "fill", + fillvalue: float | int = 0, +) -> TensorVariable: + """Convolve two two-dimensional arrays. + + Convolve in1 and in2, with the output size determined by the mode argument. + + Parameters + ---------- + in1 : (..., N, M) tensor_like + First input. + in2 : (..., K, L) tensor_like + Second input. + mode : {'full', 'valid', 'same'}, optional + A string indicating the size of the output: + - 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1). + - 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1). + - 'same': The output is the same size as in1, centered with respect to the 'full' output. + boundary : {'fill', 'wrap', 'symm'}, optional + A string indicating how to handle boundaries: + - 'fill': Pads the input arrays with fillvalue. + - 'wrap': Circularly wraps the input arrays. + - 'symm': Symmetrically reflects the input arrays. + fillvalue : float or int, optional + The value to use for padding when boundary is 'fill'. Default is 0. + Returns + ------- + out: tensor_variable + The discrete linear convolution of in1 with in2. + + """ + in1 = as_tensor_variable(in1) + in2 = as_tensor_variable(in2) + + blockwise_convolve = Blockwise( + Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue) + ) + return cast(TensorVariable, blockwise_convolve(in1, in2)) diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py index d6b0d69d7c..fec764e3a8 100644 --- a/tests/tensor/signal/test_conv.py +++ b/tests/tensor/signal/test_conv.py @@ -3,13 +3,14 @@ import numpy as np import pytest from scipy.signal import convolve as scipy_convolve +from scipy.signal import convolve2d as scipy_convolve2d from pytensor import config, function, grad from pytensor.graph.basic import ancestors, io_toposort from pytensor.graph.rewriting import rewrite_graph from pytensor.tensor import matrix, vector from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.signal.conv import Convolve1d, convolve1d +from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d from tests import unittest_tools as utt @@ -109,3 +110,30 @@ def test_convolve1d_valid_grad_rewrite(static_shape): if isinstance(node.op, Convolve1d) ] assert conv_op.mode == ("valid" if static_shape else "full") + + +@pytest.mark.parametrize( + "kernel_shape", [(3, 3), (5, 3), (5, 8)], ids=lambda x: f"kernel_shape={x}" +) +@pytest.mark.parametrize( + "data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}" +) +@pytest.mark.parametrize("mode", ["full", "valid", "same"]) +@pytest.mark.parametrize("boundary", ["fill", "wrap", "symm"]) +def test_convolve2d(kernel_shape, data_shape, mode, boundary): + data = matrix("data") + kernel = matrix("kernel") + op = partial(convolve2d, mode=mode, boundary=boundary, fillvalue=0) + + rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode)))) + data_val = rng.normal(size=data_shape).astype(data.dtype) + kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype) + + fn = function([data, kernel], op(data, kernel)) + np.testing.assert_allclose( + fn(data_val, kernel_val), + scipy_convolve2d( + data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0 + ), + rtol=1e-6 if config.floatX == "float32" else 1e-15, + )