Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17972 and #17975 Fixing PCC and Program Cache issues in Repeat and Expand #18002

Merged
merged 17 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 79 additions & 7 deletions tests/ttnn/unit_tests/operations/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@
layouts = [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]

dtypes = [(torch.float32, ttnn.float32), (torch.bfloat16, ttnn.bfloat16), (torch.bfloat16, ttnn.bfloat8_b)]
shapes = [(1,), (2,), (2, 1), (2, 3), (2, 1, 3), (4, 16, 3, 2), (4, 3, 1, 2, 2)]
shapes = [(1,), (2,), (2, 3), (4, 16, 3, 1), (4, 3, 1, 2, 2)]
repeat_shapes = [
(1,),
(2,),
(1, 2),
(1, 4),
(2, 1, 3),
(1, 2, 3),
(4, 3, 2, 1),
(2, 3, 4, 5, 2),
(2, 1, 3, 1, 3, 1),
(2048,),
]

Expand Down Expand Up @@ -75,4 +70,81 @@ def test_repeat(device, layout, dtype, shape, repeat_shape):
assert_with_pcc(torch_result, output, 0.9999)


# TODO! test program cache when it is implemented
@pytest.mark.parametrize("layout", layouts)
@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("repeat_shape", repeat_shapes)
def test_pc_repeat(device, layout, shape, repeat_shape, use_program_cache):
# trying to avoid the `buffer not divisible by page size` error. Does this make sense?
if layout == ttnn.TILE_LAYOUT and (
prod(shape) % ttnn.TILE_SIZE != 0 or _get_final_size(shape, repeat_shape) % ttnn.TILE_SIZE != 0
):
pytest.skip("Tensor not suitable for tile layout")

if len(repeat_shape) < len(shape):
pytest.skip("PyTorch repeat dim must be >= tensor dim (although we can handle this).")
num_iters = 3
input_tensors = []
torch_results = []
for i in range(num_iters):
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
torch_results.append(torch_tensor.repeat(repeat_shape))
input_tensors.append(ttnn.from_torch(torch_tensor, layout=layout, device=device, dtype=ttnn.bfloat16))
for i in range(num_iters):
output = ttnn.repeat(input_tensors[i], ttnn.Shape(repeat_shape))
output = ttnn.to_torch(output)
assert (
output.shape == torch_results[i].shape
), f"Output shape {output.shape} does not match torch shape {torch_results[i].shape}"

assert_with_pcc(torch_results[i], output, 0.9999)


# 17975 test cases


def test_17975_a(device, use_program_cache):
y = torch.rand((1, 1, 256, 384), dtype=torch.bfloat16)

for _ in range(10):
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
print("program cache: ", device.num_program_cache_entries())

x = torch.zeros((64, 1, 256, 384), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

z_tt = x_tt + y_tt

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"


def test_17975_b(device, use_program_cache):
y = torch.rand((1, 1, 32, 32), dtype=torch.bfloat16)

for _ in range(10):
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
print("program cache: ", device.num_program_cache_entries())

x = torch.zeros((4, 1, 32, 32), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

ttnn.repeat(y_tt, [4, 1, 1, 1])
z_tt = ttnn.experimental.add(x_tt, y_tt)
# z_tt = x_tt + y_tt

for i in range(4):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"


def test_17975_c(device, use_program_cache):
for _ in range(10):
y = torch.rand((1, 1, 256, 384), dtype=torch.bfloat16)

y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.repeat(y_tt, ttnn.Shape([64, 1, 1, 1]))

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,42 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater_last_dim(
}
}
}
return {.program = std::move(program)};
auto override_runtime_args_callback = [reader_kernel_id, num_cores_x, num_cores_y, number_of_pages, responsibility](
const void* operation,
const tt::tt_metal::Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
uint32_t read_start_page = 0;
uint32_t done = 0;
auto input = input_tensors.at(0);
auto output = output_tensors.at(0);
for (int core_x = 0; core_x < num_cores_x; core_x++) {
for (int core_y = 0; core_y < num_cores_y; core_y++) {
CoreCoord core = {core_x, core_y};
if (done == 1) {
const std::vector<uint32_t> reader_runtime_args = {0, 0, 0, 0, 1};
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
} else {
// set the runtime args
// set the compile time args
const uint32_t start_of_read = read_start_page;
uint32_t end_of_read = read_start_page + responsibility;
end_of_read = end_of_read < number_of_pages ? end_of_read : number_of_pages;

const std::vector<uint32_t> reader_runtime_args = {
input.buffer()->address(), output.buffer()->address(), start_of_read, end_of_read, 0

};
read_start_page = end_of_read;
done = (end_of_read == number_of_pages) ? 1 : 0;
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
}
}
}
};

return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback};
}

tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
Expand Down Expand Up @@ -162,15 +197,17 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
uint32_t cb_size_bytes = READ_ALIGNMENT * 2 + page_size_bytes;
uint32_t src0_cb_index = 0;
uint32_t src1_cb_index = 1;

tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, cb_size_bytes);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config);

tt::tt_metal::CircularBufferConfig cb_src1_config =
tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, cb_size_bytes);

auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config);

bool page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(page_size_bytes);
uint32_t page_pow_2 = page_is_pow_2 ? (std::uint32_t)std::log2(page_size_bytes) : 0;
std::vector<uint32_t> compile_time_args = {
Expand Down Expand Up @@ -245,7 +282,75 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
}
}
}
return {.program = std::move(program)};
auto override_runtime_args_callback = [reader_kernel_id,
num_repeats,
num_cores_total,
num_cores_x,
num_cores_y,
number_of_higher_pages,
number_of_lower_pages,
divide_on_higher,
responsibility_chunk,
responsibility_mod](
const void* operation,
const tt::tt_metal::Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
auto input = input_tensors.at(0);
auto output = output_tensors.at(0);
tt::tt_metal::Buffer* src_buffer = input.buffer();
tt::tt_metal::Buffer* dst_buffer = output.buffer();
uint32_t done = 0;
uint32_t read_start_page = 0;
uint32_t core_count = 0;
for (int core_x = 0; core_x < num_cores_x; core_x++) {
for (int core_y = 0; core_y < num_cores_y; core_y++) {
uint32_t responsibility =
core_count++ < responsibility_mod ? responsibility_chunk + 1 : responsibility_chunk;
CoreCoord core = {core_x, core_y};
if (done == 1) {
const std::vector<uint32_t> reader_runtime_args = {0, 0, 0, 0, 0, 0, 0, 1};
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
} else if (divide_on_higher) {
const uint32_t start_of_read = read_start_page;
uint32_t end_of_read = read_start_page + responsibility;
end_of_read = end_of_read < number_of_higher_pages ? end_of_read : number_of_higher_pages;

const std::vector<uint32_t> reader_runtime_args = {
src_buffer->address(),
dst_buffer->address(),
start_of_read,
end_of_read,
0,
number_of_lower_pages,
num_repeats,
0};
read_start_page = end_of_read;
done = (end_of_read == number_of_higher_pages) ? 1 : 0;
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
} else {
const uint32_t start_of_read = read_start_page;
uint32_t end_of_read = read_start_page + responsibility;
end_of_read = end_of_read < number_of_lower_pages ? end_of_read : number_of_lower_pages;

const std::vector<uint32_t> reader_runtime_args = {
src_buffer->address(),
dst_buffer->address(),
0,
number_of_higher_pages,
start_of_read,
end_of_read,
num_repeats,
0};
read_start_page = end_of_read;
done = (end_of_read == number_of_lower_pages) ? 1 : 0;
tt::tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args);
}
}
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback};
}

tt::tt_metal::operation::ProgramWithCallbacks rm_repeat_program_factory(
Expand Down
Loading