Skip to content

Commit

Permalink
Remove non tile multiple width/height from conv2d
Browse files Browse the repository at this point in the history
These two features are non critical for conv2d
meaning they don't contribute to enabling any model
perf on any model or improve pass rate on any sweep.

Problem with these features is that they kick in
in very unpredictable conditions for both users and
developers as they have many limits/conditions.

They are adding to conv2d test matrix, but they are
hard to test for as deriving tests that will trigger
them on multiple hw platforms is not easy.

Moreover they are source of bugs like #17647, and it's
often non obvious that bugs originate from these features
and when faced with a bug in conv2d first thing is to go to the
code and manually disable them to check for that.

For the reasons above these will get removed, and by
removing them #17647 will be fixed.
  • Loading branch information
Pavle Josipovic authored and pavlejosipovic committed Feb 19, 2025
1 parent 1aba0a5 commit c69f5c0
Show file tree
Hide file tree
Showing 16 changed files with 163 additions and 685 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,6 @@ def __init__(
self.conv2_config_override = {}
if (out_channels, out_channels, input_height, input_width) in config_override:
self.conv2_config_override = config_override[(out_channels, out_channels, input_height, input_width)]
# if use_in_shortcut:
# self.conv2_config_override["grid_size"] = self.conv_shortcut.conv.grid_size
# self.conv2_config_override["per_core_out_matrix_height"] = self.conv_shortcut.conv.per_core_out_matrix_height
# self.conv2_config_override["per_core_weight_matrix_width"] = self.conv_shortcut.conv.per_core_out_matrix_width

self.conv2_input_height = conv2_input_height
self.conv2_input_width = conv2_input_width
Expand Down
202 changes: 0 additions & 202 deletions tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,6 @@ def test_conv_features(
if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b:
pytest.skip("Row major layout not compatible with bfloat8_b")

if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat16 and packer_l1_acc and fp32_accum:
pytest.skip("skipping due to pack_untilize_dst issue!")

run_conv(
device,
torch_tensor_map,
Expand Down Expand Up @@ -2592,205 +2589,6 @@ def test_conv_for_vanilla_unet(
)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override",
(
# unique convs in rn50 (complete list)
# first conv post folding and input_channels padding to tile width
(16, 64, 64, 14, 14, 3, 3, 1, 1, 1, 1, HS, None),
# rn50 layer1
(8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None),
(16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None),
(20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None),
# rn50 layer2
(8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None),
(16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None),
(20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None),
(8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None),
(16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None),
(20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None),
(1, 32, 32, 240, 320, 3, 3, 1, 1, 1, 1, HS, None),
(1, 64, 32, 240, 320, 3, 3, 1, 1, 1, 1, HS, None),
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat8_b, ttnn.bfloat16],
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16, ttnn.float32],
)
@pytest.mark.parametrize("fp32_accum", [False, True], ids=["no_fp32_accum", "fp32_accum"])
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"])
@pytest.mark.parametrize("has_bias", [True, False], ids=["with_bias", "no_bias"])
def test_non_tile_multiple_height_conv_wh(
device,
torch_tensor_map,
use_program_cache,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
shard_layout,
config_override,
fp32_accum,
packer_l1_acc,
has_bias,
):
if device.core_grid.y == 7:
pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range")

if (
is_grayskull()
and activations_dtype == ttnn.bfloat16
and batch_size == 20
and (
output_channels == 64
or (
stride_h == 2
and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16))
)
)
):
pytest.skip("Skipping test because it won't fit in L1!")

if activations_dtype == ttnn.float32 and (batch_size >= 16 or (output_channels == 64 or input_height >= 240)):
pytest.skip("Skipping test because it won't fit in L1!")

if (
(weights_dtype == ttnn.bfloat16 and batch_size == 20 and output_channels == 128 and input_height == 56)
or (weights_dtype == ttnn.bfloat16 and batch_size == 20 and output_channels == 64)
or (weights_dtype == ttnn.bfloat8_b and batch_size == 20 and output_channels == 128 and input_height == 56)
):
pytest.skip("Skipping test because it won't fit in L1!")

if has_bias and packer_l1_acc and (fp32_accum or activations_dtype is ttnn.float32):
pytest.skip("skipping due to pack_untilize_dst issue! --> #14236")

use_shallow_conv_variant = (input_channels == 16) and device.arch() != ttnn.device.Arch.WORMHOLE_B0
run_conv(
device,
torch_tensor_map,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
config_override=config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=use_shallow_conv_variant,
packer_l1_acc=packer_l1_acc,
fp32_accum=fp32_accum,
has_bias=has_bias,
output_layout=ttnn.ROW_MAJOR_LAYOUT,
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override",
(
(1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 128, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 192, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 256, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 384, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 448, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 512, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 576, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 64, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 64, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 128, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 192, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 256, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 384, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 448, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 512, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 576, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 128, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 320, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
(1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None),
),
)
@pytest.mark.parametrize(
"weights_dtype",
[ttnn.bfloat16, ttnn.bfloat8_b],
)
@pytest.mark.parametrize(
"activations_dtype",
[ttnn.bfloat16],
)
@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi])
def test_non_tile_multiple_width_conv_wh(
device,
torch_tensor_map,
use_program_cache,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
shard_layout,
config_override,
):
run_conv(
device,
torch_tensor_map,
math_fidelity,
activations_dtype,
weights_dtype,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
config_override,
shard_layout=shard_layout,
use_shallow_conv_variant=(input_channels == 16),
output_layout=ttnn.ROW_MAJOR_LAYOUT,
)


@skip_for_grayskull()
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_shallow_conv_with_tiled_input(device):
Expand Down
36 changes: 15 additions & 21 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,18 @@ Result conv2d(

ShardOrientation shard_orientation =
conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR;
bool is_non_tile_mul_width = check_non_tile_mul_width(compute_grid_size, conv_config, in_channels);

auto [input_tensor_post_tm, parallel_config, output_parallel_config, use_non_tile_height] =
shard_or_reshard_tensor_if_required(
device,
input_tensor,
conv_config,
batch_size,
output_height,
output_width,
in_channels,
out_channels,
mm_conv,
auto_shard,
is_non_tile_mul_width);
auto [input_tensor_post_tm, parallel_config, output_parallel_config] = shard_or_reshard_tensor_if_required(
device,
input_tensor,
conv_config,
batch_size,
output_height,
output_width,
in_channels,
out_channels,
mm_conv,
auto_shard);

auto [opt_conv_op_parallel_config, opt_conv_op_block_config, conv_out_memory_config] = get_conv_configs(
conv_config,
Expand Down Expand Up @@ -137,8 +134,7 @@ Result conv2d(
groups,
opt_conv_op_block_config.act_block_h_ntiles,
input_width,
true,
is_non_tile_mul_width);
true);
}
// if 1x1 conv w/ stride 1, convert input tensor to tile layout if required
if (mm_conv) {
Expand All @@ -160,7 +156,7 @@ Result conv2d(
.dilation_hw = {dilation[0], dilation[1]},
.num_cores_nhw = opt_conv_op_parallel_config.num_cores_nhw,
.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid,
.snap_to_tile = !use_non_tile_height,
.snap_to_tile = true,
};

bool bypass_halo =
Expand All @@ -185,7 +181,7 @@ Result conv2d(
parallel_config.shard_orientation == ShardOrientation::COL_MAJOR,
0,
input_tensor_post_tm.memory_config(),
!use_non_tile_height);
true);

if (conv_config.deallocate_activation) {
input_tensor_post_tm.deallocate(/*force*/ true);
Expand Down Expand Up @@ -217,9 +213,7 @@ Result conv2d(
compute_config,
conv_config.enable_act_double_buffer,
conv_config.enable_weights_double_buffer,
conv_config.enable_split_reader,
conv_config.enable_subblock_padding,
use_non_tile_height);
conv_config.enable_split_reader);

if (memory_config.has_value() && memory_config.value() != conv_output.memory_config()) {
conv_output = ttnn::to_memory_config(conv_output, memory_config.value(), std::nullopt);
Expand Down
15 changes: 7 additions & 8 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ void py_bind_conv2d(py::module& module) {
compute_grid_size,
block_shard_orientation,
enable_channels_padding,
is_out_tiled,
false);
is_out_tiled);
},
py::arg("shard_layout"),
py::arg("batch_size"),
Expand Down Expand Up @@ -384,16 +383,16 @@ void py_bind_conv2d(py::module& module) {
py::arg("grid_size"),
py::arg("num_cores_nhw") = 1,
py::arg("num_cores_c") = 1,
py::arg("per_core_out_matrix_height").noconvert(),
py::arg("per_core_out_matrix_width").noconvert())
py::arg("per_core_out_matrix_height_ntiles").noconvert(),
py::arg("per_core_out_matrix_width_ntiles").noconvert())
.def_property_readonly("grid_size", [](const OptimizedConvParallelizationConfig& c) { return c.grid_size; })
.def_property_readonly(
"num_cores_nhw", [](const OptimizedConvParallelizationConfig& c) { return c.num_cores_nhw; })
.def_property_readonly(
"per_core_out_matrix_height",
[](const OptimizedConvParallelizationConfig& c) { return c.per_core_out_matrix_height; })
.def_property_readonly("per_core_out_matrix_width", [](const OptimizedConvParallelizationConfig& c) {
return c.per_core_out_matrix_width;
"per_core_out_matrix_height_ntiles",
[](const OptimizedConvParallelizationConfig& c) { return c.per_core_out_matrix_height_ntile; })
.def_property_readonly("per_core_out_matrix_width_ntiles", [](const OptimizedConvParallelizationConfig& c) {
return c.per_core_out_matrix_width_ntile;
});

py::class_<OptimizedConvBlockConfig>(module, "OptimizedConvBlockConfig")
Expand Down
Loading

0 comments on commit c69f5c0

Please sign in to comment.