Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions crates/burn-backend-tests/tests/autodiff/conv_transpose1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,52 @@ fn test_conv_transpose1d_complex() {
test.assert_grads(grads);
}

/// Regression test for #4845.
///
/// `ConvTranspose1d` with `padding_out != 0` and `stride == 1` used to panic in
/// the backward pass because `conv_transpose1d_x_backward` did not account for
/// the trailing `padding_out` cells, producing a gradient longer than `x`.
#[test]
fn test_conv_transpose1d_padding_out_stride1_backward_shape() {
Comment on lines +219 to +220
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test should also validate the correctness, not just the shape. I think the current fix is incorrect (see my other comment).

You can use pytorch as a reference to compute the reference values.

let device = AutodiffDevice::new();
let batch_size = 2;
let channels_in = 2;
let channels_out = 2;
let kernel_size = 3;
let size_in = 4;
let padding_out = 1;

let shape_x = Shape::new([batch_size, channels_in, size_in]);
let shape_weight = Shape::new([channels_in, channels_out, kernel_size]);
let weight = TestTensor::from_data(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape::<3, _>(shape_weight.clone())
.into_data(),
&device,
)
.require_grad();
let x = TestTensor::from_data(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape::<3, _>(shape_x.clone())
.into_data(),
&device,
)
.require_grad();

let output = conv_transpose1d(
x.clone(),
weight.clone(),
None,
ConvTransposeOptions::new([1], [0], [padding_out], [1], 1),
);
let grads = output.backward();

let x_grad = x.grad(&grads).unwrap();
let weight_grad = weight.grad(&grads).unwrap();
assert_eq!(x_grad.shape(), shape_x);
assert_eq!(weight_grad.shape(), shape_weight);
}

struct ConvTranspose1dTestCase {
batch_size: usize,
channels: [usize; 2],
Expand Down
97 changes: 91 additions & 6 deletions crates/burn-backend/src/backend/ops/modules/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ pub(crate) fn conv_transpose1d_x_backward<B: Backend>(
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B> {
B::conv1d(
let grad = B::conv1d(
output_grad,
weight,
None,
Expand All @@ -484,7 +484,37 @@ pub(crate) fn conv_transpose1d_x_backward<B: Backend>(
options.dilation,
options.groups,
),
)
);
// The forward transpose conv extends its output by `padding_out` zero cells
// on the trailing edge; those cells carry gradient but do not depend on
// any input cell. The inverse conv1d above therefore produces a result
// longer than the original `x` by `floor(padding_out / stride)`. Trim the
// trailing cells to recover the input length.
trim_padding_out_1d::<B>(grad, options.padding_out[0], options.stride[0])
Comment thread
laggui marked this conversation as resolved.
Outdated
}

fn trim_padding_out_1d<B: Backend>(
grad: FloatTensor<B>,
padding_out: usize,
stride: usize,
) -> FloatTensor<B> {
if padding_out == 0 {
return grad;
}
let trim = padding_out / stride.max(1);
if trim == 0 {
return grad;
}
let [batch_size, channels_in, length] = grad.shape().dims();
if trim >= length {
return grad;
}
let slices = [
Slice::from(0..batch_size),
Slice::from(0..channels_in),
Slice::from(0..length - trim),
];
B::float_slice(grad, &slices)
}

/// Calculate the [1D convolution transpose](crate::ops::ModuleOps::conv_transpose1d) backward pass, returning the gradient for `weight`.
Expand Down Expand Up @@ -531,7 +561,7 @@ pub(crate) fn conv_transpose2d_x_backward<B: Backend>(
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B> {
B::conv2d(
let grad = B::conv2d(
output_grad,
weight,
None,
Expand All @@ -541,7 +571,33 @@ pub(crate) fn conv_transpose2d_x_backward<B: Backend>(
options.dilation,
options.groups,
),
)
);
trim_padding_out_2d::<B>(grad, options.padding_out, options.stride)
}

fn trim_padding_out_2d<B: Backend>(
grad: FloatTensor<B>,
padding_out: [usize; 2],
stride: [usize; 2],
) -> FloatTensor<B> {
if padding_out[0] == 0 && padding_out[1] == 0 {
return grad;
}
let h_trim = padding_out[0] / stride[0].max(1);
let w_trim = padding_out[1] / stride[1].max(1);
if h_trim == 0 && w_trim == 0 {
return grad;
}
let [batch_size, channels_in, height, width] = grad.shape().dims();
let h_trim = h_trim.min(height);
let w_trim = w_trim.min(width);
let slices = [
Slice::from(0..batch_size),
Slice::from(0..channels_in),
Slice::from(0..height - h_trim),
Slice::from(0..width - w_trim),
];
B::float_slice(grad, &slices)
}

/// Calculate the [2D convolution transpose](crate::ops::ModuleOps::conv_transpose2d) backward pass, returning the gradient for `weight`.
Expand Down Expand Up @@ -591,7 +647,7 @@ pub(crate) fn conv_transpose3d_x_backward<B: Backend>(
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<B> {
B::conv3d(
let grad = B::conv3d(
output_grad,
weight,
None,
Expand All @@ -601,7 +657,36 @@ pub(crate) fn conv_transpose3d_x_backward<B: Backend>(
options.dilation,
options.groups,
),
)
);
trim_padding_out_3d::<B>(grad, options.padding_out, options.stride)
}

fn trim_padding_out_3d<B: Backend>(
grad: FloatTensor<B>,
padding_out: [usize; 3],
stride: [usize; 3],
) -> FloatTensor<B> {
if padding_out[0] == 0 && padding_out[1] == 0 && padding_out[2] == 0 {
return grad;
}
let d_trim = padding_out[0] / stride[0].max(1);
let h_trim = padding_out[1] / stride[1].max(1);
let w_trim = padding_out[2] / stride[2].max(1);
if d_trim == 0 && h_trim == 0 && w_trim == 0 {
return grad;
}
let [batch_size, channels_in, depth, height, width] = grad.shape().dims();
let d_trim = d_trim.min(depth);
let h_trim = h_trim.min(height);
let w_trim = w_trim.min(width);
let slices = [
Slice::from(0..batch_size),
Slice::from(0..channels_in),
Slice::from(0..depth - d_trim),
Slice::from(0..height - h_trim),
Slice::from(0..width - w_trim),
];
B::float_slice(grad, &slices)
}

/// Calculate the [3D convolution transpose](crate::ops::ModuleOps::conv_transpose3d) backward pass, returning the gradient for `weight`.
Expand Down
Loading