diff --git a/examples/custom_dm_matmul.py b/examples/custom_dm_matmul.py index 6ea9e364..ca09fc75 100644 --- a/examples/custom_dm_matmul.py +++ b/examples/custom_dm_matmul.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from ttlang.ttl_api import * -from utils import assert_allclose +from ttlang.utils.correctness import assert_allclose import torch diff --git a/examples/metal_examples/README.MD b/examples/metal_examples/README.MD new file mode 100644 index 00000000..947ae467 --- /dev/null +++ b/examples/metal_examples/README.MD @@ -0,0 +1,11 @@ +# To Run Metal Examples +Examples are meant to be run on a machine with a single n150 card. Set the env value TT_VISIBLE_DEVICES to be the pcie card you want to use in a multi-device setting. +Manual build of tt-mlir needed, to source ttnn module +in tt-mlir, source env/activate +Now go to tt-lang, source build/env/activate and run the desired metal kernels, such as the singlecore matmul kernel. +```bash +pytest ./examples/metal_examples/singlecore_matmul/metal/singlecore_matmul.py +``` + +# TT-Lang Examples +any tt-lang in this folder is up to spec, but currently is not guaranteed to compile/execute diff --git a/examples/metal_examples/__init__.py b/examples/metal_examples/__init__.py new file mode 100644 index 00000000..6de02c7a --- /dev/null +++ b/examples/metal_examples/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/examples/metal_examples/multicore_matmul/metal/kernels/mm_compute.cpp b/examples/metal_examples/multicore_matmul/metal/kernels/mm_compute.cpp new file mode 100644 index 00000000..fe8747f5 --- /dev/null +++ b/examples/metal_examples/multicore_matmul/metal/kernels/mm_compute.cpp @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "compute_kernel_api/matmul.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "hostdevcommon/kernel_structs.h" +#include + +using std::uint32_t; + +namespace NAMESPACE { +void MAIN { + uint32_t num_output_tiles = + get_arg_val(0); // number of output tiles to produce + uint32_t Kt = get_arg_val( + 1); // number of tiles in K dimension for dot product + + constexpr tt::CBIndex cb_in0 = tt::CBIndex::c_0; + constexpr tt::CBIndex cb_in1 = tt::CBIndex::c_1; + constexpr tt::CBIndex cb_out = tt::CBIndex::c_16; + + // Setup the FPU (matrix engine) for the matmul operation. And specify the + // input and output circular buffers. + mm_init(cb_in0, cb_in1, cb_out); + + // the simplest possible version of outer product blocked matmul + // the reader is expected to read the A's and B's tile rows and tile columns + // for each output tile + for (uint32_t i = 0; i < num_output_tiles; ++i) { + // Make sure registers can be used for the output tile. This also sets the + // registers to zero. + tile_regs_acquire(); + for (uint32_t kt = 0; kt < Kt; kt++) { + // Wait for the input tiles to be available in the input circular buffers. + cb_wait_front(cb_in0, 1); + cb_wait_front(cb_in1, 1); + + // Perform the matrix multiplication for the current tile. + // NOTE: This function also accumulates the result into the destination + // tile. + matmul_tiles(cb_in0, cb_in1, 0, 0, 0, false); + + // Mark the input tiles as used by popping them from the front of the + // circular buffers. + cb_pop_front(cb_in0, 1); + cb_pop_front(cb_in1, 1); + } + + // Commit and wait for the registers are populated with the results from the + // FPU + tile_regs_commit(); + tile_regs_wait(); + + // Ensure the output circular buffer has space for the result tile. + cb_reserve_back(cb_out, 1); + // Pack the result tile into the output circular buffer. + pack_tile(0, cb_out); + // Mark the output tile as ready so the writer can read it. + cb_push_back(cb_out, 1); + + // We don't need the registers anymore, so we can release them and prepare + // for the next output tile. + tile_regs_release(); + } +} +} // namespace NAMESPACE diff --git a/examples/metal_examples/multicore_matmul/metal/kernels/mm_reader.cpp b/examples/metal_examples/multicore_matmul/metal/kernels/mm_reader.cpp new file mode 100644 index 00000000..9997d95a --- /dev/null +++ b/examples/metal_examples/multicore_matmul/metal/kernels/mm_reader.cpp @@ -0,0 +1,71 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include +#include + +#include "debug/dprint.h" + +void kernel_main() { + // same arg indices as in reader_binary_diff_lengths for compat + uint32_t src0_addr = get_arg_val(0); + uint32_t src1_addr = get_arg_val(1); + uint32_t Mt = get_arg_val(2); + uint32_t Kt = get_arg_val(3); + uint32_t Nt = get_arg_val(4); + uint32_t output_tile_start_id = + get_arg_val(5); // starting tile ID for output tiles + uint32_t num_output_tiles = + get_arg_val(6); // number of output tiles to read + + constexpr uint32_t cb_id_in0 = tt::CBIndex::c_0; + constexpr uint32_t cb_id_in1 = tt::CBIndex::c_1; + + // Declare address in which we stored the source matrices. We have set the + // exact same format between CBs and DRAM buffers in the host code, so we can + // use the same address for both DRAM and CBs. + const uint32_t in0_tile_bytes = get_tile_size(cb_id_in0); + const uint32_t in1_tile_bytes = get_tile_size(cb_id_in1); + + constexpr auto a_args = TensorAccessorArgs<0>(); + const auto a = TensorAccessor(a_args, src0_addr, in0_tile_bytes); + + constexpr auto b_args = + TensorAccessorArgs(); + const auto b = TensorAccessor(b_args, src1_addr, in1_tile_bytes); + + // Simple 2D matmul: A[Mt, Kt] @ B[Kt, Nt] = C[Mt, Nt] + for (uint32_t output_tile = 0; output_tile < num_output_tiles; + output_tile++) { + uint32_t current_tile_id = output_tile_start_id + output_tile; + + // Convert linear output tile ID to 2D coordinates + uint32_t out_row = current_tile_id / Nt; // Which row in output + uint32_t out_col = current_tile_id % Nt; // Which col in output + + // Read all K tiles for this output position + for (uint32_t k = 0; k < Kt; k++) { + // Read A's tile at (out_row, k) + uint32_t tile_A = out_row * Kt + k; // A is MK, so we stride by Kt + { + cb_reserve_back(cb_id_in0, 1); + uint32_t l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(tile_A, a, l1_write_addr_in0); + noc_async_read_barrier(); + cb_push_back(cb_id_in0, 1); + } + + // Read B's tile at (k, out_col) + uint32_t tile_B = k * Nt + out_col; // B is KN, so we stride by Nt + { + cb_reserve_back(cb_id_in1, 1); + uint32_t l1_write_addr_in1 = get_write_ptr(cb_id_in1); + noc_async_read_tile(tile_B, b, l1_write_addr_in1); + noc_async_read_barrier(); + cb_push_back(cb_id_in1, 1); + } + } + } +} diff --git a/examples/metal_examples/multicore_matmul/metal/kernels/mm_writer.cpp b/examples/metal_examples/multicore_matmul/metal/kernels/mm_writer.cpp new file mode 100644 index 00000000..5858212a --- /dev/null +++ b/examples/metal_examples/multicore_matmul/metal/kernels/mm_writer.cpp @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + // Runtime arguments to write data back into the output buffer. + uint32_t dst_addr = get_arg_val(0); + uint32_t num_tiles = + get_arg_val(1); // number of output tiles to write + uint32_t start_id = + get_arg_val(2); // starting tile ID for output tiles + + constexpr uint32_t cb_id_out = tt::CBIndex::c_16; + + // Create the address generator for the output buffer. Due to us sharing + // buffer and circular buffer configuration parameters (e.g. same data type + // and same page size) in the host code, we can grab the same parameters from + // the circular buffer as we would from the DRAM buffer. + constexpr uint32_t onetile = 1; // single-tile ublocks + const uint32_t tile_bytes = get_tile_size(cb_id_out); + + constexpr auto c_args = TensorAccessorArgs<0>(); + const auto c = TensorAccessor(c_args, dst_addr, tile_bytes); + + // Loop through the tile indices and write each tile to DRAM in order. + uint32_t end_id = start_id + num_tiles; + for (uint32_t i = start_id; i < end_id; ++i) { + // Wait for the kernel to produce an output tile + cb_wait_front(cb_id_out, onetile); + // Write the output tile to DRAM. + uint32_t l1_read_addr = get_read_ptr(cb_id_out); + noc_async_write_tile(i, c, l1_read_addr); + noc_async_write_barrier(); // This will wait until the write is done. As an + // alternative, noc_async_write_flushed() can be + // faster because it waits until the write + // request is sent. In that case, you have to use + // noc_async_write_barrier() at least once at the + // end of data movement kernel to make sure all + // writes are done. + cb_pop_front(cb_id_out, onetile); + } +} diff --git a/examples/metal_examples/multicore_matmul/metal/multicore_matmul.py b/examples/metal_examples/multicore_matmul/metal/multicore_matmul.py new file mode 100644 index 00000000..c3bdb466 --- /dev/null +++ b/examples/metal_examples/multicore_matmul/metal/multicore_matmul.py @@ -0,0 +1,228 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import ttnn +import pytest +import torch + +from ttlang.utils.correctness import assert_with_ulp + + +# (M * N) % (32 *32) == 0 for this implemention +@pytest.mark.parametrize( + "M,K,N", + [ + (640, 640, 640), + ], +) +def test_multicore_matmul(M, K, N): + # might be some l1 config stuff + device = ttnn.open_device(device_id=0) + assert (M * N) % ( + ttnn.TILE_SIZE * ttnn.TILE_SIZE + ) == 0, "M*N must be multiple of TILE_SIZE*TILE_SIZE" + Mt = M // ttnn.TILE_SIZE + Kt = K // ttnn.TILE_SIZE + Nt = N // ttnn.TILE_SIZE + num_output_tiles_total = (M * N) // (ttnn.TILE_SIZE * ttnn.TILE_SIZE) + + device_core_size = device.compute_with_storage_grid_size() + upper_bound_core = ttnn.CoreCoord(device_core_size.x - 1, device_core_size.y - 1) + device_core_grid = ttnn.CoreRangeSet( + [ttnn.CoreRange(ttnn.CoreCoord(0, 0), upper_bound_core)] + ) + print( + f"core_grid: {device_core_grid}, num_output_tiles_total: {num_output_tiles_total}" + ) + (_, all_cores, core_group_1, core_group_2, work_per_core1, work_per_core2) = ( + ttnn.split_work_to_cores( + device_core_grid, num_output_tiles_total, row_wise=True + ) + ) + print( + f"all_cores: {all_cores}, core_group_1: {core_group_1}, core_group_2: {core_group_2}, work_per_core1: {work_per_core1}, work_per_core2: {work_per_core2}" + ) + + # allocate a, b and output tensors for matmul on device dram + dram_memory_config = ttnn.DRAM_MEMORY_CONFIG + a_tensor = ttnn.rand( + (M, K), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=dram_memory_config, + ) + b_tensor = ttnn.rand( + (K, N), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=dram_memory_config, + ) + output_tensor = ttnn.empty( + (M, N), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=dram_memory_config, + ) + dtype_size = 2 # bfloat16 + buffer_factor = 2 + cb_page_size = dtype_size * ttnn.TILE_SIZE * ttnn.TILE_SIZE + cb_total_size = buffer_factor * cb_page_size + + a_cb = 0 + b_cb = 1 + out_cb = 16 + a_cb_format = ttnn.CBFormatDescriptor( + buffer_index=a_cb, + data_format=ttnn.bfloat16, + page_size=cb_page_size, + ) + b_cb_format = ttnn.CBFormatDescriptor( + buffer_index=b_cb, + data_format=ttnn.bfloat16, + page_size=cb_page_size, + ) + out_cb_format = ttnn.CBFormatDescriptor( + buffer_index=out_cb, + data_format=ttnn.bfloat16, + page_size=cb_page_size, + ) + + a_cb_descriptor = ttnn.CBDescriptor( + total_size=cb_total_size, + core_ranges=all_cores, + format_descriptors=[a_cb_format], + ) + b_cb_descriptor = ttnn.CBDescriptor( + total_size=cb_total_size, + core_ranges=all_cores, + format_descriptors=[b_cb_format], + ) + out_cb_descriptor = ttnn.CBDescriptor( + total_size=cb_total_size, + core_ranges=all_cores, + format_descriptors=[out_cb_format], + ) + + # TODO inconsistent metal access patterns for compile/runtime args + reader_compile_time_args = ttnn.TensorAccessorArgs(a_tensor).get_compile_time_args() + reader_compile_time_args.extend( + ttnn.TensorAccessorArgs(b_tensor).get_compile_time_args() + ) + writer_compile_time_args = ttnn.TensorAccessorArgs( + output_tensor + ).get_compile_time_args() + + # iterate over cores and assign work via runtime args + # Both core groups should only be one core_range, but handling more just in case + # will always be a smaller core grid than input grid, setting up runtime list + # as the larger one to enable indexing in + num_x_cores = upper_bound_core.x + 1 + num_y_cores = upper_bound_core.y + 1 + reader_rt_args = [[[] for _ in range(num_y_cores)] for _ in range(num_x_cores)] + writer_rt_args = [[[] for _ in range(num_y_cores)] for _ in range(num_x_cores)] + compute_rt_args = [[[] for _ in range(num_y_cores)] for _ in range(num_x_cores)] + current_tile = 0 + for core_range in core_group_1.ranges(): + for x in range(core_range.start.x, core_range.end.x + 1): + for y in range(core_range.start.y, core_range.end.y + 1): + print( + f"Assigning core ({x},{y}) tile {current_tile} work_per_core1 {work_per_core1}" + ) + reader_rt_args[x][y] = [ + a_tensor.buffer_address(), + b_tensor.buffer_address(), + Mt, + Kt, + Nt, + current_tile, + work_per_core1, + ] + writer_rt_args[x][y] = [ + output_tensor.buffer_address(), + work_per_core1, + current_tile, + ] + compute_rt_args[x][y] = [work_per_core1, Kt] + current_tile += work_per_core1 + + for core_range in core_group_2.ranges(): + for x in range(core_range.start.x, core_range.end.x + 1): + for y in range(core_range.start.y, core_range.end.y + 1): + print( + f"Assigning core ({x},{y}) tile {current_tile} work_per_core2 {work_per_core2}" + ) + reader_rt_args[x][y] = [ + a_tensor.buffer_address(), + b_tensor.buffer_address(), + Mt, + Kt, + Nt, + current_tile, + work_per_core2, + ] + writer_rt_args[x][y] = [ + output_tensor.buffer_address(), + work_per_core2, + current_tile, + ] + compute_rt_args[x][y] = [work_per_core2, Kt] + current_tile += work_per_core2 + + # Compute config init can't handle options, set here + computeConfig = ttnn.ComputeConfigDescriptor() + computeConfig.math_fidelity = ttnn.MathFidelity.HiFi4 + computeConfig.fp32_dest_acc_en = True + computeConfig.math_approx_mode = False + + reader_kernel_descriptor = ttnn.KernelDescriptor( + kernel_source="examples/metal_examples/multicore_matmul/metal/kernels/mm_reader.cpp", + source_type=ttnn.KernelDescriptor.SourceType.FILE_PATH, + core_ranges=all_cores, + compile_time_args=reader_compile_time_args, + runtime_args=reader_rt_args, + config=ttnn.ReaderConfigDescriptor(), + ) + writer_kernel_descriptor = ttnn.KernelDescriptor( + kernel_source="examples/metal_examples/multicore_matmul/metal/kernels/mm_writer.cpp", + source_type=ttnn.KernelDescriptor.SourceType.FILE_PATH, + core_ranges=all_cores, + compile_time_args=writer_compile_time_args, + runtime_args=writer_rt_args, + config=ttnn.WriterConfigDescriptor(), + ) + compute_kernel_descriptor = ttnn.KernelDescriptor( + kernel_source="examples/metal_examples/multicore_matmul/metal/kernels/mm_compute.cpp", + source_type=ttnn.KernelDescriptor.SourceType.FILE_PATH, + core_ranges=all_cores, + compile_time_args=[], + runtime_args=compute_rt_args, + config=computeConfig, + ) + + program_descriptor = ttnn.ProgramDescriptor( + kernels=[ + reader_kernel_descriptor, + writer_kernel_descriptor, + compute_kernel_descriptor, + ], + semaphores=[], + cbs=[a_cb_descriptor, b_cb_descriptor, out_cb_descriptor], + ) + + print("Launching generic_op...") + output = ttnn.generic_op([a_tensor, b_tensor, output_tensor], program_descriptor) + print("Completed generic_op.") + metal_output = ttnn.to_torch(output).to(torch.bfloat16) + print(f"metal_output: {metal_output}") + + a_tensor_torch = ttnn.to_torch(a_tensor).to(torch.bfloat16) + b_tensor_torch = ttnn.to_torch(b_tensor).to(torch.bfloat16) + torch_output = torch.matmul(a_tensor_torch, b_tensor_torch) + print(f"torch_output: {torch_output}") + + assert_with_ulp(torch_output, metal_output) + + ttnn.close_device(device) diff --git a/examples/metal_examples/multicore_matmul/ttlang/multicore_matmul.py b/examples/metal_examples/multicore_matmul/ttlang/multicore_matmul.py new file mode 100644 index 00000000..c91addf4 --- /dev/null +++ b/examples/metal_examples/multicore_matmul/ttlang/multicore_matmul.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +# up to tt-lang spec, not intended to compile or run currently +import ttnn +import pytest +import torch + +from ttl import Program, make_circular_buffer_like, copy + +from ttlang.utils.correctness import assert_with_ulp +from ttlang.utils.block_allocation import split_work_to_cores + + +def get_number_of_cores(grid_range): + total_cores = 0 + for start, end in grid_range: + x_range = end[0] - start[0] + 1 + y_range = end[1] - start[1] + 1 + total_cores += x_range * y_range + return total_cores + + +@ttl.kernel(grid=(13, 10)) +def tt_lang_multicore_matmul(a: ttnn.Tensor, b: ttnn.Tensor, out: ttnn.Tensor): + assert a.shape[1] == b.shape[0], "Incompatible matrix shapes for multiplication." + assert a.shape[0] == out.shape[0], "Output matrix has incorrect number of rows." + M = a.shape[0] + N = b.shape[1] + K = a.shape[1] + Mt = M // ttnn.TILE_SIZE + Kt = K // ttnn.TILE_SIZE + Nt = N // ttnn.TILE_SIZE + num_output_tiles_total = (M * N) // (ttnn.TILE_SIZE * ttnn.TILE_SIZE) + buffering_factor = 2 + a_cb = make_circular_buffer_like(a, shape=(1, 1), buffer_factor=buffering_factor) + b_cb = make_circular_buffer_like(b, shape=(1, 1), buffer_factor=buffering_factor) + out_cb = make_circular_buffer_like( + out, shape=(1, 1), buffer_factor=buffering_factor + ) + + print(f"core_grid: {core_grid}, num_output_tiles_total: {num_output_tiles_total}") + (_, all_cores, core_group_1, core_group_2, work_per_core1, work_per_core2) = ( + split_work_to_cores( + ttl.grid_size(dims=2), num_output_tiles_total, row_wise=True + ) + ) + print( + f"all_cores: {all_cores}, core_group_1: {core_group_1}, core_group_2: {core_group_2}, work_per_core1: {work_per_core1}, work_per_core2: {work_per_core2}" + ) + + num_cores_group_1 = get_number_of_cores(core_group_1) + num_cores_group_2 = get_number_of_cores(core_group_2) + + def get_tiles_per_core(core_id): + if core_id < num_cores_group_1: + return work_per_core1 + elif core_id < num_cores_group_1 + num_cores_group_2: + return work_per_core2 + else: # no work assigned + return 0 + + def get_start_tile_id(core_id): + if core_id < num_cores_group_1: + return core_id * work_per_core1 + elif core_id < num_cores_group_1 + num_cores_group_2: + return ( + num_cores_group_1 * work_per_core1 + + (core_id - num_cores_group_1) * work_per_core2 + ) + else: # no work assigned + return 0 + + @ttl.compute() + def mm_compute(): + core_id = ttl.core(dims=1) + for _ in range(get_tiles_per_core(core_id)): + with out_cb.reserve() as out_blk: + for _ in range(Kt): + with a_cb.wait() as a_blk, b_cb.wait() as b_blk: + out_blk.store(a_blk @ b_blk, acc=True) + + @ttl.datamovement() + def mm_reader(): + core_id = ttl.core(dims=1) + # A[Mt, Kt] @ B[Kt, Nt] = C[Mt, Nt] + for tile_id in range(get_tiles_per_core(core_id)): + current_tile_id = get_start_tile_id(core_id) + tile_id + out_row = current_tile_id // Nt + out_col = current_tile_id % Nt + for k in range(Kt): + with a_cb.reserve() as a_blk, b_cb.reserve() as b_blk: + a_wr = copy(a[out_row, k], a_blk) + b_wr = copy(b[k, out_col], b_blk) + a_wr.wait() + b_wr.wait() + + @ttl.datamovement() + def mm_writer(): + core_id = ttl.core(dims=1) + # A[Mt, Kt] @ B[Kt, Nt] = C[Mt, Nt] + for tile_id in range(get_tiles_per_core(core_id)): + current_tile_id = get_start_tile_id(core_id) + tile_id + out_row = current_tile_id // Nt + out_col = current_tile_id % Nt + with out_cb.wait() as out_blk: + out_wr = copy(out_blk, out[out_row, out_col]) + out_wr.wait() + + return Program(mm_compute, mm_reader, mm_writer)(a, b, out) + + +@pytest.mark.parametrize("M,K,N", [(256, 256, 256), (512, 512, 512)]) +def test_multicore_matmul_tt_lang(M, K, N): + """Test multicore matmul kernel.""" + device = ttnn.open_device(device_id=0) + a = ttnn.rand((M, K), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + b = ttnn.rand((K, N), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + c = ttnn.empty((M, N), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + tt_lang_multicore_matmul(a, b, c) + + golden = torch.matmul( + ttnn.to_torch(a).to(torch.bfloat16), ttnn.to_torch(b).to(torch.bfloat16) + ) + result = ttnn.to_torch(c).to(torch.bfloat16) + assert_with_ulp(golden, result) + + ttnn.close_device(device) diff --git a/examples/metal_examples/singlecore_matmul/metal/kernels/mm_compute.cpp b/examples/metal_examples/singlecore_matmul/metal/kernels/mm_compute.cpp new file mode 100644 index 00000000..c1ab730b --- /dev/null +++ b/examples/metal_examples/singlecore_matmul/metal/kernels/mm_compute.cpp @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "compute_kernel_api/matmul.h" +#include "compute_kernel_api/tile_move_copy.h" +#include "hostdevcommon/kernel_structs.h" +#include + +using std::uint32_t; + +namespace NAMESPACE { +void MAIN { + const uint32_t Mt = get_compile_time_arg_val(0); + const uint32_t Kt = get_compile_time_arg_val(1); + const uint32_t Nt = get_compile_time_arg_val(2); + constexpr tt::CBIndex cb_in0 = tt::CBIndex::c_0; + constexpr tt::CBIndex cb_in1 = tt::CBIndex::c_1; + constexpr tt::CBIndex cb_out = tt::CBIndex::c_16; + + // Setup the FPU (matrix engine) for the matmul operation. And specify the + // input and output circular buffers. + mm_init(cb_in0, cb_in1, cb_out); + + // the simplest possible version of outer product blocked matmul + // the reader is expected to read the A's and B's tile rows and tile columns + // for each output tile + for (uint32_t mt = 0; mt < Mt; ++mt) { + for (uint32_t nt = 0; nt < Nt; ++nt) { + // Make sure registers can be used for the output tile. This also sets the + // registers to zero. + tile_regs_acquire(); + for (uint32_t kt = 0; kt < Kt; kt++) { + // Wait for the input tiles to be available in the input circular + // buffers. + cb_wait_front(cb_in0, 1); + cb_wait_front(cb_in1, 1); + + // Perform the matrix multiplication for the current tile. + // NOTE: This function also accumulates the result into the destination + // tile. + matmul_tiles(/*in0_cb_id=*/cb_in0, /*in1_cb_id=*/cb_in1, + /*in0_tile_index=*/0, /*in1_tile_index=*/0, + /*idst=*/0, /*transpose=*/false); + + // Mark the input tiles as used by popping them from the front of the + // circular buffers. + cb_pop_front(cb_in0, 1); + cb_pop_front(cb_in1, 1); + } + + // Commit and wait for the registers are populated with the results from + // the FPU + tile_regs_commit(); + tile_regs_wait(); + + // Ensure the output circular buffer has space for the result tile. + cb_reserve_back(cb_out, 1); + // Pack the result tile into the output circular buffer. + pack_tile(0, cb_out); + // Mark the output tile as ready so the writer can read it. + cb_push_back(cb_out, 1); + + // We don't need the registers anymore, so we can release them and prepare + // for the next output tile. + tile_regs_release(); + } + } +} +} // namespace NAMESPACE diff --git a/examples/metal_examples/singlecore_matmul/metal/kernels/mm_reader.cpp b/examples/metal_examples/singlecore_matmul/metal/kernels/mm_reader.cpp new file mode 100644 index 00000000..0441a74c --- /dev/null +++ b/examples/metal_examples/singlecore_matmul/metal/kernels/mm_reader.cpp @@ -0,0 +1,55 @@ + +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" +#include + +void kernel_main() { + // same arg indices as in reader_binary_diff_lengths for compat + uint32_t src0_addr = get_arg_val(0); + uint32_t src1_addr = get_arg_val(1); + uint32_t Mt = get_arg_val(2); + uint32_t Kt = get_arg_val(3); + uint32_t Nt = get_arg_val(4); + + constexpr uint32_t cb_id_in0 = tt::CBIndex::c_0; + constexpr uint32_t cb_id_in1 = tt::CBIndex::c_1; + + // Declare address in which we stored the source matrices. We have set the + // exact same format between CBs and DRAM buffers in the host code, so we can + // use the same address for both DRAM and CBs. + constexpr auto s0_args = TensorAccessorArgs<0>(); + const auto s0 = TensorAccessor(s0_args, src0_addr, get_tile_size(cb_id_in0)); + constexpr auto s1_args = + TensorAccessorArgs(); + const auto s1 = TensorAccessor(s1_args, src1_addr, get_tile_size(cb_id_in1)); + + // Loop through the dimensions of the matrices. Read them and push to the + // circular buffers. Dimension names are called M, N and K. `t` in `mt` means + // tile. + for (uint32_t mt = 0; mt < Mt; mt++) { + for (uint32_t nt = 0; nt < Nt; nt++) { + for (uint32_t kt = 0; kt < Kt; kt++) { + { // Read A's tile at (mt, kt) + uint32_t a_tile_index = mt * Kt + kt; // A is MxK, so we stride by Kt + cb_reserve_back(cb_id_in0, 1); + uint32_t l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(a_tile_index, s0, l1_write_addr_in0); + noc_async_read_barrier(); + cb_push_back(cb_id_in0, 1); + } + + { // Read B's tile at (kt, nt) + uint32_t b_tile_index = kt * Nt + nt; // B is KxN, so we stride by Nt + cb_reserve_back(cb_id_in1, 1); + uint32_t l1_write_addr_in1 = get_write_ptr(cb_id_in1); + noc_async_read_tile(b_tile_index, s1, l1_write_addr_in1); + noc_async_read_barrier(); + cb_push_back(cb_id_in1, 1); + } + } // Kt loop + } // Nt loop + } // Mt loop +} diff --git a/examples/metal_examples/singlecore_matmul/metal/kernels/mm_writer.cpp b/examples/metal_examples/singlecore_matmul/metal/kernels/mm_writer.cpp new file mode 100644 index 00000000..75bf7538 --- /dev/null +++ b/examples/metal_examples/singlecore_matmul/metal/kernels/mm_writer.cpp @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + // Runtime arguments to write data back into the output buffer. + uint32_t dst_addr = get_arg_val(0); + uint32_t Mt = get_arg_val(1); + uint32_t Nt = get_arg_val(2); + + constexpr uint32_t cb_id_out0 = tt::CBIndex::c_16; + + // Create the address generator for the output buffer. Due to us sharing + // buffer and circular buffer configuration parameters (e.g. same data type + // and same page size) in the host code, we can grab the same parameters from + // the circular buffer as we would from the DRAM buffer. + constexpr auto s_args = TensorAccessorArgs<0>(); + const auto s = TensorAccessor(s_args, dst_addr, get_tile_size(cb_id_out0)); + + // Loop through the matrix dimensions Mt and Nt. mm_compute matmul will + // generate C's tiles C=A*B, MN=MK*KN, in row major order, we just read them + // from CB and write out to DRAM + for (uint32_t m = 0; m < Mt; ++m) { + for (uint32_t n = 0; n < Nt; ++n) { + // Wait for the matrix multiplication kernel to produce an output + cb_wait_front(cb_id_out0, 1); + // Write the output tile to DRAM. + uint32_t l1_read_addr = get_read_ptr(cb_id_out0); + noc_async_write_tile(m * Nt + n, s, l1_read_addr); + // This will wait until the write is done. As + // an alternative, noc_async_write_flushed() + // can be faster because it waits until the + // write request is sent. In that case, you + // have to use noc_async_write_barrier() at + // least once at the end of data movement + // kernel to make sure all writes are done. + noc_async_write_barrier(); + cb_pop_front(cb_id_out0, 1); + } + } +} diff --git a/examples/metal_examples/singlecore_matmul/metal/singlecore_matmul.py b/examples/metal_examples/singlecore_matmul/metal/singlecore_matmul.py new file mode 100644 index 00000000..57f4de58 --- /dev/null +++ b/examples/metal_examples/singlecore_matmul/metal/singlecore_matmul.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import ttnn +import pytest +import torch + +from ttlang.utils.correctness import assert_with_ulp + + +@pytest.mark.parametrize( + "M,K,N", [(128, 128, 128), (256, 256, 256), (512, 512, 512), (640, 640, 640)] +) +def test_singlecore_matmul_metal(M, K, N): + device = ttnn.open_device(device_id=0) + # allocate a, b and output tensors for matmul on device dram + dram_memory_config = ttnn.DRAM_MEMORY_CONFIG + a_tensor = ttnn.rand( + (M, K), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=dram_memory_config, + ) + b_tensor = ttnn.rand( + (K, N), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=dram_memory_config, + ) + output_tensor = ttnn.empty( + (M, N), + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=device, + memory_config=dram_memory_config, + ) + Mt = M // ttnn.TILE_SIZE + Kt = K // ttnn.TILE_SIZE + Nt = N // ttnn.TILE_SIZE + + a_cb = 0 + b_cb = 1 + out_cb = 16 + dtype_size = 2 # bfloat16 + cb_page_size = dtype_size * ttnn.TILE_SIZE * ttnn.TILE_SIZE + a_cb_format = ttnn.CBFormatDescriptor( + buffer_index=a_cb, + data_format=ttnn.bfloat16, + page_size=cb_page_size, + ) + b_cb_format = ttnn.CBFormatDescriptor( + buffer_index=b_cb, + data_format=ttnn.bfloat16, + page_size=cb_page_size, + ) + out_cb_format = ttnn.CBFormatDescriptor( + buffer_index=out_cb, + data_format=ttnn.bfloat16, + page_size=cb_page_size, + ) + + # single core grid + core = ttnn.CoreCoord(0, 0) + core_grid = ttnn.CoreRangeSet([ttnn.CoreRange(core, core)]) + buffering_factor = 2 + cb_total_size = buffering_factor * cb_page_size + a_cb_descriptor = ttnn.CBDescriptor( + total_size=cb_total_size, + core_ranges=core_grid, + format_descriptors=[a_cb_format], + ) + b_cb_descriptor = ttnn.CBDescriptor( + total_size=cb_total_size, + core_ranges=core_grid, + format_descriptors=[b_cb_format], + ) + out_cb_descriptor = ttnn.CBDescriptor( + total_size=cb_total_size, + core_ranges=core_grid, + format_descriptors=[out_cb_format], + ) + + reader_compile_time_args = ttnn.TensorAccessorArgs(a_tensor).get_compile_time_args() + reader_compile_time_args.extend( + ttnn.TensorAccessorArgs(b_tensor).get_compile_time_args() + ) + writer_compile_time_args = ttnn.TensorAccessorArgs( + output_tensor + ).get_compile_time_args() + compute_compile_time_args = [Mt, Kt, Nt] + reader_rt_args = [a_tensor.buffer_address(), b_tensor.buffer_address(), Mt, Kt, Nt] + writer_rt_args = [output_tensor.buffer_address(), Mt, Nt] + + # Compute config init can't handle options, set here + computeConfig = ttnn.ComputeConfigDescriptor() + computeConfig.math_fidelity = ttnn.MathFidelity.HiFi4 + computeConfig.fp32_dest_acc_en = True + computeConfig.math_approx_mode = False + + reader_kernel_descriptor = ttnn.KernelDescriptor( + kernel_source="examples/metal_examples/singlecore_matmul/metal/kernels/mm_reader.cpp", + source_type=ttnn.KernelDescriptor.SourceType.FILE_PATH, + core_ranges=core_grid, + compile_time_args=reader_compile_time_args, + runtime_args=[[reader_rt_args]], + config=ttnn.ReaderConfigDescriptor(), + ) + writer_kernel_descriptor = ttnn.KernelDescriptor( + kernel_source="examples/metal_examples/singlecore_matmul/metal/kernels/mm_writer.cpp", + source_type=ttnn.KernelDescriptor.SourceType.FILE_PATH, + core_ranges=core_grid, + compile_time_args=writer_compile_time_args, + runtime_args=[[writer_rt_args]], + config=ttnn.WriterConfigDescriptor(), + ) + compute_kernel_descriptor = ttnn.KernelDescriptor( + kernel_source="examples/metal_examples/singlecore_matmul/metal/kernels/mm_compute.cpp", + source_type=ttnn.KernelDescriptor.SourceType.FILE_PATH, + core_ranges=core_grid, + compile_time_args=compute_compile_time_args, + runtime_args=[[[]]], + config=computeConfig, + ) + + program_descriptor = ttnn.ProgramDescriptor( + kernels=[ + reader_kernel_descriptor, + writer_kernel_descriptor, + compute_kernel_descriptor, + ], + semaphores=[], + cbs=[a_cb_descriptor, b_cb_descriptor, out_cb_descriptor], + ) + + output = ttnn.generic_op([a_tensor, b_tensor, output_tensor], program_descriptor) + metal_output = ttnn.to_torch(output).to(torch.bfloat16) + + a_tensor_torch = ttnn.to_torch(a_tensor).to(torch.bfloat16) + b_tensor_torch = ttnn.to_torch(b_tensor).to(torch.bfloat16) + torch_output = torch.matmul(a_tensor_torch, b_tensor_torch) + + assert_with_ulp(torch_output, metal_output) + + ttnn.close_device(device) diff --git a/examples/metal_examples/singlecore_matmul/ttlang/singlecore_matmul.py b/examples/metal_examples/singlecore_matmul/ttlang/singlecore_matmul.py new file mode 100644 index 00000000..06afe0c9 --- /dev/null +++ b/examples/metal_examples/singlecore_matmul/ttlang/singlecore_matmul.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +# up to tt-lang spec, not intended to compile or run currently +import sys +from pathlib import Path +import ttnn +import pytest +import torch + +from ttl import Program, make_circular_buffer_like, copy + +from ttlang.utils.correctness import assert_with_ulp + + +@ttl.kernel(grid=(1, 1)) +def tt_lang_singlecore_matmul(a: ttnn.Tensor, b: ttnn.Tensor, out: ttnn.Tensor): + assert a.shape[1] == b.shape[0], "Incompatible matrix shapes for multiplication." + assert a.shape[0] == out.shape[0], "Output matrix has incorrect number of rows." + M = a.shape[0] + N = b.shape[1] + K = a.shape[1] + Mt = M // ttnn.TILE_SIZE + Kt = K // ttnn.TILE_SIZE + Nt = N // ttnn.TILE_SIZE + buffering_factor = 2 + a_cb = make_circular_buffer_like(a, shape=(1, 1), buffer_factor=buffering_factor) + b_cb = make_circular_buffer_like(b, shape=(1, 1), buffer_factor=buffering_factor) + out_cb = make_circular_buffer_like( + out, shape=(1, 1), buffer_factor=buffering_factor + ) + + @ttl.compute() + def mm_compute(): + for _ in range(Mt): + for _ in range(Nt): + with out_cb.reserve() as out_blk: + for _ in range(Kt): + with a_cb.wait() as a_blk, b_cb.wait() as b_blk: + out_blk.store(a_blk @ b_blk, acc=True) + + @ttl.datamovement() + def mm_reader(): + for m in range(Mt): + for n in range(Nt): + for k in range(Kt): + with a_cb.reserve() as a_blk, b_cb.reserve() as b_blk: + a_wr = copy(a[m, k], a_blk) + b_wr = copy(b[k, n], b_blk) + a_wr.wait() + b_wr.wait() + + @ttl.datamovement() + def mm_writer(): + for m in range(Mt): + for n in range(Nt): + with out_cb.wait() as out_blk: + out_wr = copy(out_blk, out[m, n]) + out_wr.wait() + + return Program(mm_compute, mm_reader, mm_writer)(a, b, out) + + +def test_singlecore_matmul_tt_lang(): + """Test singlecore matmul kernel.""" + device = ttnn.open_device(device_id=0) + M, K, N = 256, 256, 256 + a = ttnn.rand((M, K), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + b = ttnn.rand((K, N), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + c = ttnn.empty((M, N), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + tt_lang_singlecore_matmul(a, b, c) + + golden = torch.matmul( + ttnn.to_torch(a).to(torch.bfloat16), ttnn.to_torch(b).to(torch.bfloat16) + ) + result = ttnn.to_torch(c).to(torch.bfloat16) + assert_with_ulp(golden, result) + + ttnn.close_device(device) diff --git a/examples/test_accessor_creation.py b/examples/test_accessor_creation.py index 4076d425..8a3811a9 100644 --- a/examples/test_accessor_creation.py +++ b/examples/test_accessor_creation.py @@ -9,7 +9,7 @@ in the Python-generated IR, rather than being added later by a pass. """ from ttlang.ttl_api import * -from utils import assert_allclose +from ttlang.utils.correctness import assert_allclose import torch diff --git a/examples/utils.py b/examples/utils.py deleted file mode 100644 index f5c11bcc..00000000 --- a/examples/utils.py +++ /dev/null @@ -1,114 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -"""Utilities for comparing tensor outputs in tests.""" - -import torch - - -def assert_pcc(golden, actual, threshold=0.99): - """ - Assert Pearson correlation coefficient is above threshold. - - Args: - golden: Expected tensor - actual: Actual output tensor - threshold: Minimum acceptable PCC (default 0.99) - - Raises: - AssertionError: If PCC < threshold - """ - combined = torch.stack([golden.flatten(), actual.flatten()]) - pcc = torch.corrcoef(combined)[0, 1].item() - assert ( - pcc >= threshold - ), f"Expected pcc {pcc} >= {threshold}\ngolden:\n{golden}\nactual:\n{actual}" - - -def assert_allclose( - actual, - expected, - rtol=1e-5, - atol=1e-8, - verbose=True, -): - """ - Assert tensors are close with detailed error reporting. - - Computes both absolute and relative errors with informative failure messages - showing error statistics and worst-case locations. - - Args: - actual: Actual output tensor - expected: Expected tensor - rtol: Relative tolerance (default 1e-5) - atol: Absolute tolerance (default 1e-8) - verbose: Show detailed error statistics on failure (default True) - - Raises: - AssertionError: If tensors don't match within tolerance - - Examples: - >>> out = model(input) - >>> assert_allclose(out, expected, rtol=1e-4, atol=1e-6) - """ - if actual.shape != expected.shape: - raise AssertionError( - f"Shape mismatch: actual {actual.shape} vs expected {expected.shape}" - ) - - # Compute element-wise absolute error - abs_diff = torch.abs(actual - expected) - max_abs_error = abs_diff.max().item() - mean_abs_error = abs_diff.mean().item() - - # Compute element-wise relative error with epsilon for numerical stability - eps = 1e-10 - rel_error = abs_diff / (torch.abs(expected) + eps) - max_rel_error = rel_error.max().item() - mean_rel_error = rel_error.mean().item() - - # Check if within tolerance - is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol) - - if not is_close and verbose: - # Find locations of worst errors - abs_error_flat = abs_diff.flatten() - rel_error_flat = rel_error.flatten() - - worst_abs_idx = abs_error_flat.argmax() - worst_rel_idx = rel_error_flat.argmax() - - # Convert flat indices to coordinates - worst_abs_coord = torch.unravel_index(worst_abs_idx, actual.shape) - worst_rel_coord = torch.unravel_index(worst_rel_idx, actual.shape) - - error_msg = f""" -Tensor comparison failed! - -Error Statistics: - Absolute Error: - Mean: {mean_abs_error:.6e} - Max: {max_abs_error:.6e} at {tuple(c.item() for c in worst_abs_coord)} - actual={actual[worst_abs_coord].item():.6f}, expected={expected[worst_abs_coord].item():.6f} - - Relative Error: - Mean: {mean_rel_error:.6e} - Max: {max_rel_error:.6e} at {tuple(c.item() for c in worst_rel_coord)} - actual={actual[worst_rel_coord].item():.6f}, expected={expected[worst_rel_coord].item():.6f} - -Thresholds: - rtol: {rtol} - atol: {atol} - -Shape: {actual.shape} -Mismatched elements: {(abs_diff > atol + rtol * torch.abs(expected)).sum().item()} / {actual.numel()} -""" - raise AssertionError(error_msg) - - if not is_close: - raise AssertionError( - f"Tensors not close: max_abs_error={max_abs_error:.6e}, " - f"max_rel_error={max_rel_error:.6e}, rtol={rtol}, atol={atol}" - ) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index ce13c87c..d06e8760 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -121,6 +121,15 @@ declare_mlir_python_sources(TTLangPythonCommon.Src _src/tensor_registry.py ) +declare_mlir_python_sources(TTLangPythonCommon.Utils + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/ttlang" + ADD_TO_PARENT TTLangPythonCommon + SOURCES + utils/__init__.py + utils/block_allocation.py + utils/correctness.py +) + # ############################################################################### # Generate packages and shared library # ############################################################################### diff --git a/python/ttlang/utils/__init__.py b/python/ttlang/utils/__init__.py new file mode 100644 index 00000000..03dcc72c --- /dev/null +++ b/python/ttlang/utils/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for tt-lang.""" + +from .block_allocation import ( + split_work_to_cores, +) + +from .correctness import ( + assert_pcc, + assert_allclose, + assert_with_ulp, +) + +__all__ = [ + # block_allocation + "split_work_to_cores", + # correctness + "assert_pcc", + "assert_allclose", + "assert_with_ulp", +] diff --git a/python/ttlang/utils/block_allocation.py b/python/ttlang/utils/block_allocation.py new file mode 100644 index 00000000..451b72d8 --- /dev/null +++ b/python/ttlang/utils/block_allocation.py @@ -0,0 +1,285 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Tuple, List +import itertools +import math + + +def remove_leading_ones(grid: Tuple[int, ...]) -> Tuple[int, ...]: + return tuple(itertools.dropwhile(lambda x: x == 1, grid)) + + +def get_number_of_cores(grid: Tuple[int, ...]) -> int: + core_count = 1 + for dim in grid: + assert dim > 0, "grid dimensions must be positive" + core_count *= dim + return core_count + + +def filter_factor_pairs_by_2d_grid( + factor_pairs: list[Tuple[int, int]], grid: Tuple[int, int] +) -> list[Tuple[int, int]]: + valid_pairs = [] + for pair in factor_pairs: + if pair[0] <= grid[0] and pair[1] <= grid[1]: + valid_pairs.append(pair) + elif pair[1] <= grid[0] and pair[0] <= grid[1]: + valid_pairs.append((pair[1], pair[0])) + return valid_pairs + + +def num_cores_to_grid_ranges( + start_coord: Tuple[int, ...], + target_num_cores: int, + grid_size: Tuple[int, ...], + row_wise: bool = True, +) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]: + """ + Generate a list of grid ranges covering target_num_cores cores starting from start_coord. + Similar to num_cores_to_corerangeset but returns simple tuples. + + Args: + start_coord: The starting coordinate + target_num_cores: Number of cores to include + grid_size: The dimensions of the grid + row_wise: If True, fill row-wise, else column-wise + + Returns: + List of (start_coord, end_coord) tuples representing rectangular ranges + """ + assert len(start_coord) == len( + grid_size + ), "start_coord and grid_size must have same dimensions" + + # Only support 2D grids for now + simplified_grid = remove_leading_ones(grid_size) + assert len(simplified_grid) <= 2, "Only supports 2D grids" + + # Get the actual 2D dimensions (last 2 dimensions) + num_cores_x = grid_size[-1] if len(grid_size) >= 1 else 1 + num_cores_y = grid_size[-2] if len(grid_size) >= 2 else 1 + + start_x = start_coord[-1] if len(start_coord) >= 1 else 0 + start_y = start_coord[-2] if len(start_coord) >= 2 else 0 + + assert ( + start_x < num_cores_x and start_y < num_cores_y + ), "Start coord must be within grid" + + if row_wise: + # Calculate available cores + total_available_cores = (num_cores_y - 1 - start_y) * num_cores_x + total_available_cores += num_cores_x - start_x + else: + # Column-wise + total_available_cores = (num_cores_x - 1 - start_x) * num_cores_y + total_available_cores += num_cores_y - start_y + + assert ( + target_num_cores <= total_available_cores + ), f"Target {target_num_cores} exceeds available {total_available_cores}" + + # Build list of ranges + all_ranges = [] + leftover_size = target_num_cores + s_x, s_y = start_x, start_y + + prefix = tuple(0 for _ in range(len(grid_size) - 2)) + + if row_wise: + # Partial row at start + if s_x != 0 and leftover_size > num_cores_x - start_x: + start_c = prefix + (s_y, s_x) + end_c = prefix + (s_y, num_cores_x - 1) + all_ranges.append((start_c, end_c)) + cores_taken = num_cores_x - s_x + leftover_size -= cores_taken + s_x = 0 + s_y += 1 + + # Full rows + if leftover_size >= num_cores_x: + num_full_rows = leftover_size // num_cores_x + start_c = prefix + (s_y, s_x) + end_c = prefix + (s_y + num_full_rows - 1, num_cores_x - 1) + all_ranges.append((start_c, end_c)) + leftover_size -= num_full_rows * num_cores_x + s_y += num_full_rows + s_x = 0 + + # Partial row at end + if leftover_size > 0: + start_c = prefix + (s_y, s_x) + end_c = prefix + (s_y, s_x + leftover_size - 1) + all_ranges.append((start_c, end_c)) + else: + # Column-wise + # Partial col at start + if s_y != 0 and leftover_size > num_cores_y - start_y: + start_c = prefix + (s_y, s_x) + end_c = prefix + (num_cores_y - 1, s_x) + all_ranges.append((start_c, end_c)) + cores_taken = num_cores_y - s_y + leftover_size -= cores_taken + s_y = 0 + s_x += 1 + + # Full cols + if leftover_size >= num_cores_y: + num_full_cols = leftover_size // num_cores_y + start_c = prefix + (s_y, s_x) + end_c = prefix + (num_cores_y - 1, s_x + num_full_cols - 1) + all_ranges.append((start_c, end_c)) + leftover_size -= num_full_cols * num_cores_y + s_x += num_full_cols + s_y = 0 + + # Partial col at end + if leftover_size > 0: + start_c = prefix + (s_y, s_x) + end_c = prefix + (s_y + leftover_size - 1, s_x) + all_ranges.append((start_c, end_c)) + + return all_ranges + + +def split_work_to_cores( + grid_size: Tuple[int, ...], units_to_divide: int, row_wise: bool = True +) -> Tuple[ + int, + Tuple[Tuple[int, ...], Tuple[int, ...]], + Tuple[Tuple[int, ...], Tuple[int, ...]], + int, + int, +]: + """Splits work units among cores in a from a single device grid. + currently can produce work splits that cannot map to CoreRanges directly, particlarily in 1-d grids + + Args: + grid_size: A tuple representing the dimensions of the core grid. + units_to_divide: The total number of work units to be divided among the cores. + row_wise: If True, split work in a row-wise manner; otherwise, column-wise. + + Returns: A tuple containing: + - total number of cores + - core group 1 as a tuple of tuples, start coord to end coord rectangle [inclusive, inclusive] + - core group 2 as a tuple of tuples, start coord to end coord rectangle [inclusive, inclusive] + - work units per core in group 1 + - work units per core in group 2 + """ + if units_to_divide == 0: + return (0, (), (), 0, 0) + simplified_grid_size = remove_leading_ones(grid_size) + assert len(simplified_grid_size) <= 2, "only supports grids with a single device" + total_cores = get_number_of_cores(grid_size) + assert total_cores > 0, "grid must have at least one core" + start_coord = (0,) * len(grid_size) + if ( + total_cores >= units_to_divide + ): # more cores than work units, assign 1 unit to first N cores + if len(simplified_grid_size) == 1: + end_coord = ((0,) * (len(grid_size) - 1)) + (units_to_divide - 1,) + elif len(simplified_grid_size) == 2: + ranges = num_cores_to_grid_ranges( + start_coord, units_to_divide, grid_size, row_wise + ) + end_coord = ((0,) * (len(grid_size) - 2)) + ranges[-1][ + 1 + ] # Last range's end coordinate + return (units_to_divide, (start_coord, end_coord), (), 1, 0) + else: + # more work units than cores, divide work as evenly as possible + if len(simplified_grid_size) == 1: + work_per_core = units_to_divide // total_cores + remaining_work = units_to_divide % total_cores + end_coord_1 = ((0,) * (len(grid_size) - 1)) + (remaining_work,) + start_coord_2 = ((0,) * (len(grid_size) - 1)) + (remaining_work + 1,) + end_coord_2 = ((0,) * (len(grid_size) - 1)) + (total_cores - 1,) + return ( + total_cores, + ((0,) * len(grid_size), end_coord_1), + (start_coord_2, end_coord_2), + work_per_core + 1, + work_per_core, + ) + + elif len(simplified_grid_size) == 2: + """ + For 2D grids with more work than cores: + - Use all available cores + - Distribute work as evenly as possible + - Group 1: cores that get (work_per_core + 1) units + - Group 2: cores that get work_per_core units + """ + work_per_core = units_to_divide // total_cores + num_cores_with_more_work = units_to_divide % total_cores + + # Evenly divided - all cores get same amount + if num_cores_with_more_work == 0: + num_cores_x = grid_size[-1] + num_cores_y = grid_size[-2] + prefix = (0,) * (len(grid_size) - 2) + end_coord = prefix + (num_cores_y - 1, num_cores_x - 1) + return (total_cores, (start_coord, end_coord), (), work_per_core, 0) + + # Uneven division - need two groups + else: + # Group 1: first num_cores_with_more_work cores get (work_per_core + 1) + group1_ranges = num_cores_to_grid_ranges( + (0,) * len(grid_size), num_cores_with_more_work, grid_size, row_wise + ) + + # Find the last core of group 1 to determine where group 2 starts + last_range_group1 = group1_ranges[-1] + last_coord_group1 = last_range_group1[1] # end coord of last range + + # Calculate starting position for group 2 + num_cores_x = grid_size[-1] + num_cores_y = grid_size[-2] + last_x = last_coord_group1[-1] + last_y = last_coord_group1[-2] + + if row_wise: + # Start in the same row if possible + if last_x != num_cores_x - 1: + start_x_group2 = last_x + 1 + start_y_group2 = last_y + # Otherwise start in the next row + else: + start_x_group2 = 0 + start_y_group2 = last_y + 1 + else: + # Column-wise: Start in the same column if possible + if last_y != num_cores_y - 1: + start_x_group2 = last_x + start_y_group2 = last_y + 1 + # Otherwise start in the next column + else: + start_x_group2 = last_x + 1 + start_y_group2 = 0 + + prefix = (0,) * (len(grid_size) - 2) + start_coord_group2 = prefix + (start_y_group2, start_x_group2) + + num_cores_group2 = total_cores - num_cores_with_more_work + group2_ranges = num_cores_to_grid_ranges( + start_coord_group2, num_cores_group2, grid_size, row_wise + ) + + # For simplified return, we'll return the bounding boxes + # Group 1: from (0,0,...) to last coord of group 1 + group1_bbox = (start_coord, last_coord_group1) + + # Group 2: from start to last coord of group 2 + last_coord_group2 = group2_ranges[-1][1] + group2_bbox = (start_coord_group2, last_coord_group2) + + return ( + total_cores, + group1_bbox, + group2_bbox, + work_per_core + 1, + work_per_core, + ) diff --git a/python/ttlang/utils/correctness.py b/python/ttlang/utils/correctness.py new file mode 100644 index 00000000..5d539d38 --- /dev/null +++ b/python/ttlang/utils/correctness.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for comparing tensor outputs in tests.""" + +import torch +import math + + +def assert_pcc(golden, actual, threshold=0.99): + """ + Assert Pearson correlation coefficient is above threshold. + + Args: + golden: Expected tensor + actual: Actual output tensor + threshold: Minimum acceptable PCC (default 0.99) + + Raises: + AssertionError: If PCC < threshold + """ + combined = torch.stack([golden.flatten(), actual.flatten()]) + pcc = torch.corrcoef(combined)[0, 1].item() + assert ( + pcc >= threshold + ), f"Expected pcc {pcc} >= {threshold}\ngolden:\n{golden}\nactual:\n{actual}" + + +def assert_allclose( + actual, + expected, + rtol=1e-5, + atol=1e-8, + verbose=True, +): + """ + Assert tensors are close with detailed error reporting. + + Computes both absolute and relative errors with informative failure messages + showing error statistics and worst-case locations. + + + Args: + actual: Actual output tensor + expected: Expected tensor + rtol: Relative tolerance (default 1e-5) + atol: Absolute tolerance (default 1e-8) + verbose: Show detailed error statistics on failure (default True) + + Raises: + AssertionError: If tensors don't match within tolerance + + Examples: + >>> out = model(input) + >>> assert_allclose(out, expected, rtol=1e-4, atol=1e-6) + """ + if actual.shape != expected.shape: + raise AssertionError( + f"Shape mismatch: actual {actual.shape} vs expected {expected.shape}" + ) + + # Compute element-wise absolute error + abs_diff = torch.abs(actual - expected) + max_abs_error = abs_diff.max().item() + mean_abs_error = abs_diff.mean().item() + + # Compute element-wise relative error with epsilon for numerical stability + eps = 1e-10 + rel_error = abs_diff / (torch.abs(expected) + eps) + max_rel_error = rel_error.max().item() + mean_rel_error = rel_error.mean().item() + + # Check if within tolerance + is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol) + + if not is_close and verbose: + # Find locations of worst errors + abs_error_flat = abs_diff.flatten() + rel_error_flat = rel_error.flatten() + + worst_abs_idx = abs_error_flat.argmax() + worst_rel_idx = rel_error_flat.argmax() + + # Convert flat indices to coordinates + worst_abs_coord = torch.unravel_index(worst_abs_idx, actual.shape) + worst_rel_coord = torch.unravel_index(worst_rel_idx, actual.shape) + + error_msg = f""" +Tensor comparison failed! + +Error Statistics: + Absolute Error: + Mean: {mean_abs_error:.6e} + Max: {max_abs_error:.6e} at {tuple(c.item() for c in worst_abs_coord)} + actual={actual[worst_abs_coord].item():.6f}, expected={expected[worst_abs_coord].item():.6f} + + Relative Error: + Mean: {mean_rel_error:.6e} + Max: {max_rel_error:.6e} at {tuple(c.item() for c in worst_rel_coord)} + actual={actual[worst_rel_coord].item():.6f}, expected={expected[worst_rel_coord].item():.6f} + +Thresholds: + rtol: {rtol} + atol: {atol} + +Shape: {actual.shape} +Mismatched elements: {(abs_diff > atol + rtol * torch.abs(expected)).sum().item()} / {actual.numel()} +""" + raise AssertionError(error_msg) + + if not is_close: + raise AssertionError( + f"Tensors not close: max_abs_error={max_abs_error:.6e}, " + f"max_rel_error={max_rel_error:.6e}, rtol={rtol}, atol={atol}" + ) + + +def _comp_nonfinite(golden, calculated): + """ + Returns True if tensors contain the same non-finite values (nan, inf, -inf) at the same positions. Also returns True if all elements are finite. + Returns False if non-finite values differ between both tensors. + """ + + # torch.equal(['nan'], ['nan']] => False + # For this reason, we check for nan and inf separately + if torch.not_equal(torch.isnan(golden), torch.isnan(calculated)).any(): + return False + + golden_inf_mask = torch.isinf(golden) + calculated_inf_mask = torch.isinf(calculated) + + if torch.not_equal(golden_inf_mask, calculated_inf_mask).any(): + return False + + golden_inf = golden[golden_inf_mask] + calculated_inf = calculated[calculated_inf_mask] + return torch.equal(golden_inf, calculated_inf) + + +def ulp(x: torch.Tensor) -> torch.Tensor: + "Return Unit of Least Precision for each element of a given tensor" + # Notes: + # - This should be identical to the definition of ULP by Goldberg + # "What every computer scientist should know about floating-point arithmetic" + # https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html + # - We use torch.abs(x) to ensure symmetry ULP(-x) == ULP(x) + # - For x powers of 2, x + ULP(x) is not closest number but second closest (previous number is 2x closer) + # However, this avoids rounding-to-nearest-tie-to-even issues on addition (i.e. x + ULP(x) != x) + abs_x = torch.abs(x) + next = torch.nextafter( + abs_x, torch.tensor(math.inf, dtype=x.dtype) + ) # 1 ULP ~ Difference between two consecutive floating point numbers + ulp_value = next - abs_x + + # Special case: if abs_x == torch.finfo(x.dtype).max, then next == math.inf, which leads to ULP(x) == inf rather than finite number + # We fix this problem by manually calculating ULP at max value, and masking tensor when input == max + dtype_max = torch.finfo(x.dtype).max + max_epsilon = dtype_max - torch.nextafter( + torch.tensor(dtype_max, dtype=x.dtype), torch.tensor(-math.inf, dtype=x.dtype) + ) + ulp_value = torch.where(abs_x == dtype_max, max_epsilon, ulp_value) + + return ulp_value + + +def comp_ulp(golden, calculated, ulp_threshold, allow_nonfinite=True): + """ + Compute absolute error between two tensors in Units of Least Precision (ULP) + """ + + # If both tensors are empty, then we can return True + if torch.numel(golden) == 0 and torch.numel(calculated) == 0: + return True, "Both tensors are empty" + + if not allow_nonfinite and not torch.all(torch.isfinite(calculated)): + return False, "Calculated tensor contains non-finite values" + + if not _comp_nonfinite(golden, calculated): + return False, "Tensors are not finite at the same positions" + # nonfinite elements can interfere with ULP error calculation + # To avoid this, replace nan, +inf, -inf with 0 + # (we have already checked that both tensors have the same nonfinite elements) + mask_finite = ~torch.isfinite(golden) + golden = golden.clone() + calculated = calculated.clone() + golden[mask_finite] = 0 + calculated[mask_finite] = 0 + + # ULP is measured according to the golden tensor + # In most cases, data type of golden tensor should be the same as calculated tensor. + # However, in some cases, we may want to measure < 1 ULP differences, which requires golden tensor + # to have higher precision than calculated tensor. + # If we passed golden tensor to ulp() as is, we would get ULP of higher precision. + # e.g. ulp of float32 rather bfloat16 calculation, which would give us a wrong value. + ulp_value = ulp(golden.type(calculated.dtype)) + + if ( + golden.dtype != calculated.dtype + ): # Note: assumes that golden has higher precision than calculated tensor + calculated = calculated.type(golden.dtype) + ulp_value = ulp_value.type( + golden.dtype + ) # Convert ULP to higher precision (for sub-1 ULP measurements) + + ulp_delta = torch.max(torch.abs(calculated - golden) / ulp_value) + + return (ulp_delta <= ulp_threshold, f"Max ULP Delta: {ulp_delta}") + + +# TODO: add support for ttnn.Tensor inputs when ttnn module is part of tt-lang dependencies +def assert_with_ulp( + expected_result: torch.Tensor, + actual_result: torch.Tensor, + ulp_threshold=10, + allow_nonfinite=False, +): + """ + Assert that two tensors are similar within a given distance expressed in Units of Least Precision (ULP) + + The error is measured using the following formula: + `` + | expected - actual | / ULP(expected) + `` + + Where ULP(expected) returns, for each element, the length of a single Unit of Least Precision (ULP). + + + Args: + expected_result (Union[ttnn.Tensor, torch.Tensor]): The expected reference tensor + actual_result (Union[ttnn.Tensor, torch.Tensor]): The actual tensor to compare against the reference + ulp_threshold (float, optional): Maximum tolerated ULP distance. Defaults to 10. + allow_nonfinite (bool, optional): If disabled, any non-finite value (NaN, +inf, -inf) will trigger an assertion. If enabled, differences between non-finite values at the same positions will trigger an assertion. + + Notes: + The length of a single ULP is measured using the difference between two consecutive floating point numbers. + + ULP should be preferred when errors between `calculated` and `golden` outputs are known to be small (difference < 10s of ULPs). + This is typically the case for element-wise operations that approximate common numerical functions (e.g. exp, pow, log, ...). + + For more significant differences, where `calculated` and `golden` differ by orders of magnitude, ULPs may be harder to compare + Indeed, with current definition, on bfloat16: + - ULP-Delta(4, 0) = 128 + - ULP-Delta(0, 4) = 4.36e+40 + + Generally, if the ULP error exceeds the 2**(#mantissa bits) (128-ULP for bfloat16, 8388608 for float32), then it means that both outputs are different by more than an order of magnitude. + For these cases, functions such as `assert_allclose(golden, calculated, rtol, atol)` should be used instead. + + To measure the accuracy in ULP of operations on bfloat8_b data type, the ttnn bfloat8_b tensor should be either passed directly to the + function, or converted to bfloat16 beforehand (bfloat16 has the 'same' resolution as bfloat8_b). + Indeed, ttnn.to_torch() converts bfloat8_b to float32 by default, which would lead to assert_with_ulp() measuring ULP error as if + data type was computed as float32. + + This should be identical to the definition of ULP by Goldberg + "What every computer scientist should know about floating-point arithmetic" + https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html + + Returns: + tuple: A tuple containing: + - ulp_passed (bool): True if ulp check passed, False otherwise + - ulp_message (str): A message describing comparison result + + Raises: + AssertionError: If the tensor shapes don't match or if tensor difference is greater than ulp_threshold. + """ + + assert list(expected_result.shape) == list( + actual_result.shape + ), f"list(expected_result.shape)={list(expected_result.shape)} vs list(actual_result.shape)={list(actual_result.shape)}" + + maximum_meaningful_ulp_thresholds = { + torch.float64: 2**52, + torch.float32: 2**23, + torch.float16: 2**10, + torch.bfloat16: 2**7, + } + maximum_meaningful_ulp_threshold = ( + maximum_meaningful_ulp_thresholds[torch.float32] + if expected_result.dtype in maximum_meaningful_ulp_thresholds + else maximum_meaningful_ulp_thresholds[expected_result.dtype] + ) + + if ulp_threshold > maximum_meaningful_ulp_threshold: + print( + f"ULP threshold {ulp_threshold} is greater than the maximum meaningful ULP threshold of {maximum_meaningful_ulp_threshold} for dtype {expected_result.dtype}" + ) + + ulp_passed, ulp_message = comp_ulp( + expected_result, actual_result, ulp_threshold, allow_nonfinite + ) + assert ulp_passed, ulp_message + return ulp_passed, ulp_message diff --git a/test/python/utils/test_block_allocation.py b/test/python/utils/test_block_allocation.py new file mode 100644 index 00000000..828f6879 --- /dev/null +++ b/test/python/utils/test_block_allocation.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Test comparing new_split_work_to_cores with ttnn.split_work_to_cores +""" +import pytest + +from ttlang.utils.block_allocation import new_split_work_to_cores +import ttnn + + +def extract_coords_from_ttnn_corerangeset(core_range_set): + """Extract all start and end coordinates from a ttnn CoreRangeSet""" + if not core_range_set.ranges(): + return [] + + coords = [] + for r in core_range_set.ranges(): + coords.append(((r.start.y, r.start.x), (r.end.y, r.end.x))) + return coords + + +@pytest.mark.parametrize( + "grid_size_tuple,units,row_wise", + [ + # Test cases with more work than cores + ((8, 8), 100, True), + ((8, 8), 100, False), + ((8, 8), 65, True), + ((8, 8), 129, True), + # Test even distribution + ((8, 8), 64, True), + ((8, 8), 128, True), + # Test with different grid sizes + ((4, 8), 50, True), + ((7, 9), 100, False), + # Test fewer units than cores + ((8, 8), 10, True), + ((8, 8), 20, False), + ((8, 8), 1, True), + # Test edge cases + ((8, 8), 63, True), + ((8, 8), 127, True), + ], +) +def test_split_work_to_cores(grid_size_tuple, units, row_wise): + """Compare results from new_split_work_to_cores and ttnn.split_work_to_cores""" + # Call new function + new_result = new_split_work_to_cores(grid_size_tuple, units, row_wise) + new_total, new_g1, new_g2, new_w1, new_w2 = new_result + + # Call ttnn function + # Create CoreRangeSet from grid_size_tuple + num_cores_x = grid_size_tuple[-1] + num_cores_y = grid_size_tuple[-2] + ttnn_grid = ttnn.CoreRangeSet( + [ + ttnn.CoreRange( + ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_cores_x - 1, num_cores_y - 1) + ) + ] + ) + + ttnn_result = ttnn.split_work_to_cores(ttnn_grid, units, row_wise) + ttnn_total, ttnn_all, ttnn_g1, ttnn_g2, ttnn_w1, ttnn_w2 = ttnn_result + + # Extract coordinates from ttnn function + ttnn_g1_coords = extract_coords_from_ttnn_corerangeset(ttnn_g1) + ttnn_g2_coords = extract_coords_from_ttnn_corerangeset(ttnn_g2) + + # Verify work distribution matches + assert new_w1 == ttnn_w1, f"Work per core G1 mismatch: {new_w1} vs {ttnn_w1}" + assert new_w2 == ttnn_w2, f"Work per core G2 mismatch: {new_w2} vs {ttnn_w2}" + + # Calculate total cores in each group from ttnn + ttnn_g1_num_cores = sum( + (end[1] - start[1] + 1) * (end[0] - start[0] + 1) + for start, end in ttnn_g1_coords + ) + ttnn_g2_num_cores = sum( + (end[1] - start[1] + 1) * (end[0] - start[0] + 1) + for start, end in ttnn_g2_coords + ) + + # Verify total work matches + new_total_work = ttnn_g1_num_cores * new_w1 + ttnn_g2_num_cores * new_w2 + ttnn_total_work = ttnn_g1_num_cores * ttnn_w1 + ttnn_g2_num_cores * ttnn_w2 + assert ( + new_total_work == ttnn_total_work == units + ), f"Total work mismatch: {new_total_work} vs {ttnn_total_work} vs {units}" + + # Verify group 1 coordinates + if new_g1 and ttnn_g1_coords: + new_g1_start, new_g1_end = new_g1 + ttnn_g1_first_start = ttnn_g1_coords[0][0] + ttnn_g1_last_end = ttnn_g1_coords[-1][1] + assert ( + new_g1_start == ttnn_g1_first_start and new_g1_end == ttnn_g1_last_end + ), f"Group 1 coordinates mismatch: new {new_g1_start} -> {new_g1_end}, ttnn {ttnn_g1_first_start} -> {ttnn_g1_last_end}" + + # Verify group 2 coordinates + if new_g2 and ttnn_g2_coords: + new_g2_start, new_g2_end = new_g2 + ttnn_g2_first_start = ttnn_g2_coords[0][0] + ttnn_g2_last_end = ttnn_g2_coords[-1][1] + assert ( + new_g2_start == ttnn_g2_first_start and new_g2_end == ttnn_g2_last_end + ), f"Group 2 coordinates mismatch: new {new_g2_start} -> {new_g2_end}, ttnn {ttnn_g2_first_start} -> {ttnn_g2_last_end}" + + # Check empty groups match + if not new_g1: + assert not ttnn_g1_coords, "Group 1 empty mismatch" + if not new_g2: + assert not ttnn_g2_coords, "Group 2 empty mismatch"