Skip to content

fix: trim padding_out in conv_transpose x-backward (#4845)#4916

Merged
laggui merged 2 commits into
tracel-ai:mainfrom
SAY-5:fix/conv-backward-padding-out-4845
May 7, 2026
Merged

fix: trim padding_out in conv_transpose x-backward (#4845)#4916
laggui merged 2 commits into
tracel-ai:mainfrom
SAY-5:fix/conv-backward-padding-out-4845

Conversation

@SAY-5
Copy link
Copy Markdown
Contributor

@SAY-5 SAY-5 commented May 4, 2026

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Fixes #4845.

Changes

conv_transpose{1,2,3}d_x_backward ran the inverse conv on the full output_grad, but the trailing padding_out cells of a transpose-conv output do not depend on any input cell. The inverse conv therefore produced a result longer than the original input by floor(padding_out / stride) on each spatial axis, which then panicked when the gradient was registered against x.

Slice off the trailing cells after the inverse conv to recover the original input shape. Weight and bias gradients are unaffected (weight grad already slices to weight_shape; bias grad sums over spatial dims).

Testing

  • cargo test -p burn-backend-tests --test autodiff conv (98 conv autodiff tests pass, including the new regression test)
  • cargo test -p burn-backend-tests --test tensor conv_transpose (19 tests pass)
  • Added test_conv_transpose1d_padding_out_stride1_backward_shape mirroring the issue reproducer (stride=1, padding_out=1).

Copy link
Copy Markdown
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the issue!

Comment on lines +219 to +220
#[test]
fn test_conv_transpose1d_padding_out_stride1_backward_shape() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test should also validate the correctness, not just the shape. I think the current fix is incorrect (see my other comment).

You can use pytorch as a reference to compute the reference values.

Comment thread crates/burn-backend/src/backend/ops/modules/conv.rs Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented May 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 65.36%. Comparing base (bde7526) to head (4e93fb1).
⚠️ Report is 15 commits behind head on main.

❌ Your project check has failed because the head coverage (65.36%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #4916      +/-   ##
==========================================
+ Coverage   65.34%   65.36%   +0.01%     
==========================================
  Files        1170     1170              
  Lines      174264   174982     +718     
==========================================
+ Hits       113869   114371     +502     
- Misses      60395    60611     +216     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@SAY-5
Copy link
Copy Markdown
Contributor Author

SAY-5 commented May 6, 2026

@laggui thanks for the close look — I tried implementing the "slice padding_out off output_grad before the inverse conv" approach and it produces correct values for the stride=1 / dilation=1 case the issue reports, but breaks test_conv_transpose1d_complex (stride=2, dilation=2, padding=1, padding_out=1) — position 7 of x_grad goes to 8 instead of the expected 15. Trimming output_grad shortens it before B::conv1d but the conv1d's own padding argument then pads zeros on the right, so the right-edge receptive field collapses and the trailing input cells lose part of their contribution.

A few directions I considered before pinging:

  1. Slice output_grad and also strip the right-side conv1d padding (run conv1d with asymmetric / right=0 padding, which burn's conv1d API doesn't expose directly).
  2. Slice output_grad to L_core only when padding_out < stride and fall back to slicing the conv1d output when padding_out >= stride (the path the current PR takes).
  3. Re-derive the correct L_in cells by running conv1d with symmetric padding and then taking the first L_in outputs (mathematically equivalent to current PR).

For the correctness test, I planned to add cases mirroring test_conv_transpose1d_stride_padding_out with stride=1 and reference values computed against PyTorch's F.conv_transpose1d backward — happy to proceed once we settle on the approach.

What's the intended shape of the fix here? Specifically: should the inverse conv1d use a different padding spec when padding_out > 0, or is there a different primitive (e.g. conv1d with explicit (left, right) padding) you'd prefer me to call into?

@laggui
Copy link
Copy Markdown
Member

laggui commented May 7, 2026

@SAY-5 Ahhhh thanks for following up 🙏 my initial assessment was wrong for the conv transpose backward. The slice actually needs to happen after the conv as you initially suggested. The padding_out pixels actually contain valid gradient information.

After looking into it I made a small change to update the way the output size is computed based on the options as it can be derived to match the input size.

Should only be missing a correctness test with reference values. We're planning to release 0.21 today so I might merge your fix first and the test could be added in a follow-up PR.

@laggui laggui merged commit 2a34823 into tracel-ai:main May 7, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Backward Pass Panics .with_padding_out(!= 0)

2 participants