Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17662: Conv2d fix split reader #17936

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions models/demos/ttnn_resnet/tt/ttnn_functional_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,7 @@ def __init__(
if type(device) == ttnn.MeshDevice and device.get_num_devices() > 8:
self.conv1_config.act_block_h_override = 64
else:
# Todo: restore after issue #16895 is fixed
# self.conv1_config.act_block_h_override = 49 * 32
self.conv1_config.act_block_h_override = 2 * 32
self.conv1_config.act_block_h_override = 49 * 32
if is_blackhole():
# self.conv1_config.act_block_h_override = 7 * 32
# self.conv1_config.act_block_h_override = 2 * 32
Expand Down
51 changes: 50 additions & 1 deletion tests/ttnn/unit_tests/operations/test_new_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def run_conv(
input_mesh_mapper=None,
weight_mesh_mapper=None,
output_mesh_composer=None,
enable_split_reader=False,
):
if isinstance(device, ttnn.MeshDevice):
assert input_mesh_mapper is not None, "Expected mesh mapper for input tensor when using device mesh"
Expand Down Expand Up @@ -130,7 +131,7 @@ def run_conv(
input_channels_alignment=8 if use_shallow_conv_variant and not auto_shard else 32,
deallocate_activation=deallocate_activation,
enable_act_double_buffer=False,
enable_split_reader=False,
enable_split_reader=enable_split_reader,
enable_subblock_padding=False,
output_layout=output_layout,
)
Expand Down Expand Up @@ -2917,3 +2918,51 @@ def test_dram_input_mm_conv(device, tiled_input, input_on_device):
passing, pcc_msg = check_with_pcc_without_tensor_printout(torch_output_tensor, torch_out_golden_tensor, pcc=0.99)
logger.info(f"PCC = {pcc_msg}. Threshold = 0.99")
assert passing


@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
@pytest.mark.parametrize(
"batch_size, output_channels, input_channels, input_height, input_width, filter_height, filter_width, stride_h, stride_w, pad_h, pad_w, shard_layout, config_override",
((16, 64, 16, 115, 115, 4, 4, 1, 1, 0, 0, HS, {"act_block_h": 32 * 49}),),
)
def test_split_reader_regression(
device,
torch_tensor_map,
use_program_cache,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
shard_layout,
config_override,
):
run_conv(
device,
torch_tensor_map,
ttnn.MathFidelity.LoFi,
ttnn.bfloat8_b,
ttnn.bfloat8_b,
batch_size,
output_channels,
input_channels,
input_height,
input_width,
filter_height,
filter_width,
stride_h,
stride_w,
pad_h,
pad_w,
config_override=config_override,
use_shallow_conv_variant=True,
has_bias=False,
shard_layout=shard_layout,
enable_split_reader=True,
)
8 changes: 7 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,13 @@ static std::pair<uint32_t, uint32_t> determine_largest_subblock_size(
break;
}
}
TT_ASSERT(subblock_h > 0 && subblock_w > 0);
TT_FATAL(
subblock_h > 0 && subblock_w > 0,
"Could not find valid subblock size for block size {}x{}, split_reader_enabled: {}, fp32_accum: {}",
block_height,
block_width,
split_reader_enabled,
fp32_accum);
return {subblock_h, subblock_w};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,8 @@ operation::ProgramWithCallbacks multi_core_optimized_conv_sharded_v2_impl(
(uint32_t)act_mcast_receiver_semaphore_id,
(uint32_t)in0_block_num_tiles * tilized_act_tile_size, // act_mcast_sender_size_bytes
(uint32_t)(transpose_mcast ? 1 : 0),
(uint32_t)act_block_h_datums_last_block};
(uint32_t)act_block_h_datums_last_block,
(uint32_t)act_block_h_datums_split_last};

// define for bias
std::map<string, string> writer_defines;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ void kernel_main() {

constexpr uint32_t act_block_h_datums_read_last_block =
act_block_h_datums_last_block > act_block_h_datums ? act_block_h_datums / 2 : act_block_h_datums_last_block / 2;
constexpr uint32_t act_block_h_datums_second_reader = get_compile_time_arg_val(26);
constexpr uint32_t act_block_h_datums_second_reader_read = act_block_h_datums_second_reader / 2;

uint32_t i = 0;
uint32_t noop = get_arg_val<uint32_t>(i);
Expand Down Expand Up @@ -150,7 +152,7 @@ void kernel_main() {

start_reader_idx = reader_idx;
#ifdef SPLIT_READER
start_reader_idx += act_block_h_datums_read;
start_reader_idx += act_block_h_datums_second_reader_read;
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ void kernel_main() {
constexpr uint32_t total_weight_num_tiles =
weight_block_height_num_outer * num_blocks_weight_h * weight_block_num_tiles;

constexpr uint32_t act_block_h_datums_first_reader_read = act_block_h_datums_first_reader / 2;

uint32_t i = 0;
i += 19;
uint32_t out_start_tile_id = get_arg_val<uint32_t>(i);
Expand Down Expand Up @@ -254,7 +256,7 @@ void kernel_main() {
out_block_h_start_tile_id_h += out_block_height_num_tiles;
#endif

start_reader_idx = reader_idx + act_block_h_datums_read;
start_reader_idx = reader_idx + act_block_h_datums_first_reader_read;
} // out_num_blocks_h
out_block_w_start_tile_id += out_next_block_stride_w;
out_block_w_start_tile_id_w += weight_block_width_ntiles;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ void kernel_main() {
constexpr uint32_t cb_id_act_second_reader = 7;
constexpr uint32_t cb_id_sharded_act = 3;
constexpr uint32_t act_block_h_datums_read = act_block_h_datums / 2; // Extra /2 because of packed uint16 reads
constexpr uint32_t act_block_h_datums_first_reader_read =
act_block_h_datums_first_reader / 2; // Extra /2 because of packed uint16 reads
constexpr uint32_t act_block_num_tiles_read = act_block_num_tiles;

constexpr uint32_t cb_reader_indices = tt::CBIndex::c_4;
Expand Down Expand Up @@ -401,8 +403,7 @@ void kernel_main() {
out_block_h_start_tile_id += out_next_block_stride_h;
out_block_h_start_tile_id_h += out_block_height_num_tiles;
#endif

start_reader_idx = reader_idx + act_block_h_datums_read;
start_reader_idx = reader_idx + act_block_h_datums_first_reader_read;
} // out_num_blocks_h
out_block_w_start_tile_id += out_next_block_stride_w;
out_block_w_start_tile_id_w += weight_block_width_ntiles;
Expand Down
Loading