Skip to content

Commit

Permalink
#17972 and #17975 Fixing PCC and Program Cache issues in Repeat and E…
Browse files Browse the repository at this point in the history
…xpand (#18002)

### Ticket
#17975
#17972

### Problem description
This PR closes two P0 errors by applying bug fixes to the repeat program
factory and giving repeat program cache support.

### What's changed
Program factory changes to repeat
Adding Program Cache testing to the CI pipelines for repeat
Removed redundant CI tests in Repeat to help improve CI pipeline times

### Checklist
- [ ] [All post
commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml)
CI passes.
[Submitted](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml)
[rerun after PR
changes](https://github.com/tenstorrent/tt-metal/actions/runs/13422573676)
- [ ] [Blackhole Post
commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml)
CI passes (if applicable)
- [ ] [Model
regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml)
CI passes (if applicable).
[Submitted](https://github.com/tenstorrent/tt-metal/actions/runs/13416833113)
- [ ] [Device performance
regression](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml)
CI passes (if applicable)
- [ ] **(For models and ops writers)** Full [new models
tests](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
CI passes (if applicable)
- [ ] New/Existing tests provide coverage for changes
- [ ] T3K Demo.
[Submitted](https://github.com/tenstorrent/tt-metal/actions/runs/13416778070)
  • Loading branch information
jvegaTT authored Feb 20, 2025
1 parent 8e4a6e0 commit 705b94d
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 10 deletions.
128 changes: 121 additions & 7 deletions tests/ttnn/unit_tests/operations/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@
layouts = [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]

dtypes = [(torch.float32, ttnn.float32), (torch.bfloat16, ttnn.bfloat16), (torch.bfloat16, ttnn.bfloat8_b)]
shapes = [(1,), (2,), (2, 1), (2, 3), (2, 1, 3), (4, 16, 3, 2), (4, 3, 1, 2, 2)]
shapes = [(1,), (2,), (2, 3), (4, 16, 3, 1), (4, 3, 1, 2, 2)]
repeat_shapes = [
(1,),
(2,),
(1, 2),
(1, 4),
(2, 1, 3),
(1, 2, 3),
(4, 3, 2, 1),
(2, 3, 4, 5, 2),
(2, 1, 3, 1, 3, 1),
(2048,),
]

Expand Down Expand Up @@ -75,4 +70,123 @@ def test_repeat(device, layout, dtype, shape, repeat_shape):
assert_with_pcc(torch_result, output, 0.9999)


# TODO! test program cache when it is implemented
@pytest.mark.parametrize("layout", layouts)
@pytest.mark.parametrize("shape", shapes)
@pytest.mark.parametrize("repeat_shape", repeat_shapes)
def test_pc_repeat(device, layout, shape, repeat_shape, use_program_cache):
# trying to avoid the `buffer not divisible by page size` error. Does this make sense?
if layout == ttnn.TILE_LAYOUT and (
prod(shape) % ttnn.TILE_SIZE != 0 or _get_final_size(shape, repeat_shape) % ttnn.TILE_SIZE != 0
):
pytest.skip("Tensor not suitable for tile layout")

if len(repeat_shape) < len(shape):
pytest.skip("PyTorch repeat dim must be >= tensor dim (although we can handle this).")
num_iters = 3
input_tensors = []
torch_results = []
for i in range(num_iters):
torch_tensor = torch.rand(shape, dtype=torch.bfloat16)
torch_results.append(torch_tensor.repeat(repeat_shape))
input_tensors.append(ttnn.from_torch(torch_tensor, layout=layout, device=device, dtype=ttnn.bfloat16))
for i in range(num_iters):
output = ttnn.repeat(input_tensors[i], ttnn.Shape(repeat_shape))
output = ttnn.to_torch(output)
assert (
output.shape == torch_results[i].shape
), f"Output shape {output.shape} does not match torch shape {torch_results[i].shape}"

assert_with_pcc(torch_results[i], output, 0.9999)
if i == 0:
base_program_cache_entires = device.num_program_cache_entries()
else:
assert (
device.num_program_cache_entries() == base_program_cache_entires,
"program cache entries differ on same configs",
)


# 17975 test cases


def test_pc_with_different_shapes_in_sequence(device, use_program_cache):
y = torch.rand((1, 1, 256, 384), dtype=torch.bfloat16)
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
base_program_cache_entires = device.num_program_cache_entries()

x = torch.zeros((64, 1, 256, 384), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
num_iters = 4
z_tt = x_tt + y_tt

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
for _ in range(num_iters):
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
assert (
device.num_program_cache_entries() == base_program_cache_entires,
"program cache entries differ on same configs",
)

x = torch.zeros((64, 1, 256, 384), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

z_tt = x_tt + y_tt

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
y = torch.rand((1, 1, 32, 32), dtype=torch.bfloat16)

y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
base_program_cache_entires = device.num_program_cache_entries()

x = torch.zeros((4, 1, 32, 32), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

ttnn.repeat(y_tt, [4, 1, 1, 1])
z_tt = ttnn.experimental.add(x_tt, y_tt)
# z_tt = x_tt + y_tt

for i in range(num_iters):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
for _ in range(num_iters):
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
assert (
device.num_program_cache_entries() == base_program_cache_entires,
"program cache entries differ on same configs",
)

x = torch.zeros((4, 1, 32, 32), dtype=torch.bfloat16)
x_tt = ttnn.from_torch(x, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

ttnn.repeat(y_tt, [4, 1, 1, 1])
z_tt = ttnn.experimental.add(x_tt, y_tt)
# z_tt = x_tt + y_tt

for i in range(num_iters):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
y = torch.rand((1, 1, 256, 384), dtype=torch.bfloat16)

y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
base_program_cache_entires = device.num_program_cache_entries()
z_tt = ttnn.repeat(y_tt, ttnn.Shape([64, 1, 1, 1]))

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
for _ in range(num_iters):
y = torch.rand((1, 1, 256, 384), dtype=torch.bfloat16)
y_tt = ttnn.from_torch(y, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
assert (
device.num_program_cache_entries() == base_program_cache_entires,
"program cache entries differ on same configs",
)
z_tt = ttnn.repeat(y_tt, ttnn.Shape([64, 1, 1, 1]))

for i in range(64):
z_torch = ttnn.to_torch(z_tt[i : i + 1])
assert torch.allclose(z_torch, y, atol=1e-2), f"z_torch[{i}] != y"
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,23 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater_last_dim(
}
}
}
return {.program = std::move(program)};
auto override_runtime_args_callback = [reader_kernel_id, total_cores](
const void* operation,
const tt::tt_metal::Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
auto input = input_tensors.at(0);
auto output = output_tensors.at(0);
auto& runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id);
for (const auto& core : total_cores) {
auto& runtime_args = runtime_args_by_core[core.x][core.y];
runtime_args.at(0) = input.buffer()->address();
runtime_args.at(1) = output.buffer()->address();
}
};

return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback};
}

tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
Expand Down Expand Up @@ -162,15 +178,17 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
uint32_t cb_size_bytes = READ_ALIGNMENT * 2 + page_size_bytes;
uint32_t src0_cb_index = 0;
uint32_t src1_cb_index = 1;

tt::tt_metal::CircularBufferConfig cb_src0_config =
tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src0_cb_index, cb_data_format}})
.set_page_size(src0_cb_index, cb_size_bytes);
auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src0_config);

tt::tt_metal::CircularBufferConfig cb_src1_config =
tt::tt_metal::CircularBufferConfig(cb_size_bytes, {{src1_cb_index, cb_data_format}})
.set_page_size(src1_cb_index, cb_size_bytes);

auto cb_src1 = tt::tt_metal::CreateCircularBuffer(program, total_cores, cb_src1_config);

bool page_is_pow_2 = tt::tt_metal::is_power_of_two_at_least_32(page_size_bytes);
uint32_t page_pow_2 = page_is_pow_2 ? (std::uint32_t)std::log2(page_size_bytes) : 0;
std::vector<uint32_t> compile_time_args = {
Expand Down Expand Up @@ -245,7 +263,22 @@ tt::tt_metal::operation::ProgramWithCallbacks rm_repeater(
}
}
}
return {.program = std::move(program)};
auto override_runtime_args_callback = [reader_kernel_id, total_cores](
const void* operation,
const tt::tt_metal::Program& program,
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>&,
const std::vector<Tensor>& output_tensors) {
auto input = input_tensors.at(0);
auto output = output_tensors.at(0);
auto& runtime_args_by_core = GetRuntimeArgs(program, reader_kernel_id);
for (const auto& core : total_cores) {
auto& runtime_args = runtime_args_by_core[core.x][core.y];
runtime_args.at(0) = input.buffer()->address();
runtime_args.at(1) = output.buffer()->address();
}
};
return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback};
}

tt::tt_metal::operation::ProgramWithCallbacks rm_repeat_program_factory(
Expand Down

0 comments on commit 705b94d

Please sign in to comment.