Skip to content

Implement Convolve2D Op #1397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 9, 2025

Natural follow up to #1318 is to also get convolve2d in.

This PR is just to get the ball rolling, it's just a super dumb wrapper around scipy.signal.convolve2d. Remainings TODO:

  • Work out the relevant gradients
  • Is mode == "same" required, or can we handle it as a special case like in convolve1d
  • scipy has arguments for boundary and fillvalue. Currently I'm just forwarding these arguments, but for computing gradients, it might make more sense to do this symbolically on the inputs using pt.pad, then always use fixed arguments in perform.

I'm sure there's more to do that I am omitting.

One huge thing is that the way the convolution operator is defined is different between the signal processing literature and the deep learning literature. Jax handles this by having an abstract_convolve op buried down in lax, and then front-ends that convert the domain specific formulations to the inner representation. Not sure what we want to do.

Also, the old Conv2d Op from the Theano days had an extremely complex Gradient Op. I am curious how much of that is over engineering, and how much is that convolutions are inherently tricky.

@jessegrabowski jessegrabowski added enhancement New feature or request help wanted Extra attention is needed SciPy compatibility convolution labels May 9, 2025
Copy link

codecov bot commented May 9, 2025

Codecov Report

Attention: Patch coverage is 90.19608% with 5 lines in your changes missing coverage. Please review.

Project coverage is 82.09%. Comparing base (5335a68) to head (4230f63).
Report is 12 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/signal/conv.py 90.19% 3 Missing and 2 partials ⚠️

❌ Your patch check has failed because the patch coverage (90.19%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1397      +/-   ##
==========================================
+ Coverage   82.02%   82.09%   +0.06%     
==========================================
  Files         207      208       +1     
  Lines       49301    49609     +308     
  Branches     8747     8798      +51     
==========================================
+ Hits        40440    40725     +285     
- Misses       6695     6708      +13     
- Partials     2166     2176      +10     
Files with missing lines Coverage Δ
pytensor/tensor/signal/conv.py 94.06% <90.19%> (-3.00%) ⬇️

... and 13 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.


def perform(self, node, inputs, outputs):
in1, in2 = inputs
outputs[0][0] = scipy_convolve2d(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this is where I would like to compare with the old C stuff we had

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below

@jessegrabowski
Copy link
Member Author

Benchmarks for the non-batched case against the old theano implementation. I guess this isn't super meaningful, because they really, really cared about the batched case.

--------------------------------------------------------------------------------------------------------------------------------------- benchmark: 6 tests ---------------------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                 Min                       Max                      Mean                 StdDev                    Median                     IQR            Outliers           OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_conv2d_benchmark[new-impl-kernel_shape=(10, 10)-kernel_shape=(3, 3)]                      5.2000 (1.0)            250.5990 (1.0)              5.9652 (1.0)           4.2769 (1.0)              5.6000 (1.0)            0.1001 (1.0)      250;1205  167,637.6976 (1.0)       11696           1
test_conv2d_benchmark[theano-impl-kernel_shape=(10, 10)-kernel_shape=(3, 3)]                  24.4990 (4.71)           621.7840 (2.48)            27.8714 (4.67)         10.9988 (2.57)            25.9990 (4.64)           0.8990 (8.98)     515;1098   35,879.0114 (0.21)       9834           1

test_conv2d_benchmark[new-impl-kernel_shape=(100, 100)-kernel_shape=(10, 10)]              1,322.1980 (254.27)       2,194.2970 (8.76)         1,377.6138 (230.94)       88.6022 (20.72)        1,341.5485 (239.56)        52.4990 (524.68)      91;96      725.8928 (0.00)        738           1
test_conv2d_benchmark[theano-impl-kernel_shape=(100, 100)-kernel_shape=(10, 10)]           1,353.2660 (260.24)       2,167.3440 (8.65)         1,416.2842 (237.42)       97.7110 (22.85)        1,375.7150 (245.66)        59.0990 (590.64)      94;94      706.0730 (0.00)        674           1

test_conv2d_benchmark[new-impl-kernel_shape=(1000, 1000)-kernel_shape=(50, 50)]        4,151,984.7820 (>1000.0)  4,273,707.2180 (>1000.0)  4,207,943.9498 (>1000.0)  50,938.5541 (>1000.0)  4,224,029.2110 (>1000.0)   81,616.0525 (>1000.0)       2;0        0.2376 (0.00)          5           1
test_conv2d_benchmark[theano-impl-kernel_shape=(1000, 1000)-kernel_shape=(50, 50)]     4,136,063.3770 (>1000.0)  4,355,675.9200 (>1000.0)  4,208,198.4226 (>1000.0)  99,650.2559 (>1000.0)  4,141,102.1190 (>1000.0)  151,135.4925 (>1000.0)       1;0        0.2376 (0.00)          5           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Code for the benchmark:

@pytest.mark.parametrize(
    "data_shape, kernel_shape", [[(10, 10), (3, 3)],
                                 [(100, 100), (10, 10)],
                                 [(1000, 1000), (50, 50)]],
    ids=lambda x: f"kernel_shape={x}"
)
@pytest.mark.parametrize('func', ['new', 'theano'], ids=['new-impl', 'theano-impl'])
def test_conv2d_benchmark(data_shape, kernel_shape, func, benchmark):

    x = matrix("x")
    y = matrix("y")

    if func == 'new':
        out = convolve2d(x, y, mode="valid")
    else:
        out = conv2d(input=x[None, None], filters=y[None, None], border_mode="valid")
        out = out[0, 0]

    rng = np.random.default_rng(38)
    x_test = rng.normal(size=data_shape).astype(x.dtype)
    y_test = rng.normal(size=kernel_shape).astype(y.dtype)

    fn = function([x, y], out, trust_input=True)

    benchmark(fn, x_test, y_test)

The old rewrites are there, but I didn't carefully check which (if any) were used.

@jessegrabowski
Copy link
Member Author

Here are the results if the perform method uses scipy.signal.convolve instead of convolve2d (so that it can choose between direct and fft):

-------------------------------------------------------------------------------------------------------------------------------------- benchmark: 6 tests --------------------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                 Min                       Max                      Mean                 StdDev                    Median                    IQR            Outliers          OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_conv2d_benchmark[new-impl-kernel_shape=(10, 10)-kernel_shape=(3, 3)]                     30.7000 (1.24)           522.9960 (1.0)             36.3384 (1.20)         15.2515 (1.0)             32.5000 (1.23)          1.3991 (1.0)       400;961  27,519.0703 (0.83)       6143           1
test_conv2d_benchmark[theano-impl-kernel_shape=(10, 10)-kernel_shape=(3, 3)]                  24.8000 (1.0)          1,082.3920 (2.07)            30.2780 (1.0)          20.2872 (1.33)            26.4000 (1.0)           1.4010 (1.00)     444;1452  33,027.2273 (1.0)        9681           1

test_conv2d_benchmark[new-impl-kernel_shape=(100, 100)-kernel_shape=(10, 10)]                287.5980 (11.60)        1,692.2900 (3.24)           334.4822 (11.05)        78.7191 (5.16)           309.5480 (11.73)        41.6000 (29.73)     127;146   2,989.6958 (0.09)       1412           1
test_conv2d_benchmark[theano-impl-kernel_shape=(100, 100)-kernel_shape=(10, 10)]           1,352.8900 (54.55)        3,885.0710 (7.43)         1,450.9978 (47.92)       160.3017 (10.51)        1,403.1890 (53.15)       105.0991 (75.12)       64;58     689.1809 (0.02)        694           1

test_conv2d_benchmark[new-impl-kernel_shape=(1000, 1000)-kernel_shape=(50, 50)]           32,648.5110 (>1000.0)     44,534.5410 (85.15)       36,218.6597 (>1000.0)   2,614.7971 (171.45)      36,145.6900 (>1000.0)   3,311.7563 (>1000.0)       6;1      27.6101 (0.00)         23           1
test_conv2d_benchmark[theano-impl-kernel_shape=(1000, 1000)-kernel_shape=(50, 50)]     4,144,777.8940 (>1000.0)  4,245,045.5250 (>1000.0)  4,203,307.0746 (>1000.0)  37,183.2344 (>1000.0)  4,207,583.0600 (>1000.0)  42,366.2383 (>1000.0)       2;0       0.2379 (0.00)          5           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@jessegrabowski
Copy link
Member Author

Interesting blog post, suggests the approach in this PR might be on the wrong track:

https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/

@ricardoV94
Copy link
Member

Can you test batch? The fft thing is something I want to do explicitly based on input size

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 10, 2025

I did, the pattern above holds for batches. Or do you want the batch benchmarks using scipy.signal.convolve2d so fft is never used? I did it with the fft version.

What is more important is to test the multi-channel case. Need to make a new function to do this, but it's not so complex.

@ricardoV94
Copy link
Member

Isn't multi-channel just batching the second argument?

@ricardoV94
Copy link
Member

ricardoV94 commented May 10, 2025

Interesting blog post, suggests the approach in this PR might be on the wrong track:

Why? Under the hood scipy/numpy are using blas if not going through the fft path. The convolve 1d for instance is a loop of vector vector blas products:
https://github.com/numpy/numpy/blob/e9e981e9d45335fd6b758e620812772e19143f35/numpy/_core/src/multiarray/multiarraymodule.c#L1235

@jessegrabowski
Copy link
Member Author

Isn't multi-channel just batching the second argument?

You also have to sum-reduce the input channel

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 10, 2025

Here's what the "neural net" version looks like using the impl in this PR:

def nn_conv2d(data: TensorVariable, filters: TensorVariable) -> TensorVariable:
    """Convolve 2D data with 2D filters.

    Parameters
    ----------
    data : (n, c_in, w, h) tensor_variable
        Input data.
    filters : (c_in, c_out, wf, hf) tensor_variable
        Convolution filters.

    Returns
    -------
    out: tensor_variable
        The result of convolving data with filters.

    """
    result = convolve2d(expand_dims(data, axis=1),
                        expand_dims(filters, axis=0),
                        mode="valid")
    return result.sum(axis=2)

Here's a benchmark suite:

@pytest.mark.parametrize(
    "data_shape, kernel_shape", [[(10, 1, 8, 8), (3, 1, 3, 3)], # 8x8 grayscale
                                 [(1000, 1, 8, 8), (3, 1, 1, 3)], # same, but with 1000 images
                                 [(10, 3, 64, 64), (10, 3, 8, 8)], # 64x64 RGB
                                 [(1000, 3, 64, 64), (10, 3, 8, 8)], # same, but with 1000 images
                                 [(3, 100, 100, 100), (250, 100, 50, 50)]], # Very large, deep hidden layer or something

    ids=lambda x: f"data_shape={x[0]}, kernel_shape={x[1]}"
)
@pytest.mark.parametrize('func', ['new', 'theano'], ids=['new-impl', 'theano-impl'])
def test_conv2d_nn_benchmark(data_shape, kernel_shape, func, benchmark):
    import pytensor.tensor as pt
    x = pt.tensor("x", shape=data_shape)
    y = pt.tensor("y", shape=kernel_shape)

    if func == 'new':
        out = nn_conv2d(x, y)
    else:
        out = conv2d(input=x, filters=y, border_mode="valid")

    rng = np.random.default_rng(38)
    x_test = rng.normal(size=data_shape).astype(x.dtype)
    y_test = rng.normal(size=kernel_shape).astype(y.dtype)

    fn = function([x, y], out, trust_input=True)

    benchmark(fn, x_test, y_test)

And results:

----------------------------------------------------------------------------------------------------------------------------------------------------- benchmark: 10 tests ------------------------------------------------------------------------------------------------------------------------------------------------------
Name (time in us)                                                                                                      Min                         Max                        Mean                    StdDev                      Median                       IQR            Outliers         OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_conv2d_nn_benchmark[new-impl-data_shape=(10, 1, 8, 8), kernel_shape= (3, 1, 3, 3)]                          936.4921 (9.83)           1,380.5890 (1.17)             994.1993 (9.05)            66.9125 (1.87)             970.0910 (9.69)            36.1000 (4.25)         9;12  1,005.8346 (0.11)        118           1
test_conv2d_nn_benchmark[theano-impl-data_shape=(10, 1, 8, 8), kernel_shape= (3, 1, 3, 3)]                        95.2990 (1.0)            1,182.5920 (1.0)              109.8578 (1.0)             35.7867 (1.0)              100.1000 (1.0)              8.5010 (1.0)       275;629  9,102.6749 (1.0)        4270           1

test_conv2d_nn_benchmark[new-impl-data_shape=(1000, 1, 8, 8),  kernel_shape=(3, 1, 1, 3)]                     55,901.9330 (586.60)        84,205.5950 (71.20)         62,783.0746 (571.49)       7,240.1447 (202.31)        60,402.6950 (603.42)       3,931.4420 (462.47)        2;2     15.9279 (0.00)         15           1
test_conv2d_nn_benchmark[theano-impl-data_shape=(1000, 1, 8, 8),  kernel_shape=(3, 1, 1, 3)]                   6,673.1410 (70.02)         12,769.2180 (10.80)          7,506.3460 (68.33)          695.0093 (19.42)          7,425.5810 (74.18)          613.8975 (72.21)        18;6    133.2206 (0.01)        140           1

test_conv2d_nn_benchmark[new-impl-data_shape=(10, 3, 64, 64), kernel_shape=(10, 3, 8, 8)]                     81,628.8160 (856.55)       103,832.8290 (87.80)         87,618.6407 (797.56)       5,752.8892 (160.76)        86,344.6270 (862.58)       4,498.6125 (529.19)        2;1     11.4131 (0.00)         12           1
test_conv2d_nn_benchmark[theano-impl-data_shape=(10, 3, 64, 64), kernel_shape=(10, 3, 8, 8)]                 112,968.2180 (>1000.0)      121,295.8730 (102.57)       117,034.1081 (>1000.0)      2,838.9201 (79.33)        116,256.8010 (>1000.0)      4,449.5125 (523.41)        3;0      8.5445 (0.00)          9           1

test_conv2d_nn_benchmark[new-impl-data_shape=(1000, 3, 64, 64),  kernel_shape=(10, 3, 8, 8)]               5,882,476.3960 (>1000.0)    6,362,158.9680 (>1000.0)    6,068,239.5658 (>1000.0)    186,465.2903 (>1000.0)    6,029,093.1700 (>1000.0)    246,610.8448 (>1000.0)       1;0      0.1648 (0.00)          5           1
test_conv2d_nn_benchmark[theano-impl-data_shape=(1000, 3, 64, 64),  kernel_shape=(10, 3, 8, 8)]           11,253,140.1460 (>1000.0)   11,692,865.5460 (>1000.0)   11,389,335.7826 (>1000.0)    192,363.5367 (>1000.0)   11,274,060.4120 (>1000.0)    269,707.3488 (>1000.0)       1;0      0.0878 (0.00)          5           1

test_conv2d_nn_benchmark[new-impl-data_shape=(3, 100, 100, 100), kernel_shape=(250, 100, 50, 50)]         33,364,977.9930 (>1000.0)   36,696,836.2380 (>1000.0)   35,023,641.8976 (>1000.0)  1,265,627.7147 (>1000.0)   35,188,259.2120 (>1000.0)  1,789,538.3190 (>1000.0)       2;0      0.0286 (0.00)          5           1
test_conv2d_nn_benchmark[theano-impl-data_shape=(3, 100, 100, 100), kernel_shape=(250, 100, 50, 50)]     916,582,280.2420 (>1000.0)  927,166,598.2280 (>1000.0)  922,773,638.3670 (>1000.0)  4,064,912.3530 (>1000.0)  924,311,938.7580 (>1000.0)  5,257,726.4838 (>1000.0)       2;0      0.0011 (0.00)          5           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@ricardoV94
Copy link
Member

Are those using fft on the new impl? I want to know if the old c code is useful for non fft path, so far seems like not, and we may want to roll something much simpler.

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 11, 2025

Are those using fft on the new impl? I want to know if the old c code is useful for non fft path, so far seems like not, and we may want to roll something much simpler.

New new impl is using scipy.signal.conv, which automatically chooses between direct and fft based on the input size. For the small kernel tests, that should be doing direct and the old C impl is winning by a lot.

Both are really really miserable on the big tests :)

@ricardoV94
Copy link
Member

ricardoV94 commented May 11, 2025

Is the theno one still winning?

miserable may just be the cost of the computation or is jax cpu doing considerably better?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
convolution enhancement New feature or request help wanted Extra attention is needed SciPy compatibility
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants