Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17972 and #17975 Fixing PCC and Program Cache issues in Repeat and Expand #18002

Merged
merged 17 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 121 additions & 7 deletions tests/ttnn/unit_tests/operations/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@
layouts = [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]

dtypes = [(torch.float32, ttnn.float32), (torch.bfloat16, ttnn.bfloat16), (torch.bfloat16, ttnn.bfloat8_b)]
shapes = [(1,), (2,), (2, 1), (2, 3), (2, 1, 3), (4, 16, 3, 2), (4, 3, 1, 2, 2)]
shapes = [(1,), (2,), (2, 3), (4, 16, 3, 1), (4, 3, 1, 2, 2)]
repeat_shapes = [
(1,),
(2,),
(1, 2),
(1, 4),
(2, 1, 3),
(1, 2, 3),
(4, 3, 2, 1),
(2, 3, 4, 5, 2),
(2, 1, 3, 1, 3, 1),
(2048,),
]

Expand Down Expand Up @@ -75,4 +70,123 @@ def test_repeat(device, layout, dtype, shape, repeat_shape):
assert_with_pcc(torch_result, output, 0.9999)


# TODO! test program cache when it is implemented
@pytest.mark.parametrize("layout", layouts)
@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("repeat_shape", repeat_shapes)
def test_pc_repeat(device, layout, shape, repeat_shape, use_program_cache):
# trying to avoid the `buffer not divisible by page size` error. Does this make sense?
if layout == ttnn.TILE_LAYOUT and (
prod(shape) % ttnn.TILE_SIZE != 0 or _get_final_size(shape, repeat_shape) % ttnn.TILE_SIZE != 0
):
pytest.skip("Tensor not suitable for tile layout")

if len(repeat_shape) < len(shape):
pytest.skip("PyTorch repeat dim must be >= tensor dim (although we can handle this).")
num_iters = 3
input_tensors = []
torch_results = []
for i in range(num_iters):
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
torch_results.append(torch_tensor.repeat(repeat_shape))
input_tensors.append(ttnn.from_torch(torch_tensor, layout=layout, device=device, dtype=ttnn.bfloat16))
for i in range(num_iters):
output = ttnn.repeat(input_tensors[i], ttnn.Shape(repeat_shape))
output = ttnn.to_torch(output)
assert (
output.shape == torch_results[i].shape
), f"Output shape {output.shape} does not match torch shape {torch_results[i].shape}"

assert_with_pcc(torch_results[i], output, 0.9999)
if i == 0:
base_program_cache_entires = device.num_program_cache_entries()
else:
assert (
device.num_program_cache_entries() == base_program_cache_entires,
"program cache entries differ on same configs",
)


# 17975 test cases


def test_pc_with_different_shapes_in_sequence(device, use_program_cache):
y = torch.rand((1, 1, 256, 384), dtype=torch.bfloat16)
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
base_program_cache_entires = device.num_program_cache_entries()

x = torch.zeros((64, 1, 256, 384), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
num_iters = 4
z_tt = x_tt + y_tt

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
for _ in range(num_iters):
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
assert (
device.num_program_cache_entries() == base_program_cache_entires,
"program cache entries differ on same configs",
)

x = torch.zeros((64, 1, 256, 384), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

z_tt = x_tt + y_tt

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
y = torch.rand((1, 1, 32, 32), dtype=torch.bfloat16)

y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
base_program_cache_entires = device.num_program_cache_entries()

x = torch.zeros((4, 1, 32, 32), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

ttnn.repeat(y_tt, [4, 1, 1, 1])
z_tt = ttnn.experimental.add(x_tt, y_tt)
# z_tt = x_tt + y_tt

for i in range(num_iters):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
for _ in range(num_iters):
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
assert (
device.num_program_cache_entries() == base_program_cache_entires,
"program cache entries differ on same configs",
)

x = torch.zeros((4, 1, 32, 32), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

ttnn.repeat(y_tt, [4, 1, 1, 1])
z_tt = ttnn.experimental.add(x_tt, y_tt)
# z_tt = x_tt + y_tt

for i in range(num_iters):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
y = torch.rand((1, 1, 256, 384), dtype=torch.bfloat16)

y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
base_program_cache_entires = device.num_program_cache_entries()
z_tt = ttnn.repeat(y_tt, ttnn.Shape([64, 1, 1, 1]))

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
for _ in range(num_iters):
y = torch.rand((1, 1, 256, 384), dtype=torch.bfloat16)
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
assert (
device.num_program_cache_entries() == base_program_cache_entires,
"program cache entries differ on same configs",
)
z_tt = ttnn.repeat(y_tt, ttnn.Shape([64, 1, 1, 1]))

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,23 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater_last_dim(
}
}
}
return {.program = std::move(program)};
auto override_runtime_args_callback = [reader_kernel_id, total_cores](
const void* operation,
const tt::tt_metal::Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
auto input = input_tensors.at(0);
auto output = output_tensors.at(0);
auto& runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id);
for (const auto& core : total_cores) {
auto& runtime_args = runtime_args_by_core[core.x][core.y];
runtime_args.at(0) = input.buffer()->address();
runtime_args.at(1) = output.buffer()->address();
}
};

return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback};
}

tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
Expand Down Expand Up @@ -162,15 +178,17 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
uint32_t cb_size_bytes = READ_ALIGNMENT * 2 + page_size_bytes;
uint32_t src0_cb_index = 0;
uint32_t src1_cb_index = 1;

tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, cb_size_bytes);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config);

tt::tt_metal::CircularBufferConfig cb_src1_config =
tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, cb_size_bytes);

auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config);

bool page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(page_size_bytes);
uint32_t page_pow_2 = page_is_pow_2 ? (std::uint32_t)std::log2(page_size_bytes) : 0;
std::vector<uint32_t> compile_time_args = {
Expand Down Expand Up @@ -245,7 +263,22 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
}
}
}
return {.program = std::move(program)};
auto override_runtime_args_callback = [reader_kernel_id, total_cores](
const void* operation,
const tt::tt_metal::Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
auto input = input_tensors.at(0);
auto output = output_tensors.at(0);
auto& runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id);
for (const auto& core : total_cores) {
auto& runtime_args = runtime_args_by_core[core.x][core.y];
runtime_args.at(0) = input.buffer()->address();
runtime_args.at(1) = output.buffer()->address();
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback};
}

tt::tt_metal::operation::ProgramWithCallbacks rm_repeat_program_factory(
Expand Down
Loading