diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py index 691081f1952..58f3ab618b0 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py @@ -182,10 +182,6 @@ def __init__( self.conv2_config_override = {} if (out_channels, out_channels, input_height, input_width) in config_override: self.conv2_config_override = config_override[(out_channels, out_channels, input_height, input_width)] - # if use_in_shortcut: - # self.conv2_config_override["grid_size"] = self.conv_shortcut.conv.grid_size - # self.conv2_config_override["per_core_out_matrix_height"] = self.conv_shortcut.conv.per_core_out_matrix_height - # self.conv2_config_override["per_core_weight_matrix_width"] = self.conv_shortcut.conv.per_core_out_matrix_width self.conv2_input_height = conv2_input_height self.conv2_input_width = conv2_input_width diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index 610cd0ef6e3..082cb3c90fa 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -384,9 +384,6 @@ def test_conv_features( if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b: pytest.skip("Row major layout not compatible with bfloat8_b") - if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat16 and packer_l1_acc and fp32_accum: - pytest.skip("skipping due to pack_untilize_dst issue!") - run_conv( device, torch_tensor_map, @@ -2592,205 +2589,6 @@ def test_conv_for_vanilla_unet( ) -@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) -@pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", - ( - # unique convs in rn50 (complete list) - # first conv post folding and input_channels padding to tile width - (16, 64, 64, 14, 14, 3, 3, 1, 1, 1, 1, HS, None), - # rn50 layer1 - (8, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), - (16, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), - (20, 64, 64, 56, 56, 3, 3, 1, 1, 1, 1, HS, None), - # rn50 layer2 - (8, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), - (16, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), - (20, 128, 128, 56, 56, 3, 3, 2, 2, 1, 1, HS, None), - (8, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), - (16, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), - (20, 128, 128, 28, 28, 3, 3, 1, 1, 1, 1, HS, None), - (1, 32, 32, 240, 320, 3, 3, 1, 1, 1, 1, HS, None), - (1, 64, 32, 240, 320, 3, 3, 1, 1, 1, 1, HS, None), - ), -) -@pytest.mark.parametrize( - "weights_dtype", - [ttnn.bfloat8_b, ttnn.bfloat16], -) -@pytest.mark.parametrize( - "activations_dtype", - [ttnn.bfloat16, ttnn.float32], -) -@pytest.mark.parametrize("fp32_accum", [False, True], ids=["no_fp32_accum", "fp32_accum"]) -@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) -@pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) -@pytest.mark.parametrize("has_bias", [True, False], ids=["with_bias", "no_bias"]) -def test_non_tile_multiple_height_conv_wh( - device, - torch_tensor_map, - use_program_cache, - math_fidelity, - activations_dtype, - weights_dtype, - batch_size, - output_channels, - input_channels, - input_height, - input_width, - filter_height, - filter_width, - stride_h, - stride_w, - pad_h, - pad_w, - shard_layout, - config_override, - fp32_accum, - packer_l1_acc, - has_bias, -): - if device.core_grid.y == 7: - pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range") - - if ( - is_grayskull() - and activations_dtype == ttnn.bfloat16 - and batch_size == 20 - and ( - output_channels == 64 - or ( - stride_h == 2 - and (output_channels == 256 or (output_channels == 128 and weights_dtype == ttnn.bfloat16)) - ) - ) - ): - pytest.skip("Skipping test because it won't fit in L1!") - - if activations_dtype == ttnn.float32 and (batch_size >= 16 or (output_channels == 64 or input_height >= 240)): - pytest.skip("Skipping test because it won't fit in L1!") - - if ( - (weights_dtype == ttnn.bfloat16 and batch_size == 20 and output_channels == 128 and input_height == 56) - or (weights_dtype == ttnn.bfloat16 and batch_size == 20 and output_channels == 64) - or (weights_dtype == ttnn.bfloat8_b and batch_size == 20 and output_channels == 128 and input_height == 56) - ): - pytest.skip("Skipping test because it won't fit in L1!") - - if has_bias and packer_l1_acc and (fp32_accum or activations_dtype is ttnn.float32): - pytest.skip("skipping due to pack_untilize_dst issue! --> #14236") - - use_shallow_conv_variant = (input_channels == 16) and device.arch() != ttnn.device.Arch.WORMHOLE_B0 - run_conv( - device, - torch_tensor_map, - math_fidelity, - activations_dtype, - weights_dtype, - batch_size, - output_channels, - input_channels, - input_height, - input_width, - filter_height, - filter_width, - stride_h, - stride_w, - pad_h, - pad_w, - config_override=config_override, - shard_layout=shard_layout, - use_shallow_conv_variant=use_shallow_conv_variant, - packer_l1_acc=packer_l1_acc, - fp32_accum=fp32_accum, - has_bias=has_bias, - output_layout=ttnn.ROW_MAJOR_LAYOUT, - ) - - -@skip_for_grayskull() -@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) -@pytest.mark.parametrize( - "batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override", - ( - (1, 64, 64, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 128, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 192, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 256, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 384, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 448, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 512, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 576, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 64, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 64, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 128, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 192, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 256, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 384, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 448, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 512, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 576, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 128, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 320, 320, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - (1, 640, 640, 16, 16, 3, 3, 1, 1, 1, 1, BS, None), - ), -) -@pytest.mark.parametrize( - "weights_dtype", - [ttnn.bfloat16, ttnn.bfloat8_b], -) -@pytest.mark.parametrize( - "activations_dtype", - [ttnn.bfloat16], -) -@pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) -def test_non_tile_multiple_width_conv_wh( - device, - torch_tensor_map, - use_program_cache, - math_fidelity, - activations_dtype, - weights_dtype, - batch_size, - output_channels, - input_channels, - input_height, - input_width, - filter_height, - filter_width, - stride_h, - stride_w, - pad_h, - pad_w, - shard_layout, - config_override, -): - run_conv( - device, - torch_tensor_map, - math_fidelity, - activations_dtype, - weights_dtype, - batch_size, - output_channels, - input_channels, - input_height, - input_width, - filter_height, - filter_width, - stride_h, - stride_w, - pad_h, - pad_w, - config_override, - shard_layout=shard_layout, - use_shallow_conv_variant=(input_channels == 16), - output_layout=ttnn.ROW_MAJOR_LAYOUT, - ) - - @skip_for_grayskull() @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) def test_shallow_conv_with_tiled_input(device): diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 50b5c017a41..a3928a36629 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -90,21 +90,18 @@ Result conv2d( ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - bool is_non_tile_mul_width = check_non_tile_mul_width(compute_grid_size, conv_config, in_channels); - auto [input_tensor_post_tm, parallel_config, output_parallel_config, use_non_tile_height] = - shard_or_reshard_tensor_if_required( - device, - input_tensor, - conv_config, - batch_size, - output_height, - output_width, - in_channels, - out_channels, - mm_conv, - auto_shard, - is_non_tile_mul_width); + auto [input_tensor_post_tm, parallel_config, output_parallel_config] = shard_or_reshard_tensor_if_required( + device, + input_tensor, + conv_config, + batch_size, + output_height, + output_width, + in_channels, + out_channels, + mm_conv, + auto_shard); auto [opt_conv_op_parallel_config, opt_conv_op_block_config, conv_out_memory_config] = get_conv_configs( conv_config, @@ -137,8 +134,7 @@ Result conv2d( groups, opt_conv_op_block_config.act_block_h_ntiles, input_width, - true, - is_non_tile_mul_width); + true); } // if 1x1 conv w/ stride 1, convert input tensor to tile layout if required if (mm_conv) { @@ -160,7 +156,7 @@ Result conv2d( .dilation_hw = {dilation[0], dilation[1]}, .num_cores_nhw = opt_conv_op_parallel_config.num_cores_nhw, .core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid, - .snap_to_tile = !use_non_tile_height, + .snap_to_tile = true, }; bool bypass_halo = @@ -185,7 +181,7 @@ Result conv2d( parallel_config.shard_orientation == ShardOrientation::COL_MAJOR, 0, input_tensor_post_tm.memory_config(), - !use_non_tile_height); + true); if (conv_config.deallocate_activation) { input_tensor_post_tm.deallocate(/*force*/ true); @@ -217,9 +213,7 @@ Result conv2d( compute_config, conv_config.enable_act_double_buffer, conv_config.enable_weights_double_buffer, - conv_config.enable_split_reader, - conv_config.enable_subblock_padding, - use_non_tile_height); + conv_config.enable_split_reader); if (memory_config.has_value() && memory_config.value() != conv_output.memory_config()) { conv_output = ttnn::to_memory_config(conv_output, memory_config.value(), std::nullopt); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index ef664e12add..0591ed02d0c 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -295,8 +295,7 @@ void py_bind_conv2d(py::module& module) { compute_grid_size, block_shard_orientation, enable_channels_padding, - is_out_tiled, - false); + is_out_tiled); }, py::arg("shard_layout"), py::arg("batch_size"), @@ -384,16 +383,16 @@ void py_bind_conv2d(py::module& module) { py::arg("grid_size"), py::arg("num_cores_nhw") = 1, py::arg("num_cores_c") = 1, - py::arg("per_core_out_matrix_height").noconvert(), - py::arg("per_core_out_matrix_width").noconvert()) + py::arg("per_core_out_matrix_height_ntiles").noconvert(), + py::arg("per_core_out_matrix_width_ntiles").noconvert()) .def_property_readonly("grid_size", [](const OptimizedConvParallelizationConfig& c) { return c.grid_size; }) .def_property_readonly( "num_cores_nhw", [](const OptimizedConvParallelizationConfig& c) { return c.num_cores_nhw; }) .def_property_readonly( - "per_core_out_matrix_height", - [](const OptimizedConvParallelizationConfig& c) { return c.per_core_out_matrix_height; }) - .def_property_readonly("per_core_out_matrix_width", [](const OptimizedConvParallelizationConfig& c) { - return c.per_core_out_matrix_width; + "per_core_out_matrix_height_ntiles", + [](const OptimizedConvParallelizationConfig& c) { return c.per_core_out_matrix_height_ntile; }) + .def_property_readonly("per_core_out_matrix_width_ntiles", [](const OptimizedConvParallelizationConfig& c) { + return c.per_core_out_matrix_width_ntile; }); py::class_(module, "OptimizedConvBlockConfig") diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index 959acd36d04..6f67fb238a6 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -10,6 +10,7 @@ #include "conv2d_utils.hpp" #include +#include "tt-metalium/constants.hpp" #include "tt-metalium/hal.hpp" #include "ttnn/operations/conv/conv2d/device/conv2d_op.hpp" #include "ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp" @@ -80,28 +81,6 @@ uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num1, uint32_t n return divisor; } -bool check_non_tile_mul_width( - const CoreCoord& compute_grid, const Conv2dConfig& conv_config, const uint32_t in_channels) { - auto num_cores_c = conv_config.transpose_shards ? compute_grid.y : compute_grid.x; - auto elem_size = conv_config.weights_dtype == DataType::BFLOAT8_B ? 1 : 2; - bool is_non_tile_mul_width = - (conv_config.shard_layout.has_value() && conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED) && - conv_config.act_block_h_override == 0 && - (conv_config.weights_dtype == DataType::BFLOAT8_B || conv_config.weights_dtype == DataType::BFLOAT16) && - conv_config.output_layout == Layout::ROW_MAJOR && ((elem_size * in_channels) % (16 * num_cores_c)) == 0; - return is_non_tile_mul_width; -} - -bool check_non_tile_height(const Conv2dConfig& conv_config, const uint32_t out_channels) { - bool use_non_tile_height = (conv_config.shard_layout.has_value() && - conv_config.shard_layout.value() == TensorMemoryLayout::HEIGHT_SHARDED) && - out_channels <= 256 && conv_config.act_block_h_override == 0 && - (conv_config.dtype == DataType::BFLOAT16 || conv_config.dtype == DataType::FLOAT32) && - conv_config.output_layout == Layout::ROW_MAJOR; - use_non_tile_height = use_non_tile_height && conv_config.input_channels_alignment != 16; - return use_non_tile_height; -} - ParallelConfig determine_parallel_config( const TensorMemoryLayout shard_layout, uint32_t batch_size, @@ -113,17 +92,9 @@ ParallelConfig determine_parallel_config( ShardOrientation block_shard_orientation, bool enable_channels_padding, bool is_out_tiled, - bool is_non_tile_mul_shard_width, uint32_t act_block_h_override) { uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1; uint32_t effective_tile_width = is_out_tiled ? tt::constants::TILE_WIDTH : 1; - // If the shard is not tile-multiplicatively along the width dimension, - // set the effective tile width to 1 and disable channel padding. - // Required(if any) paddings are added while creating the matrices. - if (is_non_tile_mul_shard_width) { - effective_tile_width = 1; - enable_channels_padding = false; - } uint32_t out_nhw_ntiles = tt::round_up(batch_size * output_height * output_width, tt::constants::TILE_HEIGHT) / effective_tile_height; uint32_t input_channles_ntiles = tt::div_up(input_channels, effective_tile_width); @@ -277,13 +248,12 @@ OptimizedConvParallelizationConfig determine_conv_op_parallel_config_from_conv_o TT_ASSERT(conv_output_mem_config.shard_spec.has_value()); const auto& shard_spec = conv_output_mem_config.shard_spec.value(); const auto& shard_shape = shard_spec.shape; - uint32_t per_core_out_matrix_height_ntiles = div_up(shard_shape[0], 32); return { .grid_size = shard_spec.grid.bounding_box().grid_size(), .num_cores_nhw = num_cores_nhw, .num_cores_c = num_cores_c, - .per_core_out_matrix_height = shard_shape[0], - .per_core_out_matrix_width = shard_shape[1], + .per_core_out_matrix_height_ntile = div_up(shard_shape[0], tt::constants::TILE_HEIGHT), + .per_core_out_matrix_width_ntile = div_up(shard_shape[1], tt::constants::TILE_WIDTH), }; } @@ -341,8 +311,7 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( "Config Error: act_block_h_override must be a multiple of 32 (tile height)."); } - uint32_t act_block_h_ntiles = - div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); + uint32_t act_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntile; if (act_block_h_override > 0) { uint32_t act_block_h_override_ntiles = act_block_h_override / constants::TILE_HEIGHT; @@ -379,10 +348,8 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( } TT_ASSERT(act_block_w % 32 == 0); uint32_t act_block_w_ntiles = act_block_w / 32; - uint32_t out_block_h_ntiles = - div_up(conv_op_parallel_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); - uint32_t weight_block_w_ntiles = - div_up(conv_op_parallel_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); + uint32_t out_block_h_ntiles = conv_op_parallel_config.per_core_out_matrix_height_ntile; + uint32_t weight_block_w_ntiles = conv_op_parallel_config.per_core_out_matrix_width_ntile; auto [out_subblock_h_ntiles, out_subblock_w_ntiles] = determine_largest_subblock_size(act_block_h_ntiles, weight_block_w_ntiles, fp32_accum, split_reader_enabled); return { @@ -418,7 +385,7 @@ DeviceComputeKernelConfig get_conv_default_compute_kernel_config(DeviceType* dev } template -static std::tuple get_conv_padded_input_shape_and_mem_config( +static std::tuple get_conv_padded_input_shape_and_mem_config( T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -427,8 +394,7 @@ static std::tuple get_conv_padded_i uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv, - bool is_non_tile_mul_width) { + bool is_mm_conv) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); bool needs_shard_or_reshard = false; @@ -494,11 +460,6 @@ static std::tuple get_conv_padded_i } } - // shallow conv variriant not supported - // out_channels <= 256 incorrect output from pack_untilize_dst if output > 256 Tracking --> #14236 - // bf8 not supported due to limation of sharding dim multipl of 32 - const bool use_non_tile_height = check_non_tile_height(conv_config, out_channels); - ParallelConfig parallel_config = input_tensor_parallel_config; if (conv_config.reshard_if_not_optimal || needs_shard_or_reshard) { auto block_shard_orientation = @@ -513,8 +474,7 @@ static std::tuple get_conv_padded_i device->compute_with_storage_grid_size(), block_shard_orientation, !is_mm_conv, - !use_non_tile_height, - is_non_tile_mul_width, + true, conv_config.act_block_h_override); if (conv_config.override_sharding_config) { @@ -541,18 +501,13 @@ static std::tuple get_conv_padded_i const auto& input_shape = input_tensor.get_logical_shape(); uint32_t tensor_height = input_shape[0] * input_shape[1] * input_shape[2]; uint32_t round_up_size = tt::constants::TILE_HEIGHT; - if ((use_non_tile_height || shard_layout == TensorMemoryLayout::WIDTH_SHARDED) && - input_tensor_.layout() == Layout::ROW_MAJOR) { + if (shard_layout == TensorMemoryLayout::WIDTH_SHARDED && input_tensor_.layout() == Layout::ROW_MAJOR) { round_up_size = 1; } uint32_t input_tensor_height_snapped_to_tile = tt::round_up(tensor_height, input_num_cores_nhw * round_up_size); TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); uint32_t input_tensor_width_snapped_to_channels_alignment = tt::round_up(input_shape[3], input_num_cores_c * conv_config.input_channels_alignment); - if (is_non_tile_mul_width) { - input_tensor_width_snapped_to_channels_alignment = - tt::round_up(input_shape[3], conv_config.input_channels_alignment); - } auto input_padded_shape = ttnn::Shape( {1, @@ -566,13 +521,9 @@ static std::tuple get_conv_padded_i parallel_config, round_up_size); - return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height}; + return {input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard}; } else { - return { - input_tensor.get_logical_shape(), - input_tensor.memory_config(), - needs_shard_or_reshard, - use_non_tile_height}; + return {input_tensor.get_logical_shape(), input_tensor.memory_config(), needs_shard_or_reshard}; } } @@ -584,7 +535,7 @@ static ttnn::Shape flatten_4d_shape(const ttnn::Shape& input_shape) { } template -std::tuple shard_or_reshard_tensor_if_required( +std::tuple shard_or_reshard_tensor_if_required( T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -594,24 +545,14 @@ std::tuple shard_or_reshard_ uint32_t in_channels, uint32_t out_channels, bool is_mm_conv, - bool auto_shard, - bool is_non_tile_mul_width) { + bool auto_shard) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); auto compute_grid_size = device->compute_with_storage_grid_size(); - auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard, use_non_tile_height] = + auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard] = get_conv_padded_input_shape_and_mem_config( - device, - input_tensor_, - conv_config, - batch_size, - height, - width, - in_channels, - out_channels, - is_mm_conv, - is_non_tile_mul_width); + device, input_tensor_, conv_config, batch_size, height, width, in_channels, out_channels, is_mm_conv); ParallelConfig parallel_config = { .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, .shard_scheme = input_tensor_sharded_memory_config.memory_layout, @@ -675,7 +616,7 @@ std::tuple shard_or_reshard_ input_tensor, device, (auto_shard_mm ? ttnn::DRAM_MEMORY_CONFIG : input_tensor_sharded_memory_config)); } } - return {input_tensor, parallel_config, output_parallel_config, use_non_tile_height}; + return {input_tensor, parallel_config, output_parallel_config}; } void validate_weight_and_bias_tensors( @@ -707,10 +648,10 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co .in0_block_w = conv_blocking_config.act_block_w_ntiles, .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, - .out_block_h = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), - .out_block_w = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), - .per_core_M = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), - .per_core_N = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), + .out_block_h = conv_parallelization_config.per_core_out_matrix_height_ntile, + .out_block_w = conv_parallelization_config.per_core_out_matrix_width_ntile, + .per_core_M = conv_parallelization_config.per_core_out_matrix_height_ntile, + .per_core_N = conv_parallelization_config.per_core_out_matrix_width_ntile, .fuse_batch = true, .mcast_in0 = false}; if (activation != "") { @@ -723,10 +664,10 @@ ttnn::operations::matmul::MatmulProgramConfig determine_matmul_op_config_from_co .in0_block_w = conv_blocking_config.act_block_w_ntiles, .out_subblock_h = conv_blocking_config.out_subblock_h_ntiles, .out_subblock_w = conv_blocking_config.out_subblock_w_ntiles, - .out_block_h = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), - .out_block_w = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), - .per_core_M = div_up(conv_parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT), - .per_core_N = div_up(conv_parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH), + .out_block_h = conv_parallelization_config.per_core_out_matrix_height_ntile, + .out_block_w = conv_parallelization_config.per_core_out_matrix_width_ntile, + .per_core_M = conv_parallelization_config.per_core_out_matrix_height_ntile, + .per_core_N = conv_parallelization_config.per_core_out_matrix_width_ntile, .transpose_mcast = transpose_mcast}; if (activation != "") { matmul_config.fused_activation = ttnn::operations::unary::utils::string_to_unary_with_param(activation); @@ -795,9 +736,6 @@ Conv2dConfig determine_conv_config_for_auto_shard( conv_config.act_block_h_override = constants::TILE_HEIGHT; } - const bool is_non_tile_shard_width = check_non_tile_mul_width(compute_grid_size, conv_config, in_channels); - const bool use_non_tile_height = check_non_tile_height(conv_config, out_channels); - const uint32_t in_channels_padded = round_up(in_channels, conv_config.input_channels_alignment); const uint32_t output_channels_padded = round_up(out_channels, constants::TILE_WIDTH); // Note: These are not exact shapes for weights as prepare_conv_weights will pad the weights depending on the @@ -816,7 +754,6 @@ Conv2dConfig determine_conv_config_for_auto_shard( shard_orientation, !is_mm_conv, is_out_tiled, - is_non_tile_shard_width, conv_config.act_block_h_override); const ParallelConfig output_parallel_config = determine_output_parallel_config( @@ -854,7 +791,6 @@ Conv2dConfig determine_conv_config_for_auto_shard( conv_config, conv_out_memory_config, enable_bias, - use_non_tile_height, conv_is_1d_deptwise); // Since we don't have L1 usage for halo output (input to conv2d) @@ -917,16 +853,10 @@ std::tuple kernel_size, const CoreCoord& compute_grid) { - bool is_non_tile_mul_width = check_non_tile_mul_width(compute_grid, conv_config, in_channels); - const bool use_non_tile_height = check_non_tile_height(conv_config, out_channels); - - uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; + uint32_t round_up_size = tt::constants::TILE_HEIGHT; uint32_t nhw_out = batch_size * output_height * output_width; uint32_t out_channels_padded = tt::round_up( out_channels, get_num_cores_channels_from_parallel_config(output_parallel_config) * tt::constants::TILE_WIDTH); - if (is_non_tile_mul_width) { - out_channels_padded = tt::round_up(out_channels, 32); - } MemoryConfig conv_out_memory_config = create_sharded_memory_config_from_parallel_config( ttnn::Shape({1, 1, nhw_out, out_channels_padded}), output_parallel_config, round_up_size); ParallelConfig largest_parallel_config = @@ -942,9 +872,6 @@ std::tuple 1) || (in0_num_blocks_w > 2)); } -template std::tuple get_conv_padded_input_shape_and_mem_config( +template std::tuple get_conv_padded_input_shape_and_mem_config( IDevice* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -1289,10 +1201,9 @@ template std::tuple get_conv_padded uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv, - bool is_non_tile_mul_width); + bool is_mm_conv); -template std::tuple get_conv_padded_input_shape_and_mem_config( +template std::tuple get_conv_padded_input_shape_and_mem_config( MeshDevice* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -1301,10 +1212,9 @@ template std::tuple get_conv_padded uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv, - bool is_non_tile_mul_width); + bool is_mm_conv); -template std::tuple shard_or_reshard_tensor_if_required( +template std::tuple shard_or_reshard_tensor_if_required( IDevice* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -1314,10 +1224,9 @@ template std::tuple shard_or uint32_t in_channels, uint32_t out_channels, bool is_mm_conv, - bool auto_shard, - bool is_non_tile_mul_width); + bool auto_shard); -template std::tuple shard_or_reshard_tensor_if_required( +template std::tuple shard_or_reshard_tensor_if_required( MeshDevice* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -1327,8 +1236,7 @@ template std::tuple shard_or uint32_t in_channels, uint32_t out_channel, bool is_mm_conv, - bool auto_shard, - bool is_non_tile_mul_width); + bool auto_shard); template DeviceComputeKernelConfig get_conv_default_compute_kernel_config( tt::tt_metal::IDevice* device); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index b3d4a0b5553..440521121d5 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -37,12 +37,6 @@ bool use_matmul_for_1x1_conv( bool is_1d_deptwise_conv( uint32_t groups, uint32_t input_channels, uint32_t output_channels, uint32_t kernel_width, uint32_t image_width); - -bool check_non_tile_mul_width( - const CoreCoord& compute_grid, const Conv2dConfig& conv_config, const uint32_t in_channels); - -bool check_non_tile_height(const Conv2dConfig& conv_config, const uint32_t out_channels); - sliding_window::ParallelConfig determine_parallel_config( const TensorMemoryLayout shard_layout, uint32_t batch_size, @@ -54,7 +48,6 @@ sliding_window::ParallelConfig determine_parallel_config( ShardOrientation block_shard_orientation, bool enable_channels_padding, bool is_out_tiled = true, - bool is_non_tile_mul_shard_width = false, uint32_t act_block_h_override = 0); sliding_window::ParallelConfig determine_output_parallel_config( @@ -113,7 +106,7 @@ std::tuple -static std::tuple get_conv_padded_input_shape_and_mem_config( +static std::tuple get_conv_padded_input_shape_and_mem_config( T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, @@ -122,8 +115,7 @@ static std::tuple get_conv_padded_i uint32_t width, uint32_t in_channels, uint32_t out_channels, - bool is_mm_conv, - bool is_non_tile_mul_width = false); + bool is_mm_conv); template DeviceComputeKernelConfig get_conv_default_compute_kernel_config(DeviceType* device); @@ -148,7 +140,7 @@ Conv2dConfig determine_conv_config_for_auto_shard( const DeviceComputeKernelConfig& compute_config); template -std::tuple +std::tuple shard_or_reshard_tensor_if_required( T* device, const ttnn::Tensor& input_tensor_, @@ -159,8 +151,7 @@ shard_or_reshard_tensor_if_required( uint32_t in_channels, uint32_t out_channels, bool is_mm_conv, - bool auto_shard, - bool is_non_tile_mul_width = false); + bool auto_shard); std::ostream& operator<<(std::ostream& os, const Conv2dConfig& config); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp index a7f1c2a774a..249fab4d7c3 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.cpp @@ -71,8 +71,7 @@ Tensor optimized_conv_new( bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, - bool enable_subblock_padding, - bool use_non_tile_height) { + bool enable_subblock_padding) { std::vector output_tensors = {Tensor(tt::tt_metal::operation::get_workers_for_op_output({a, b}))}; operation::launch_op( @@ -91,8 +90,7 @@ Tensor optimized_conv_new( enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, - enable_subblock_padding, - use_non_tile_height]( + enable_subblock_padding]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -138,8 +136,7 @@ Tensor optimized_conv_new( enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, - enable_subblock_padding, - use_non_tile_height); + enable_subblock_padding); IDevice* device = a.device(); optimized_conv_op.pre_op_l1_allocation_size_bytes = @@ -163,10 +160,8 @@ void OptimizedConvNew::validate( TT_FATAL((this->dtype == DataType::BFLOAT16) || (this->dtype == DataType::FLOAT32), "Error"); } if (this->memory_config.is_sharded()) { - uint32_t out_block_h_ntiles = - optimized_conv_op_utils::div_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); - uint32_t per_core_out_matrix_width_ntiles = - optimized_conv_op_utils::div_up(parallelization_config.per_core_out_matrix_width, TILE_WIDTH); + uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntile; + uint32_t per_core_out_matrix_width_ntiles = parallelization_config.per_core_out_matrix_width_ntile; auto [act_matrix_shape, act_matrix_shape_unpadded] = optimized_conv_op_utils::compute_opt_conv_activation_as_mm_shape( input_tensor_a.get_padded_shape(), @@ -207,10 +202,8 @@ std::vector OptimizedConvNew::compute_output_specs(const std::vector // Tiled output shape is padded shape. Padded to tile shape. auto shape_w = batch_size * conv_output_h * conv_output_w; auto shape_c = output_channels; - auto padded_shape_w = this->use_non_tile_height - ? parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height - : parallelization_config.num_cores_nhw * - tt::round_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + auto padded_shape_w = + parallelization_config.num_cores_nhw * parallelization_config.per_core_out_matrix_height_ntile * TILE_HEIGHT; auto padded_shape_c = tt::round_up(this->output_channels, TILE_WIDTH); ttnn::Shape output_shape({1, 1, shape_w, shape_c}); ttnn::Shape padded_output_shape({1, 1, padded_shape_w, padded_shape_c}); @@ -219,24 +212,9 @@ std::vector OptimizedConvNew::compute_output_specs(const std::vector if (this->memory_config.is_sharded()) { if (this->memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { uint32_t total_height_tiles = padded_output_shape.volume() / padded_output_shape[-1] / TILE_HEIGHT; - uint32_t num_cores; - std::array shard_shape; - if (this->use_non_tile_height) { - num_cores = this->parallelization_config.num_cores_nhw; - uint32_t total_height = padded_output_shape.volume() / padded_output_shape[-1]; - shard_shape = {(uint32_t)(total_height / num_cores), padded_output_shape[-1]}; - } else { - num_cores = total_height_tiles / - tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); - CoreRangeSet shard_grid = - tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true); - - shard_shape = { - optimized_conv_op_utils::div_up( - this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * - TILE_HEIGHT, - padded_output_shape[-1]}; - } + uint32_t num_cores = total_height_tiles / this->parallelization_config.per_core_out_matrix_height_ntile; + std::array shard_shape = { + this->parallelization_config.per_core_out_matrix_height_ntile * TILE_HEIGHT, padded_output_shape[-1]}; CoreRangeSet shard_grid = tt::tt_metal::num_cores_to_corerangeset(num_cores, this->parallelization_config.grid_size, true); auto shard_spec = ShardSpec{shard_grid, shard_shape, ShardOrientation::ROW_MAJOR}; @@ -249,8 +227,8 @@ std::vector OptimizedConvNew::compute_output_specs(const std::vector } else if (this->memory_config.memory_layout == TensorMemoryLayout::WIDTH_SHARDED) { uint32_t total_height_tiles = padded_output_shape.volume() / padded_output_shape[-1] / TILE_HEIGHT; std::array shard_shape = { - tt::div_up(this->parallelization_config.per_core_out_matrix_height, TILE_HEIGHT) * TILE_HEIGHT, - tt::div_up(this->parallelization_config.per_core_out_matrix_width, TILE_WIDTH) * TILE_WIDTH}; + this->parallelization_config.per_core_out_matrix_height_ntile * TILE_HEIGHT, + this->parallelization_config.per_core_out_matrix_width_ntile * TILE_WIDTH}; auto shard_grid = this->memory_config.shard_spec.value().grid; auto shard_spec = ShardSpec{shard_grid, shard_shape, this->memory_config.shard_spec.value().orientation}; auto mem_config = this->memory_config; @@ -314,8 +292,7 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program( enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, - enable_subblock_padding, - use_non_tile_height); + enable_subblock_padding); const uint32_t post_op_l1_allocation_size = device->allocator()->get_statistics(tt::tt_metal::BufferType::L1).total_allocated_bytes; @@ -340,7 +317,6 @@ operation::ProgramWithCallbacks OptimizedConvNew::create_program( .enable_subblock_padding = enable_subblock_padding}, this->memory_config, has_bias, - use_non_tile_height, is_1d_deptwise_conv( groups, input_tensor_shape[3], diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp index 6f804922950..04557524b76 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op.hpp @@ -122,13 +122,8 @@ struct OptimizedConvParallelizationConfig { CoreCoord grid_size; // (x,y) uint32_t num_cores_nhw = 1; uint32_t num_cores_c = 1; - uint32_t per_core_out_matrix_height = 1; - uint32_t per_core_out_matrix_width = 1; - // std::size_t in0_block_w; - // std::size_t out_subblock_h; - // std::size_t out_subblock_w; - // std::size_t per_core_M; - // std::size_t per_core_N; + uint32_t per_core_out_matrix_height_ntile = 1; + uint32_t per_core_out_matrix_width_ntile = 1; CoreCoord get_grid_size() const { return this->grid_size; } }; @@ -159,8 +154,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_ bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, - bool enable_subblock_padding, - bool use_non_tile_height); + bool enable_subblock_padding); // new micro op struct OptimizedConvNew { @@ -179,7 +173,6 @@ struct OptimizedConvNew { bool enable_weights_double_buffer; bool enable_split_reader; bool enable_subblock_padding; - bool use_non_tile_height; uint32_t pre_op_l1_allocation_size_bytes; OptimizedConvNew( const sliding_window::SlidingWindowConfig& sliding_window_config, @@ -198,8 +191,7 @@ struct OptimizedConvNew { bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, - bool enable_subblock_padding, - bool use_non_tile_height) : + bool enable_subblock_padding) : output_channels(output_channels), groups(groups), sliding_window_config(sliding_window_config), @@ -216,8 +208,7 @@ struct OptimizedConvNew { enable_act_double_buffer(enable_act_double_buffer), enable_weights_double_buffer(enable_weights_double_buffer), enable_split_reader(enable_split_reader), - enable_subblock_padding(enable_subblock_padding), - use_non_tile_height(use_non_tile_height) {} + enable_subblock_padding(enable_subblock_padding) {} void validate( const std::vector& input_tensors, @@ -290,8 +281,7 @@ Tensor optimized_conv_new( bool enable_act_double_buffer = false, bool enable_weights_double_buffer = false, bool enable_split_reader = false, - bool enable_subblock_padding = false, - bool use_non_tile_height = false); + bool enable_subblock_padding = false); // Only enable packer l1 accumulation when there are in0_num_blocks_w > 2, otherwise // unnecessary overhead for reconfigs are added. Last iteration of l1 accumulation @@ -317,7 +307,6 @@ conv_op_l1_usage calculate_L1_usage( const Conv2dConfig& conv_config, const MemoryConfig& output_memory_config, bool enable_bias, - bool use_non_tile_height, bool is_1d_depthwise_conv); } // namespace conv2d diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp index d0e917aee50..32fd24971e8 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_sharded_program_factory.cpp @@ -33,7 +33,6 @@ const uint32_t tilize_mode_tilized_act_cb = CBIndex::c_25; const uint32_t untilize_mode_reblock_cb = CBIndex::c_26; const uint32_t out0_cb = CBIndex::c_16; const uint32_t temp_sum_cb = CBIndex::c_27; -const uint32_t untilized_padded_out_cb = CBIndex::c_28; } // namespace CMAKE_UNIQUE_NAMESPACE } // namespace @@ -84,8 +83,7 @@ std::tuple create_CBs_for_sharded_input_v2( bool with_bias, bool split_reader, bool fp32_dest_acc_en, - bool packer_l1_acc_en, - bool use_non_tile_height) { + bool packer_l1_acc_en) { using namespace CMAKE_UNIQUE_NAMESPACE; tt::DataFormat interm0_df = @@ -199,42 +197,15 @@ std::tuple create_CBs_for_sharded_input_v2( bool need_unpad_after_untilize = output_shard_shape[1] * output_shard_shape[0] < num_writer_output_tiles * TILE_HW; - // If only width is non-tile multiple - if (need_unpad_after_untilize && !use_non_tile_height && weight_width_sliced) { - uint32_t num_bytes_for_df = datum_size(out_df); - CircularBufferConfig compute_cb_output_config = - CircularBufferConfig(num_writer_output_tiles * out_tile_size, {{untilized_padded_out_cb, out_df}}) - .set_page_size(untilized_padded_out_cb, out_tile_size); - auto compute_cb_output = tt_metal::CreateCircularBuffer(program, core, compute_cb_output_config); - log_debug( - LogOp, - "untilized padded out CB(shard width non-tile multiple): {}, npages: {}, pagesize: {}", - untilized_padded_out_cb, - num_writer_output_tiles, - out_tile_size); - CircularBufferConfig cb_output_config = - CircularBufferConfig( - num_bytes_for_df * output_shard_shape[0] * output_shard_shape[1], {{out0_cb, out_df}}) - .set_page_size(out0_cb, output_shard_shape[1] * num_bytes_for_df); - cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); - cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); - log_debug( - LogOp, - "output CB(shard widht non-tile multiple): {}, npages: {}, pagesize: {}", - out0_cb, - output_shard_shape[0], - output_shard_shape[1] * num_bytes_for_df); - } else { - auto shard_shape = output.shard_spec().value().shape; - uint32_t aligned_output_stick_nbytes = - use_non_tile_height ? shard_shape[1] * output.element_size() : out_tile_size; - uint32_t aligned_output_num_pages = use_non_tile_height ? shard_shape[0] : num_writer_output_tiles; - CircularBufferConfig cb_output_config = - CircularBufferConfig(aligned_output_num_pages * aligned_output_stick_nbytes, {{out0_cb, out_df}}) - .set_page_size(out0_cb, aligned_output_stick_nbytes); - cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); - cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); - } + + auto shard_shape = output.shard_spec().value().shape; + uint32_t aligned_output_stick_nbytes = out_tile_size; + uint32_t aligned_output_num_pages = num_writer_output_tiles; + CircularBufferConfig cb_output_config = + CircularBufferConfig(aligned_output_num_pages * aligned_output_stick_nbytes, {{out0_cb, out_df}}) + .set_page_size(out0_cb, aligned_output_stick_nbytes); + cb_output_config = cb_output_config.set_globally_allocated_address(*output.buffer()); + cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config); } else { // Share buffer if same data format if (interm0_df == out_df) { @@ -425,8 +396,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, - bool enable_subblock_padding, - bool use_non_tile_height) { + bool enable_subblock_padding) { using namespace CMAKE_UNIQUE_NAMESPACE; bool pass = true; tt_metal::IDevice* device = a.device(); @@ -435,8 +405,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( TT_FATAL(output_channels <= b.get_padded_shape()[3], "Invalid weight shape. Incorrect weight tensor."); uint32_t act_block_h_ntiles = block_config.act_block_h_ntiles; uint32_t act_block_w_ntiles = block_config.act_block_w_ntiles; - uint32_t weight_block_w_ntiles = div_up(parallelization_config.per_core_out_matrix_width, TILE_WIDTH); - uint32_t out_block_h_ntiles = div_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + uint32_t weight_block_w_ntiles = parallelization_config.per_core_out_matrix_width_ntile; + uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntile; uint32_t out_subblock_h_ntiles = block_config.out_subblock_h_ntiles; uint32_t out_subblock_w_ntiles = block_config.out_subblock_w_ntiles; @@ -535,8 +505,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t num_cores_y = p_config.grid_size.y; uint32_t total_num_cores = num_cores_x * num_cores_y; - uint32_t per_core_out_matrix_width_ntiles = div_up(parallelization_config.per_core_out_matrix_width, TILE_WIDTH); - uint32_t per_core_out_matrix_height_ntiles = div_up(parallelization_config.per_core_out_matrix_height, TILE_HEIGHT); + uint32_t per_core_out_matrix_width_ntiles = parallelization_config.per_core_out_matrix_width_ntile; + uint32_t per_core_out_matrix_height_ntiles = parallelization_config.per_core_out_matrix_height_ntile; bool block_sharded = a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED; bool height_sharded = a.memory_config().memory_layout == TensorMemoryLayout::HEIGHT_SHARDED; @@ -919,14 +889,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( log_debug(LogOp, "num_blocks_out_h_per_core: {}", num_blocks_out_h_per_core); TT_FATAL(act_matrix_height_ntiles % per_core_out_matrix_height_ntiles == 0, "Error"); - uint32_t total_active_num_cores_per_weight_slice; - if (use_non_tile_height) { - total_active_num_cores_per_weight_slice = - tt::round_up(act_matrix_height_unpadded, parallelization_config.num_cores_nhw) / - parallelization_config.per_core_out_matrix_height; - } else { - total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / per_core_out_matrix_height_ntiles; - } + uint32_t total_active_num_cores_per_weight_slice = act_matrix_height_ntiles / per_core_out_matrix_height_ntiles; TT_FATAL(total_active_num_cores_per_weight_slice <= total_num_cores_per_weight_slice, "Error"); uint32_t total_noop_cores = total_num_cores_per_weight_slice - total_active_num_cores_per_weight_slice; uint32_t total_active_num_cores = total_active_num_cores_per_weight_slice * num_weight_slices_width; @@ -1074,8 +1037,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( uint32_t output_block_num_tiles = enable_subblock_padding ? (act_block_h_ntiles_padded * weight_block_w_ntiles) : writer_output_block_num_tiles; - uint32_t aligned_output_num_pages = - use_non_tile_height ? output.shard_spec().value().shape[0] : writer_output_block_num_tiles; + uint32_t aligned_output_num_pages = writer_output_block_num_tiles; std::vector reader_rt_args; std::vector reader_compile_time_args; @@ -1157,8 +1119,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( has_bias, split_reader, fp32_dest_acc_en, - packer_l1_acc_en, - use_non_tile_height); + packer_l1_acc_en); } CBHandle cb_sharded_act = std::get<0>(input_output_cbs); CBHandle cb_output = std::get<1>(input_output_cbs); @@ -1391,20 +1352,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( writer_compile_time_args.insert( writer_compile_time_args.end(), split_reader_args.begin(), split_reader_args.end()); } - bool need_unpad_after_untilize = - parallelization_config.per_core_out_matrix_width < per_core_out_matrix_width_ntiles * TILE_WIDTH; - if (need_unpad_after_untilize) { - TT_FATAL(block_sharded, "Need to handle this case for non-sliced weights"); - TT_FATAL(untilize_out, "Cannot support non-tile multiple shard width with tilized output"); - writer_compile_time_args.push_back(per_core_out_matrix_width_ntiles); - writer_compile_time_args.push_back(per_core_out_matrix_width_ntiles * TILE_WIDTH * 2); - writer_compile_time_args.push_back(parallelization_config.per_core_out_matrix_width * 2); - writer_compile_time_args.push_back(untilized_padded_out_cb); - writer_defines["UNPAD_UNTILIZE_OUT"] = 1; - writer_mcast_sender_defines["UNPAD_UNTILIZE_OUT"] = 1; - } - uint32_t compute_output_cb = need_unpad_after_untilize ? untilized_padded_out_cb : out0_cb; std::vector compute_kernel_args = { in0_block_w, act_num_subblocks, @@ -1428,9 +1376,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl( untilize_out, bias_ntiles_per_core, - compute_output_cb, - aligned_output_num_pages, - use_non_tile_height}; + out0_cb}; auto writer_mcast_noc = NOC::NOC_0; auto reader_noc = writer_mcast_noc == NOC::NOC_0 ? NOC::NOC_1 : NOC::NOC_0; @@ -1816,8 +1762,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( bool enable_act_double_buffer, bool enable_weights_double_buffer, bool enable_split_reader, - bool enable_subblock_padding, - bool use_non_tile_height) { + bool enable_subblock_padding) { tt_metal::Program program = tt_metal::CreateProgram(); ttnn::operations::sliding_window::ParallelConfig parallel_config; @@ -1889,8 +1834,7 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_new( enable_act_double_buffer, enable_weights_double_buffer, enable_split_reader, - enable_subblock_padding, - use_non_tile_height); + enable_subblock_padding); } } // namespace conv2d diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp index 3ed850823b9..84d7bc017aa 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/conv2d_op_width_sharded_program_factory.cpp @@ -62,9 +62,8 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh TT_FATAL(output_channels <= b.get_padded_shape()[3], "Invalid weight shape. Incorrect weight tensor."); uint32_t act_block_h_ntiles = block_config.act_block_h_ntiles; uint32_t act_block_w_ntiles = block_config.act_block_w_ntiles; - uint32_t weight_block_w_ntiles = - div_up(parallelization_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); - uint32_t out_block_h_ntiles = div_up(parallelization_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); + uint32_t weight_block_w_ntiles = parallelization_config.per_core_out_matrix_width_ntile; + uint32_t out_block_h_ntiles = parallelization_config.per_core_out_matrix_height_ntile; uint32_t out_subblock_h_ntiles = block_config.out_subblock_h_ntiles; uint32_t out_subblock_w_ntiles = block_config.out_subblock_w_ntiles; @@ -168,12 +167,10 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh const auto& p_config = parallelization_config; uint32_t num_cores_x = p_config.grid_size.x; uint32_t num_cores_y = p_config.grid_size.y; - uint32_t per_core_out_matrix_height_ntiles = - div_up(p_config.per_core_out_matrix_height, tt::constants::TILE_HEIGHT); - uint32_t per_core_out_matrix_width_ntiles = div_up(p_config.per_core_out_matrix_width, tt::constants::TILE_WIDTH); + uint32_t per_core_out_matrix_height_ntiles = p_config.per_core_out_matrix_height_ntile; // weight_width_sliced determines is 1d-sysarr-conv or 2d-sysarr-conv - bool weight_width_sliced = per_core_out_matrix_width_ntiles < weight_matrix_width_ntiles; - // uint32_t conv_act_c_blocks = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles; + bool weight_width_sliced = p_config.per_core_out_matrix_width_ntile < weight_matrix_width_ntiles; + // uint32_t conv_act_c_blocks = weight_matrix_width_ntiles / p_config.per_core_out_matrix_width_ntile; uint32_t input_channels_padded = shard_shape[1] * input_num_cores; // TT_FATAL(conv_act_c_blocks == p_config.num_cores_c, "Error"); TT_FATAL(input_channels_padded >= ashape[3], "Incorrect padding of input channels!"); @@ -443,10 +440,10 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh bias_in_dram = bias_buffer->buffer_type() == BufferType::DRAM; } - uint32_t num_weight_slices_width = weight_matrix_width_ntiles / per_core_out_matrix_width_ntiles; + uint32_t num_weight_slices_width = weight_matrix_width_ntiles / p_config.per_core_out_matrix_width_ntile; uint32_t num_blocks_act_h_per_core = - (per_core_out_matrix_height_ntiles + act_block_h_ntiles - 1) / act_block_h_ntiles; - uint32_t num_blocks_weight_w_per_core = per_core_out_matrix_width_ntiles / weight_block_w_ntiles; + (p_config.per_core_out_matrix_height_ntile + act_block_h_ntiles - 1) / act_block_h_ntiles; + uint32_t num_blocks_weight_w_per_core = p_config.per_core_out_matrix_width_ntile / weight_block_w_ntiles; uint32_t bias_ntiles_per_core = bias_ntiles / num_weight_slices_width; auto output_shape = sliding_window_config.get_output_shape(); @@ -511,8 +508,8 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh log_debug(LogOp, "act_matrix_height_ntiles: {}", act_matrix_height_ntiles); log_debug(LogOp, "act_matrix_width_ntiles: {}", act_matrix_width_ntiles); log_debug(LogOp, "weight_matrix_width_ntiles: {}", weight_matrix_width_ntiles); - log_debug(LogOp, "per_core_out_matrix_height_ntiles: {}", per_core_out_matrix_height_ntiles); - log_debug(LogOp, "per_core_out_matrix_width_ntiles: {}", per_core_out_matrix_width_ntiles); + log_debug(LogOp, "per_core_out_matrix_height_ntiles: {}", p_config.per_core_out_matrix_height_ntile); + log_debug(LogOp, "per_core_out_matrix_width_ntiles: {}", p_config.per_core_out_matrix_width_ntile); log_debug(LogOp, "per_core_num_blocks_act_w: {}", per_core_num_blocks_act_w); log_debug(LogOp, "num_blocks_act_h: {}", num_blocks_act_h); @@ -648,8 +645,7 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh if (packer_l1_acc) { compute_defines["PACKER_L1_ACC"] = "1"; } - uint32_t num_output_tiles = per_core_out_matrix_height_ntiles * per_core_out_matrix_width_ntiles; - uint32_t use_non_tile_height = false; + uint32_t num_output_tiles = per_core_out_matrix_height_ntiles * p_config.per_core_out_matrix_width_ntile; compute_kernel_args = { act_block_w_ntiles, // in0_block_w act_num_subblocks, // in0_num_sublocks @@ -675,8 +671,6 @@ tt::tt_metal::operation::ProgramWithCallbacks multi_core_optimized_conv_width_sh bias_ntiles, out0_cb, - num_output_tiles, - use_non_tile_height, input_num_cores, // in0_nblocks_w_tilize. Repeat tilize after all cores have done one round of MCAST. }; diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp index 94ea5310615..94545fc3704 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/conv_bmm_tilize_col_major_out_blocks.cpp @@ -40,24 +40,19 @@ inline void tilize_in( tilize_uninit(in_cb_id, out_cb_id); } // tilize_in() -template +template inline void reblock_and_untilize( uint32_t num_out_subblocks_in_col, uint32_t out_subblock_num_tiles, uint32_t out_subblock_h, - uint32_t output_rows_h, uint32_t interm_cb_id, uint32_t out_cb_id) { - constexpr bool is_non_tile_height_ = is_non_tile_height; - uint32_t TILE_SIZE = is_non_tile_height_ ? 32 : out_block_w; uint32_t num_tiles_in_row_of_subblocks = mulsi3(out_subblock_num_tiles, num_out_subblocks_in_col); cb_wait_front(interm_cb_id, num_tiles_in_row_of_subblocks); uint32_t within_block_index = 0; for (uint32_t h = 0; h < out_subblock_h; h++) { uint32_t block_offset = 0; - uint32_t out_sub_block_rows_h = output_rows_h <= TILE_SIZE ? output_rows_h : TILE_SIZE; - uint32_t rows_to_copy = is_non_tile_height_ ? out_sub_block_rows_h : 16; - cb_reserve_back(out_cb_id, out_sub_block_rows_h); + cb_reserve_back(out_cb_id, out_block_w); for (uint32_t n = 0; n < num_out_subblocks_in_col; n++) { tile_regs_acquire(); for (uint32_t w = 0; w < out_subblock_w; w++) { @@ -66,12 +61,11 @@ inline void reblock_and_untilize( } tile_regs_commit(); tile_regs_wait(); - pack_untilize_dst(out_cb_id, 1, n, rows_to_copy); + pack_untilize_dst(out_cb_id, 1, n); tile_regs_release(); block_offset += out_subblock_num_tiles; } - cb_push_back(out_cb_id, out_sub_block_rows_h); - output_rows_h -= out_sub_block_rows_h; + cb_push_back(out_cb_id, out_block_w); within_block_index += out_subblock_w; } cb_pop_front(interm_cb_id, num_tiles_in_row_of_subblocks); @@ -100,11 +94,9 @@ void MAIN { constexpr bool tilize_in0 = get_compile_time_arg_val(14); constexpr bool untilize_out = get_compile_time_arg_val(15); constexpr uint32_t out_cb_id = get_compile_time_arg_val(17); - uint32_t output_rows_h = get_compile_time_arg_val(18); - constexpr bool is_non_tile_height = get_compile_time_arg_val(19); #ifdef WIDTH_SHARDED - constexpr uint32_t in0_nblocks_w_tilize = get_compile_time_arg_val(20); + constexpr uint32_t in0_nblocks_w_tilize = get_compile_time_arg_val(18); #endif constexpr uint32_t out_block_num_tiles = in0_num_subblocks * in1_num_subblocks * out_subblock_num_tiles; @@ -118,7 +110,6 @@ void MAIN { constexpr uint32_t in0_cb_second_reader_id = tt::CBIndex::c_7; constexpr uint32_t matmul_partials_cb = tt::CBIndex::c_24; constexpr uint32_t tilized_in0_cb_id = tt::CBIndex::c_25; - // constexpr uint32_t untilize_mode_reblock_cb = tt::CBIndex::c_26; constexpr uint32_t untilize_mode_out_cb_id = untilize_out ? matmul_partials_cb : out_cb_id; @@ -439,19 +430,9 @@ void MAIN { #endif pack_untilize_dst_init_short(out_cb_id); copy_tile_to_dst_init_short(matmul_partials_cb); - uint32_t curr_tile_output_rows_h = 0; - uint32_t TILE_SIZE = is_non_tile_height ? 32 : out_block_w; - TILE_SIZE = TILE_SIZE * out_subblock_h; for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { - curr_tile_output_rows_h = output_rows_h < TILE_SIZE ? output_rows_h : TILE_SIZE; - reblock_and_untilize( - in1_num_subblocks, - out_subblock_num_tiles, - out_subblock_h, - curr_tile_output_rows_h, - matmul_partials_cb, - out_cb_id); - output_rows_h -= curr_tile_output_rows_h; + reblock_and_untilize( + in1_num_subblocks, out_subblock_num_tiles, out_subblock_h, matmul_partials_cb, out_cb_id); } pack_untilize_uninit(matmul_partials_cb); } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp index b4760a862f5..37c8edb7701 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_receiver_conv_weights_tiled_col_to_rm_blocks.cpp @@ -48,12 +48,6 @@ void kernel_main() { constexpr uint32_t out_addr = get_compile_time_arg_val(29); -#ifdef UNPAD_UNTILIZE_OUT - constexpr uint32_t out_block_width_ntiles = get_compile_time_arg_val(33); - constexpr uint32_t out_block_width_padded_bytes = get_compile_time_arg_val(34); - constexpr uint32_t out_block_width_bytes = get_compile_time_arg_val(35); - constexpr uint32_t untilized_padded_out_cb = get_compile_time_arg_val(36); -#endif uint32_t i = 0; i += 19; uint32_t out_start_tile_id = get_arg_val(i); @@ -194,30 +188,8 @@ void kernel_main() { } // out_num_blocks_w #ifdef SHARDED_OUT -#ifdef UNPAD_UNTILIZE_OUT - uint32_t dst_cb_addr = get_write_ptr(cb_id_out0); - - uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb); - for (uint32_t nbw = 0; nbw < out_num_blocks_w; nbw++) { - for (uint32_t nbh = 0; nbh < out_num_blocks_h; nbh++) { - for (uint32_t bh = 0; bh < out_block_height_num_tiles; bh++) { - cb_wait_front(untilized_padded_out_cb, out_block_width_ntiles); - uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb); - for (uint32_t r = 0; r < 32; r++) { - noc_async_read(get_noc_addr(src_cb_addr), dst_cb_addr, out_block_width_bytes); - noc_async_read_barrier(); - src_cb_addr += out_block_width_padded_bytes; - - dst_cb_addr += out_aligned_page_size; - } - cb_pop_front(untilized_padded_out_cb, out_block_width_ntiles); - } - } - } -#else cb_wait_front( cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); #endif -#endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp index 0053e2c68d2..88744e90369 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/device/kernels/writer_tiled_out_2d_mcast_sender_conv_weights_tiled_col_to_rm_blocks.cpp @@ -49,12 +49,6 @@ void kernel_main() { constexpr uint32_t out_addr = get_compile_time_arg_val(29); -#ifdef UNPAD_UNTILIZE_OUT - constexpr uint32_t out_block_width_ntiles = get_compile_time_arg_val(33); - constexpr uint32_t out_block_width_padded_bytes = get_compile_time_arg_val(34); - constexpr uint32_t out_block_width_bytes = get_compile_time_arg_val(35); - constexpr uint32_t untilized_padded_out_cb = get_compile_time_arg_val(36); -#endif uint32_t i = 0; i += 1; const uint32_t weight_addr_dram_base = get_arg_val(i); @@ -337,30 +331,8 @@ void kernel_main() { weight_start_tile_id += weight_next_block_stride_w; } // out_num_blocks_w #ifdef SHARDED_OUT -#ifdef UNPAD_UNTILIZE_OUT - uint32_t dst_cb_addr = get_write_ptr(cb_id_out0); - - uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb); - for (uint32_t nbw = 0; nbw < out_num_blocks_w; nbw++) { - for (uint32_t nbh = 0; nbh < out_num_blocks_h; nbh++) { - for (uint32_t bh = 0; bh < out_block_height_num_tiles; bh++) { - cb_wait_front(untilized_padded_out_cb, out_block_width_ntiles); - uint32_t src_cb_addr = get_read_ptr(untilized_padded_out_cb); - for (uint32_t r = 0; r < 32; r++) { - noc_async_read(get_noc_addr(src_cb_addr), dst_cb_addr, out_block_width_bytes); - noc_async_read_barrier(); - src_cb_addr += out_block_width_padded_bytes; - - dst_cb_addr += out_aligned_page_size; - } - cb_pop_front(untilized_padded_out_cb, out_block_width_ntiles); - } - } - } -#else cb_wait_front( cb_id_out0, out_subblock_tile_count * out_num_subblocks_h * out_num_subblocks_w * out_num_blocks_w * out_num_blocks_h); #endif -#endif } diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 2678a4ce2af..2f7b82a170e 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -504,24 +504,17 @@ ttnn::Tensor conv_bias_layout_convert( uint32_t weight_block_w_ntiles, const ParallelConfig& parallel_config, T* device, - uint32_t out_channels, - bool is_non_tile_mul_width) { + uint32_t out_channels) { ttnn::Tensor bias_tensor_ = bias_tensor; validate_bias_tensor(bias_tensor_); - if (!is_non_tile_mul_width) { - const auto& bias_shape = bias_tensor_.get_logical_shape(); - TT_FATAL(bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1, "bias shape is not correct"); - ttnn::Shape bias_channels_padded_shape({1, 1, 32, round_up(out_channels, weight_block_w_ntiles * 32)}); - bias_tensor_ = - ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D{0, 0, 0, 0}, 0); - bias_tensor_ = ttnn::to_layout(bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); - if (bias_tensor_.get_dtype() != bias_dtype) { - bias_tensor_ = ttnn::to_dtype(bias_tensor_, bias_dtype); - } - } else { - uint32_t num_cores_channels = get_num_cores_channels_from_parallel_config(parallel_config); - bias_tensor_ = - convert_conv_bias_tensor_to_tiled_layout_block_sharded(bias_tensor_, num_cores_channels, bias_dtype); + const auto& bias_shape = bias_tensor_.get_logical_shape(); + TT_FATAL(bias_shape[0] == 1 && bias_shape[1] == 1 && bias_shape[2] == 1, "bias shape is not correct"); + ttnn::Shape bias_channels_padded_shape({1, 1, 32, round_up(out_channels, weight_block_w_ntiles * 32)}); + bias_tensor_ = + ttnn::pad(bias_tensor_, bias_channels_padded_shape.to_array_4D(), tt::tt_metal::Array4D{0, 0, 0, 0}, 0); + bias_tensor_ = ttnn::to_layout(bias_tensor_, Layout::TILE, std::nullopt, std::nullopt, (T*)nullptr); + if (bias_tensor_.get_dtype() != bias_dtype) { + bias_tensor_ = ttnn::to_dtype(bias_tensor_, bias_dtype); } return bias_tensor_; } @@ -569,10 +562,6 @@ static OptimizedConvBlockConfig get_opt_block_config( ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - const bool use_non_tile_height = check_non_tile_height(conv_config, out_channels); - - bool is_non_tile_mul_width = check_non_tile_mul_width(compute_grid_size, conv_config, in_channels); - if (input_memory_config.is_sharded() && !conv_config.reshard_if_not_optimal) { conv_config.shard_layout = input_memory_config.memory_layout; } @@ -593,8 +582,7 @@ static OptimizedConvBlockConfig get_opt_block_config( compute_grid_size, shard_orientation, !mm_conv, - !use_non_tile_height, - is_non_tile_mul_width, + true, conv_config.act_block_h_override); } auto output_parallel_config = parallel_config; @@ -610,11 +598,11 @@ static OptimizedConvBlockConfig get_opt_block_config( log_debug(tt::LogOp, "Changing width sharded output grid to {}", output_parallel_config.grid); } - uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; auto conv_out_memory_config = create_sharded_memory_config_from_parallel_config( - ttnn::Shape({1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, 32)}), + ttnn::Shape( + {1, 1, batch_size * output_height * output_width, tt::round_up(out_channels, tt::constants::TILE_WIDTH)}), output_parallel_config, - round_up_size); + tt::constants::TILE_HEIGHT); auto largest_parallel_config = output_parallel_config.grid.num_cores() > parallel_config.grid.num_cores() ? output_parallel_config : parallel_config; @@ -657,8 +645,7 @@ std::pair> prepare_conv_weights_biases uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device, - bool is_non_tile_mul_width) { + const bool parameters_on_device) { validate_weight_tensor(weight_tensor); ttnn::Tensor weight_tensor_; // tensor to return ttnn::Tensor bias_tensor_; @@ -701,11 +688,7 @@ std::pair> prepare_conv_weights_biases uint32_t out_channel_padding = out_channels_padded - out_channels; ttnn::Shape weights_channels_padded_shape({out_channels_padded, in_channels_padded, window_h, window_w}); - if (is_non_tile_mul_width) { - weights_channels_padded_shape = ttnn::Shape( - {round_up(out_channels, 32), round_up(in_channels, input_channels_alignment), window_h, window_w}); - out_channels_padded = tt::round_up(out_channels, 32); - } + if (weights_bias_dtype == DataType::BFLOAT8_B) { TT_ASSERT(weight_tensor_.get_dtype() == DataType::FLOAT32); if (bias_tensor.has_value()) { @@ -757,8 +740,7 @@ std::pair> prepare_conv_weights_biases weight_block_w_ntiles, output_parallel_config, device, - out_channels_padded, - is_non_tile_mul_width); + out_channels_padded); bias_tensor_ = ttnn::operations::core::to_device(bias_tensor_, device, std::nullopt); } } @@ -819,10 +801,6 @@ ttnn::Tensor prepare_conv_weights( ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - const bool use_non_tile_height = check_non_tile_height(conv_config, out_channels); - bool is_non_tile_mul_width = - check_non_tile_mul_width(device->compute_with_storage_grid_size(), conv_config, in_channels); - if (input_memory_config.is_sharded() && !conv_config.reshard_if_not_optimal) { conv_config.shard_layout = input_memory_config.memory_layout; } @@ -844,8 +822,7 @@ ttnn::Tensor prepare_conv_weights( device->compute_with_storage_grid_size(), shard_orientation, !mm_conv, - !use_non_tile_height, - is_non_tile_mul_width, + true, conv_config.act_block_h_override); } @@ -867,9 +844,7 @@ ttnn::Tensor prepare_conv_weights( device, groups, opt_conv_op_block_config.act_block_h_ntiles, - input_width, - false, - is_non_tile_mul_width); + input_width); return weight_tensor_on_device; } @@ -928,13 +903,10 @@ ttnn::Tensor prepare_conv_bias( ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - const bool use_non_tile_height = check_non_tile_height(conv_config, out_channels); - if (input_memory_config.is_sharded() && !conv_config.reshard_if_not_optimal) { conv_config.shard_layout = input_memory_config.memory_layout; } CoreCoord compute_grid = device->compute_with_storage_grid_size(); - bool is_non_tile_mul_width = check_non_tile_mul_width(compute_grid, conv_config, in_channels); ParallelConfig parallel_config; if (input_memory_config.shard_spec.has_value() && !conv_config.reshard_if_not_optimal) { parallel_config = { @@ -952,8 +924,7 @@ ttnn::Tensor prepare_conv_bias( compute_grid, shard_orientation, !mm_conv, - !use_non_tile_height, - is_non_tile_mul_width, + true, conv_config.act_block_h_override); } @@ -970,8 +941,7 @@ ttnn::Tensor prepare_conv_bias( weight_block_w_ntiles, output_parallel_config, device, - out_channels, - is_non_tile_mul_width); + out_channels); return bias_tensor_; } @@ -1028,8 +998,7 @@ template std::pair> prepare_conv_weigh uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device, - bool is_non_tile_mul_width); + const bool parameters_on_device); template std::pair> prepare_conv_weights_biases_and_move_to_device( @@ -1045,8 +1014,7 @@ prepare_conv_weights_biases_and_move_to_device( uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device, - bool is_non_tile_mul_width); + const bool parameters_on_device); template ttnn::Tensor prepare_conv_bias( const ttnn::Tensor& bias_tensor, @@ -1091,8 +1059,7 @@ template ttnn::Tensor conv_bias_layout_convert( uint32_t weight_block_w_ntiles, const sliding_window::ParallelConfig& parallel_config, IDevice* device, - uint32_t out_channels, - bool is_non_tile_mul_width); + uint32_t out_channels); template ttnn::Tensor conv_bias_layout_convert( const ttnn::Tensor& bias_tensor, @@ -1101,8 +1068,7 @@ template ttnn::Tensor conv_bias_layout_convert( uint32_t weight_block_w_ntiles, const sliding_window::ParallelConfig& parallel_config, MeshDevice* device, - uint32_t out_channels, - bool is_non_tile_mul_width); + uint32_t out_channels); } // namespace conv2d } // namespace operations::conv diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp index d1951b8bb33..5377a62a345 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.hpp @@ -117,8 +117,7 @@ std::pair> prepare_conv_weights_biases uint32_t groups, uint32_t act_block_h_ntiles, uint32_t input_width, - const bool parameters_on_device = true, - bool is_non_tile_mul_width = false); + const bool parameters_on_device = true); } // namespace conv2d } // namespace operations::conv diff --git a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp index 7c5ab221a0e..d9e4f831fb5 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv_transpose2d/conv_transpose2d.cpp @@ -202,26 +202,25 @@ Result conv_transpose2d( } // Call Halo Transpose - auto [input_tensor_post_tm, parallel_config, output_parallel_config, use_non_tile_height] = - shard_or_reshard_tensor_if_required( - device, - input_tensor, - conv_config, - batch_size, - output_height, - output_width, - in_channels, - out_channels, - mm_conv, - auto_shard); + auto [input_tensor_post_tm, parallel_config, output_parallel_config] = shard_or_reshard_tensor_if_required( + device, + input_tensor, + conv_config, + batch_size, + output_height, + output_width, + in_channels, + out_channels, + mm_conv, + auto_shard); - uint32_t round_up_size = !use_non_tile_height ? tt::constants::TILE_HEIGHT : 1; + uint32_t round_up_size = tt::constants::TILE_HEIGHT; Tensor halo_output; if (!mm_conv) { sliding_window_config.num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); sliding_window_config.core_range_set = input_tensor_post_tm.memory_config().shard_spec.value().grid; - sliding_window_config.snap_to_tile = !use_non_tile_height; + sliding_window_config.snap_to_tile = true; halo_output = ttnn::halo( DefaultQueueId,