Skip to content

Commit 3c657d5

Browse files
committed
[release/2.8] fix miopen batchnorm changing output format
cherry pick of pytorch#162112
1 parent db3ba66 commit 3c657d5

File tree

3 files changed

+27
-36
lines changed

3 files changed

+27
-36
lines changed

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/NativeFunctions.h>
88
#else
99
#include <ATen/ops/empty.h>
10+
#include <ATen/ops/empty_like.h>
1011
#include <ATen/ops/miopen_batch_norm_native.h>
1112
#include <ATen/ops/miopen_batch_norm_backward_native.h>
1213
#endif
@@ -102,7 +103,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
102103
mode = miopenBNSpatial;
103104
}
104105

105-
auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
106+
auto output_t = at::empty(input_t, input_t.options(), input_t.suggest_memory_format());
106107
TensorArg output{ output_t, "output", 0 };
107108

108109
auto handle = getMiopenHandle();
@@ -170,22 +171,15 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
170171
const std::optional<Tensor>& save_var_t_opt,
171172
double epsilon) {
172173
// See [Note: hacky wrapper removal for optional tensor]
173-
const Tensor& running_mean =
174-
running_mean_opt.value_or(Tensor());
175-
const Tensor& running_var =
176-
running_var_opt.value_or(Tensor());
177-
const Tensor& save_mean_t =
178-
save_mean_t_opt.value_or(Tensor());
179-
const Tensor& save_var_t =
180-
save_var_t_opt.value_or(Tensor());
174+
const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor());
175+
const Tensor& save_var_t = save_var_t_opt.value_or(Tensor());
181176

182177
auto grad_output_contig =
183178
grad_output_t.contiguous(input_t.suggest_memory_format());
184-
TensorArg input{ input_t, "input", 1 },
185-
grad_output{ grad_output_contig, "grad_output", 2 },
186-
weight{ weight_t, "weight", 3 },
187-
save_mean{ save_mean_t, "save_mean", 4 },
188-
save_var{ save_var_t, "save_var", 5 };
179+
TensorArg input{input_t, "input", 1},
180+
grad_output{grad_output_contig, "grad_output", 2},
181+
weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
182+
save_var{save_var_t, "save_var", 5};
189183
CheckedFrom c = "miopen_batch_norm_backward";
190184

191185
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});

test/nn/test_convolution.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
skipCUDAIfMiopen,
3030
skipCUDAIfNoCudnn,
3131
skipCUDAIfNoMiopen,
32-
skipCUDAIfNotMiopenSuggestNHWC,
3332
skipCUDAIfRocm,
3433
skipMeta,
3534
skipMPS,
@@ -50,8 +49,6 @@
5049
parametrize as parametrize_test,
5150
run_tests,
5251
set_default_dtype,
53-
skipIfNotMiopenSuggestNHWC,
54-
skipIfRocmVersionLessThan,
5552
subtest,
5653
TEST_SCIPY,
5754
TEST_WITH_ROCM,
@@ -61,6 +58,11 @@
6158
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
6259

6360

61+
if TEST_WITH_ROCM:
62+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
63+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
64+
65+
6466
if TEST_SCIPY:
6567
import scipy.ndimage
6668
import scipy.signal
@@ -710,7 +712,6 @@ def test_ConvTranspose2d_half_cublas_gemm(self):
710712
# Almost identical to the above `test_Conv2d_naive_groups`
711713
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
712714
@tf32_on_and_off(0.001)
713-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
714715
def test_Conv2d_groups_nobias(self):
715716
dev_dtypes = [("cpu", torch.float)]
716717
if TEST_CUDA:
@@ -756,7 +757,6 @@ def test_Conv2d_groups_nobias(self):
756757
# and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024
757758
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
758759
@tf32_on_and_off(0.001)
759-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
760760
def test_Conv2d_groups_nobias_v2(self):
761761
torch.manual_seed(123)
762762
dev_dtypes = [("cpu", torch.float)]
@@ -891,7 +891,6 @@ def test_conv_tbc(self):
891891

892892
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
893893
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
894-
@skipIfNotMiopenSuggestNHWC
895894
def test_grouped_conv_cudnn_nhwc_support(self):
896895
# in order to catch the hols in grouped convolution in nhwc support for earlier cudnn version
897896
input = torch.randn((16, 16, 8, 8), dtype=torch.float16, device="cuda").to(
@@ -3140,7 +3139,6 @@ def test_conv_noncontig_weights_and_bias(self, device):
31403139

31413140
@onlyCUDA
31423141
@largeTensorTest("12GB")
3143-
@skipIfRocmVersionLessThan((6, 0))
31443142
def test_conv_transposed_large(self, device):
31453143
dtype = torch.half if self.device_type == "cuda" else torch.float
31463144
conv = nn.ConvTranspose2d(1, 1, 1, 1, bias=False).to(device).to(dtype)
@@ -3184,7 +3182,6 @@ def test_conv_transposed_large(self, device):
31843182
self.assertEqual(maxdiff3, 0)
31853183

31863184
@onlyCUDA
3187-
@skipCUDAIfRocm
31883185
@largeTensorTest("12GB")
31893186
def test_conv_large(self, device):
31903187
dtype = torch.half if self.device_type == "cuda" else torch.float
@@ -3217,7 +3214,6 @@ def test_conv_large(self, device):
32173214
self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
32183215

32193216
@onlyCUDA
3220-
@skipCUDAIfRocm
32213217
@largeTensorTest("20GB", "cpu")
32223218
@largeTensorTest("60GB", "cuda")
32233219
def test_conv_large_batch_1(self, device):
@@ -3365,7 +3361,6 @@ def test_ConvTranspose3d_size_1_kernel(self, device):
33653361
@dtypes(torch.float)
33663362
@torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False)
33673363
@tf32_on_and_off(0.001)
3368-
@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7")
33693364
def test_Conv2d_naive_groups(self, device, dtype):
33703365
# Check that grouped convolutions matches two half convolutions
33713366
m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype)
@@ -3634,19 +3629,21 @@ def helper(
36343629
)
36353630

36363631
@onlyCUDA
3637-
@skipCUDAIfNotMiopenSuggestNHWC
36383632
@dtypes(torch.half, torch.float, torch.cfloat)
36393633
def test_conv_cudnn_nhwc(self, device, dtype):
36403634
def helper(n, c, h, w, out_channels, kernel_size, groups):
3641-
input = torch.randint(-3, 3, (n, c, h, w), dtype=dtype, device=device).to(
3642-
memory_format=torch.channels_last
3643-
)
3635+
# randint with dtype=torch.cfloat fails with
3636+
# RuntimeError: check_random_bounds handles only integral, floating-point and boolean types
3637+
# must create randint and randint_like using default int64, then cast to desired
3638+
input = torch.randint(
3639+
-3, 3, (n, c, h, w), dtype=torch.int64, device=device
3640+
).to(dtype, memory_format=torch.channels_last)
36443641
input.requires_grad_()
36453642
conv = nn.Conv2d(c, out_channels, kernel_size, groups=groups).to(
36463643
device="cuda", dtype=dtype, memory_format=torch.channels_last
36473644
)
36483645
for p in conv.parameters():
3649-
p.data = torch.randint_like(p, -3, 3)
3646+
p.data = torch.randint_like(p, -3, 3, dtype=torch.int64).to(p.dtype)
36503647

36513648
# use FP64 channels-first conv as reference
36523649
ref_input = input.detach().clone().contiguous().double().requires_grad_()
@@ -3660,7 +3657,7 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36603657
out = conv(input)
36613658
ref_out = ref_conv(ref_input)
36623659

3663-
grad = torch.randint_like(out, -3, 3)
3660+
grad = torch.randint_like(out, -3, 3, dtype=torch.int64).to(out.dtype)
36643661
ref_grad = grad.detach().clone().double().contiguous()
36653662

36663663
out.backward(grad)
@@ -3687,7 +3684,6 @@ def helper(n, c, h, w, out_channels, kernel_size, groups):
36873684
helper(1, 16, 56, 56, out_channels=16, kernel_size=3, groups=16)
36883685

36893686
@onlyCUDA
3690-
@skipCUDAIfRocm
36913687
@dtypes(torch.half, torch.float)
36923688
def test_conv_cudnn_ndhwc(self, device, dtype):
36933689
def helper(n, c, d, h, w, out_channels, kernel_size, groups):
@@ -3817,7 +3813,6 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
38173813
)
38183814

38193815
@onlyCUDA
3820-
@skipCUDAIfNotMiopenSuggestNHWC
38213816
@tf32_on_and_off(0.05)
38223817
def test_conv_cudnn_mismatch_memory_format(self, device):
38233818
configs = [
@@ -3950,7 +3945,6 @@ def test_cudnn_convolution_add_relu(self, device, dtype):
39503945
self.assertEqual(F.relu(conv2d_out + alpha * z), cudnn_out)
39513946

39523947
@onlyCUDA
3953-
@skipCUDAIfRocm
39543948
def test_convert_conv2d_weight_memory_format(self, device):
39553949
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device=device)
39563950
model = nn.Sequential(nn.Conv2d(8, 4, 3), nn.BatchNorm2d(4)).to(device).float()
@@ -3970,7 +3964,6 @@ def test_convert_conv2d_weight_memory_format(self, device):
39703964
self.assertTrue(out.is_contiguous(memory_format=memory_format))
39713965

39723966
@onlyCUDA
3973-
@skipCUDAIfRocm
39743967
def test_convert_conv3d_weight_memory_format(self, device):
39753968
input = torch.randint(
39763969
1, 10, (2, 8, 4, 4, 4), dtype=torch.float32, device=device

test/test_nn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@
5858

5959
AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported()
6060

61+
if TEST_WITH_ROCM:
62+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
63+
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM"] = "1"
64+
6165
# load_tests from common_utils is used to automatically filter tests for
6266
# sharding on sandcastle. This line silences flake warnings
6367
load_tests = load_tests
@@ -3493,15 +3497,15 @@ def test_cudnn_forward_exception(self):
34933497
self.assertRaisesRegex(RuntimeError, re.escape("input.size(-1) must be equal to input_size"), rnn, x_wrong)
34943498

34953499
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
3496-
@skipIfRocm
34973500
def test_cudnn_weight_format(self):
34983501
rnns = [
34993502
nn.LSTM(10, 20, batch_first=True),
35003503
nn.LSTM(10, 20, batch_first=True, proj_size=10),
35013504
nn.GRU(10, 20, batch_first=True),
35023505
nn.RNN(10, 20, batch_first=True)
35033506
]
3504-
first_warn = True
3507+
# ROCm RNN does not issue warning about single contig chunk of memory, so don't assert it
3508+
first_warn = False if torch.version.hip else True
35053509
for rnn in rnns:
35063510
rnn.cuda()
35073511
input = torch.randn(5, 4, 10, requires_grad=True, device="cuda")

0 commit comments

Comments
 (0)