Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17972 and #17975 Fixing PCC and Program Cache issues in Repeat and Expand #18002

Merged
merged 17 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions tests/ttnn/unit_tests/operations/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
from models.utility_functions import comp_pcc
from tests.ttnn.utils_for_testing import assert_with_pcc

from tqdm import tqdm

layouts = [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]

dtypes = [(torch.float32, ttnn.float32), (torch.bfloat16, ttnn.bfloat16), (torch.bfloat16, ttnn.bfloat8_b)]
shapes = [(1,), (2,), (2, 1), (2, 3), (2, 1, 3), (4, 16, 3, 2), (4, 3, 1, 2, 2)]
dtypes = [(torch.float32, ttnn.float32), (torch.bfloat16, ttnn.bfloat16)]
shapes = [(1,), (2,), (2, 3), (4, 16, 3, 1), (4, 3, 1, 2, 2)]
repeat_shapes = [
(1,),
(2,),
(1, 2),
(1, 4),
(2, 1, 3),
(1, 2, 3),
(4, 3, 2, 1),
(2, 3, 4, 5, 2),
(2, 1, 3, 1, 3, 1),
(2048,),
]

Expand Down Expand Up @@ -75,4 +72,27 @@ def test_repeat(device, layout, dtype, shape, repeat_shape):
assert_with_pcc(torch_result, output, 0.9999)


# TODO! test program cache when it is implemented
@pytest.mark.parametrize("layout", layouts)
@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("repeat_shape", repeat_shapes)
def test_pc_repeat(device, layout, shape, repeat_shape):
# trying to avoid the `buffer not divisible by page size` error. Does this make sense?
if layout == ttnn.TILE_LAYOUT and (
prod(shape) % ttnn.TILE_SIZE != 0 or _get_final_size(shape, repeat_shape) % ttnn.TILE_SIZE != 0
):
pytest.skip("Tensor not suitable for tile layout")

if len(repeat_shape) < len(shape):
pytest.skip("PyTorch repeat dim must be >= tensor dim (although we can handle this).")
device.enable_program_cache()
for _ in tqdm(range(3)):
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
torch_result = torch_tensor.repeat(repeat_shape)
input_tensor = ttnn.from_torch(torch_tensor, layout=layout, device=device, dtype=ttnn.bfloat16)
output = ttnn.repeat(input_tensor, ttnn.Shape(repeat_shape))
output = ttnn.to_torch(output)
assert (
output.shape == torch_result.shape
), f"Output shape {output.shape} does not match torch shape {torch_result.shape}"

assert_with_pcc(torch_result, output, 0.9999)
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,55 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater_last_dim(
}
}
}
return {.program = std::move(program)};
auto override_runtime_args_callback = [reader_kernel_id](
const void* operation,
const tt::tt_metal::Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
auto input = input_tensors.at(0);
auto output = output_tensors.at(0);
ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view());
ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view());
const uint32_t data_size = input.element_size();
tt::tt_metal::IDevice* device = input.device();
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;
uint32_t read_start_page = 0;
tt::tt_metal::Buffer* src_buffer = input.buffer();
tt::tt_metal::Buffer* dst_buffer = output.buffer();
uint32_t number_of_pages = input_log_shape[-2];
uint32_t responsibility = ((number_of_pages - 1) / num_cores_total) + 1;
uint32_t done = 0;
for (int core_x = 0; core_x < num_cores_x; core_x++) {
for (int core_y = 0; core_y < num_cores_y; core_y++) {
CoreCoord core = {core_x, core_y};
if (done == 1) {
const std::vector<uint32_t> reader_runtime_args = {
src_buffer->address(), dst_buffer->address(), 0, 0, 1};
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
} else {
// set the runtime args
// set the compile time args
const uint32_t start_of_read = read_start_page;
uint32_t end_of_read = read_start_page + responsibility;
end_of_read = end_of_read < number_of_pages ? end_of_read : number_of_pages;

const std::vector<uint32_t> reader_runtime_args = {
src_buffer->address(), dst_buffer->address(), start_of_read, end_of_read, 0

};
read_start_page = end_of_read;
done = (end_of_read == input_log_shape[-2]) ? 1 : 0;
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
}
}
}
};

return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback};
}

tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
Expand Down Expand Up @@ -162,15 +210,17 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
uint32_t cb_size_bytes = READ_ALIGNMENT * 2 + page_size_bytes;
uint32_t src0_cb_index = 0;
uint32_t src1_cb_index = 1;

tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, cb_size_bytes);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config);

tt::tt_metal::CircularBufferConfig cb_src1_config =
tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, cb_size_bytes);

auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config);

bool page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(page_size_bytes);
uint32_t page_pow_2 = page_is_pow_2 ? (std::uint32_t)std::log2(page_size_bytes) : 0;
std::vector<uint32_t> compile_time_args = {
Expand Down Expand Up @@ -245,7 +295,93 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
}
}
}
return {.program = std::move(program)};
auto override_runtime_args_callback = [reader_kernel_id, num_repeats](
const void* operation,
const tt::tt_metal::Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
auto input = input_tensors.at(0);
auto output = output_tensors.at(0);
tt::DataFormat cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype());
const uint32_t data_size = input.element_size();
tt::tt_metal::IDevice* device = input.device();
// Multi device pre-computation
auto compute_with_storage_grid_size = device->compute_with_storage_grid_size();
uint32_t num_cores_x = compute_with_storage_grid_size.x;
uint32_t num_cores_y = compute_with_storage_grid_size.y;
uint32_t num_cores_total = num_cores_x * num_cores_y;
ttnn::Shape input_log_shape = ttnn::Shape(input.get_logical_shape().view());
ttnn::Shape output_log_shape = ttnn::Shape(output.get_logical_shape().view());
uint32_t page_size_bytes = input_log_shape[3] * data_size;
TT_ASSERT(
page_size_bytes == output_log_shape[3] * data_size,
"Data size of output does not match requirement for repeat last dim");
uint32_t read_start_page = 0;
tt::tt_metal::Buffer* src_buffer = input.buffer();
tt::tt_metal::Buffer* dst_buffer = output.buffer();
uint32_t number_of_higher_pages = input_log_shape[0];
uint32_t number_of_lower_pages = input_log_shape[2];
uint32_t done = 0;
// Determine runtime argumens
bool divide_on_higher = number_of_higher_pages > number_of_lower_pages;

uint32_t responsibility_chunk =
(divide_on_higher ? number_of_higher_pages : number_of_lower_pages) / num_cores_total;
uint32_t responsibility_mod =
(divide_on_higher ? number_of_higher_pages : number_of_lower_pages) % num_cores_total;
uint32_t core_count = 0;
for (int core_x = 0; core_x < num_cores_x; core_x++) {
for (int core_y = 0; core_y < num_cores_y; core_y++) {
uint32_t responsibility =
core_count++ < responsibility_mod ? responsibility_chunk + 1 : responsibility_chunk;
CoreCoord core = {core_x, core_y};
if (done == 1) {
const std::vector<uint32_t> reader_runtime_args = {0, 0, 0, 0, 0, 0, 0, 1};
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
} else if (divide_on_higher) {
// set the runtime args
// set the compile time args
const uint32_t start_of_read = read_start_page;
uint32_t end_of_read = read_start_page + responsibility;
end_of_read = end_of_read < number_of_higher_pages ? end_of_read : number_of_higher_pages;

const std::vector<uint32_t> reader_runtime_args = {
src_buffer->address(),
dst_buffer->address(),
start_of_read,
end_of_read,
0,
number_of_lower_pages,
num_repeats,
0};
read_start_page = end_of_read;
done = (end_of_read == number_of_higher_pages) ? 1 : 0;
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
} else {
// set the runtime args
// set the compile time args
const uint32_t start_of_read = read_start_page;
uint32_t end_of_read = read_start_page + responsibility;
end_of_read = end_of_read < number_of_lower_pages ? end_of_read : number_of_lower_pages;

const std::vector<uint32_t> reader_runtime_args = {
src_buffer->address(),
dst_buffer->address(),
0,
number_of_higher_pages,
start_of_read,
end_of_read,
num_repeats,
0};
read_start_page = end_of_read;
done = (end_of_read == number_of_lower_pages) ? 1 : 0;
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
}
}
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback};
}

tt::tt_metal::operation::ProgramWithCallbacks rm_repeat_program_factory(
Expand Down
Loading