Skip to content

Implement Convolve2D Op #1397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion pytensor/tensor/signal/conv.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import TYPE_CHECKING, Literal, cast

from numpy import convolve as numpy_convolve
from scipy.signal import convolve2d as scipy_convolve2d

from pytensor.graph import Apply, Op
from pytensor.scalar.basic import upcast
from pytensor.tensor.basic import as_tensor_variable, join, zeros
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import maximum, minimum
from pytensor.tensor.type import vector
from pytensor.tensor.type import matrix, vector
from pytensor.tensor.variable import TensorVariable


Expand Down Expand Up @@ -131,3 +132,120 @@
mode = "valid"

return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2))


class Convolve2D(Op):
__props__ = ("mode", "boundary", "fillvalue")
gufunc_signature = "(n,m),(k,l)->(o,p)"

def __init__(
self,
mode: Literal["full", "valid", "same"] = "full",
boundary: Literal["fill", "wrap", "symm"] = "fill",
fillvalue: float | int = 0,
):
if mode not in ("full", "valid", "same"):
raise ValueError(f"Invalid mode: {mode}")

Check warning on line 148 in pytensor/tensor/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/signal/conv.py#L148

Added line #L148 was not covered by tests
if boundary not in ("fill", "wrap", "symm"):
raise ValueError(f"Invalid boundary: {boundary}")

Check warning on line 150 in pytensor/tensor/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/signal/conv.py#L150

Added line #L150 was not covered by tests

self.mode = mode
self.boundary = boundary
self.fillvalue = fillvalue

def make_node(self, in1, in2):
in1, in2 = map(as_tensor_variable, (in1, in2))

assert in1.ndim == 2
assert in2.ndim == 2

dtype = upcast(in1.dtype, in2.dtype)

n, m = in1.type.shape
k, l = in2.type.shape

if self.mode == "full":
shape_1 = None if (n is None or k is None) else n + k - 1
shape_2 = None if (m is None or l is None) else m + l - 1

elif self.mode == "valid":
shape_1 = None if (n is None or k is None) else max(n, k) - max(n, k) + 1
shape_2 = None if (m is None or l is None) else max(m, l) - min(m, l) + 1

else: # mode == "same"
shape_1 = n
shape_2 = m

out_shape = (shape_1, shape_2)
out = matrix(dtype=dtype, shape=out_shape)
return Apply(self, [in1, in2], [out])

def perform(self, node, inputs, outputs):
in1, in2 = inputs
outputs[0][0] = scipy_convolve2d(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this is where I would like to compare with the old C stuff we had

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below

in1, in2, mode=self.mode, boundary=self.boundary, fillvalue=self.fillvalue
)

def infer_shape(self, fgraph, node, shapes):
in1_shape, in2_shape = shapes
n, m = in1_shape
k, l = in2_shape

if self.mode == "full":
shape = (n + k - 1, m + l - 1)
elif self.mode == "valid":
shape = (
maximum(n, k) - minimum(n, k) + 1,
maximum(m, l) - minimum(m, l) + 1,
)
else: # self.mode == 'same':
shape = (n, m)

return [shape]

def L_op(self, inputs, outputs, output_grads):
raise NotImplementedError

Check warning on line 207 in pytensor/tensor/signal/conv.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/signal/conv.py#L207

Added line #L207 was not covered by tests


def convolve2d(
in1: "TensorLike",
in2: "TensorLike",
mode: Literal["full", "valid", "same"] = "full",
boundary: Literal["fill", "wrap", "symm"] = "fill",
fillvalue: float | int = 0,
) -> TensorVariable:
"""Convolve two two-dimensional arrays.

Convolve in1 and in2, with the output size determined by the mode argument.

Parameters
----------
in1 : (..., N, M) tensor_like
First input.
in2 : (..., K, L) tensor_like
Second input.
mode : {'full', 'valid', 'same'}, optional
A string indicating the size of the output:
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1).
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1).
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
boundary : {'fill', 'wrap', 'symm'}, optional
A string indicating how to handle boundaries:
- 'fill': Pads the input arrays with fillvalue.
- 'wrap': Circularly wraps the input arrays.
- 'symm': Symmetrically reflects the input arrays.
fillvalue : float or int, optional
The value to use for padding when boundary is 'fill'. Default is 0.
Returns
-------
out: tensor_variable
The discrete linear convolution of in1 with in2.

"""
in1 = as_tensor_variable(in1)
in2 = as_tensor_variable(in2)

blockwise_convolve = Blockwise(
Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue)
)
return cast(TensorVariable, blockwise_convolve(in1, in2))
30 changes: 29 additions & 1 deletion tests/tensor/signal/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import numpy as np
import pytest
from scipy.signal import convolve as scipy_convolve
from scipy.signal import convolve2d as scipy_convolve2d

from pytensor import config, function, grad
from pytensor.graph.basic import ancestors, io_toposort
from pytensor.graph.rewriting import rewrite_graph
from pytensor.tensor import matrix, vector
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d
from tests import unittest_tools as utt


Expand Down Expand Up @@ -109,3 +110,30 @@ def test_convolve1d_valid_grad_rewrite(static_shape):
if isinstance(node.op, Convolve1d)
]
assert conv_op.mode == ("valid" if static_shape else "full")


@pytest.mark.parametrize(
"kernel_shape", [(3, 3), (5, 3), (5, 8)], ids=lambda x: f"kernel_shape={x}"
)
@pytest.mark.parametrize(
"data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}"
)
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
@pytest.mark.parametrize("boundary", ["fill", "wrap", "symm"])
def test_convolve2d(kernel_shape, data_shape, mode, boundary):
data = matrix("data")
kernel = matrix("kernel")
op = partial(convolve2d, mode=mode, boundary=boundary, fillvalue=0)

rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode))))
data_val = rng.normal(size=data_shape).astype(data.dtype)
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)

fn = function([data, kernel], op(data, kernel))
np.testing.assert_allclose(
fn(data_val, kernel_val),
scipy_convolve2d(
data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0
),
rtol=1e-6 if config.floatX == "float32" else 1e-15,
)