-
Notifications
You must be signed in to change notification settings - Fork 22
Large Kernel Support for MPWI #1045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thank you for your contribution! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for large kernels (kernel_hw > 32) to the Max Pool With Indices (MPWI) operation by implementing accumulation over multiple chunks. Previously, MPWI only supported kernels with height/width ≤ 32.
Key changes:
- Added an
accumulatetemplate parameter to enable chunk-based processing for large kernels - Renamed the third parameter from
tile_idx(unused) tochunkto track the chunk index during accumulation - Implemented accumulation logic that stores intermediate results in separate destination tiles and combines them across chunks
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| tt_llk_wormhole_b0/common/inc/sfpu/ckernel_sfpu_max_pool_indices.h | Updated MPWI functions to support large kernel accumulation for Wormhole B0 architecture |
| tt_llk_blackhole/common/inc/sfpu/ckernel_sfpu_max_pool_indices.h | Updated MPWI functions to support large kernel accumulation for Blackhole architecture |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| TT_SFPSTORE(p_sfpu::LREG0, InstrModLoadStore::DEFAULT, ADDR_MOD_3, values_accum_tile_offset + col_offset); | ||
| } | ||
|
|
||
| // store the final result to DST 0 (data) and DST 2 (indexes) |
Copilot
AI
Jan 1, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment uses "indexes" but throughout the rest of the codebase and documentation, the term "indices" is consistently used. This should be changed to "indices" for consistency.
| // store the final result to DST 0 (data) and DST 2 (indexes) | |
| // store the final result to DST 0 (data) and DST 2 (indices) |
tt_llk_blackhole/common/inc/sfpu/ckernel_sfpu_max_pool_indices.h
Outdated
Show resolved
Hide resolved
tt_llk_wormhole_b0/common/inc/sfpu/ckernel_sfpu_max_pool_indices.h
Outdated
Show resolved
Hide resolved
tt_llk_blackhole/common/inc/sfpu/ckernel_sfpu_max_pool_indices.h
Outdated
Show resolved
Hide resolved
tt_llk_wormhole_b0/common/inc/sfpu/ckernel_sfpu_max_pool_indices.h
Outdated
Show resolved
Hide resolved
tt_llk_blackhole/common/inc/sfpu/ckernel_sfpu_max_pool_indices.h
Outdated
Show resolved
Hide resolved
🚀 tt-metal post-commit testsBranch:
Test Results:
🔗 Links📊 Post-commit workflow: #20644145065 |
🚀 tt-metal post-commit testsBranch:
Test Results:
🔗 Links📊 Post-commit workflow: #20644810710 |
253e11a to
478513d
Compare
🚀 tt-metal post-commit testsBranch:
Test Results:
🔗 Links📊 Post-commit workflow: #20725100735 |
🚀 tt-metal post-commit testsBranch:
Test Results:
🔗 Links📊 Post-commit workflow: #20727817391 |
| } | ||
| else | ||
| { | ||
| static_assert(!accumulate, "accumulate mode is not supported for TILE layout"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please make sure that layout == TILE in that case. This way, assert will trigger for ROW_MAJOR + accumulate as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is in the else block of the if constexpr (layout == ckernel::DataLayout::ROW_MAJOR) so I think it's already checking this.
| } | ||
| else | ||
| { | ||
| static_assert(!accumulate, "accumulate mode is not supported for TILE layout"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same thing here
478513d to
0cf13ce
Compare
| bool accumulate = false> | ||
| inline void _calculate_max_pool_with_indices_(const uint values_tile_idx, const uint indices_tile_idx, const uint chunk) | ||
| { | ||
| // size of each tile in Dest is 64 rows |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding LLK_ASSERTs to check values of values_tile_idx, indice_tile_idx, and chunk are valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if I can do much to validate the chunk number. They should increase 1 by 1 from 0 but without a way to save the state I'm not sure how to do this.
For the DST tiles the only real requirement is that each tile has a clear tile after it if accumulation is true. Otherwise there are no restrictions, would you like me to add these checks?
🚀 tt-metal post-commit testsBranch:
Test Results:
🔗 Links📊 Post-commit workflow: #20929388907 |
🚀 tt-metal post-commit testsBranch:
Test Results:
🔗 Links📊 Post-commit workflow: #20930842380 |
107b48c to
5a534ab
Compare
🚀 tt-metal post-commit testsBranch:
Test Results:
🔗 Links📊 Post-commit workflow: #21005031265 |
🚀 tt-metal post-commit testsBranch:
Test Results:
🔗 Links📊 Post-commit workflow: #21005217637 |
| { // for all but the first chunk we need to load the previous result from DST 1 and 3 and do a max with the current result in DST 0 and 2 | ||
| TT_SFPLOAD(p_sfpu::LREG1, InstrModLoadStore::DEFAULT, ADDR_MOD_3, values_accum_tile_offset + col_offset); // previous accumulated value | ||
| TT_SFPLOAD(p_sfpu::LREG5, instr_mod_index, ADDR_MOD_3, indices_accum_tile_offset + col_offset); // previous accumulated index | ||
| TTI_SFPSWAP(0, p_sfpu::LREG0, p_sfpu::LREG1, p_sfpswap::ALL_ROWS_MAX); // LREG0 contains max of current and previous value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This instruction requires 2 cycles to complete and should always be followed by an SFPNOP instruction.
Could we swap order of the second SFPLOAD and SFPSWAP to hide this NOP?
edit: Nevermind, it's already handled by the order of STOREs.
| template <bool APPROXIMATION_MODE, bool is_fp32_dest_acc_en, int ITERATIONS> | ||
| inline void _calculate_max_pool_with_indices_generic_(const uint values_tile_idx, const uint indices_tile_idx, const uint tile_idx /* unused */) | ||
| template <bool APPROXIMATION_MODE, bool is_fp32_dest_acc_en, int ITERATIONS, bool accumulate = false> | ||
| inline void _calculate_max_pool_with_indices_generic_(const uint values_tile_idx, const uint indices_tile_idx, const uint chunk) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After looking into this locally, I think we can eliminate some loads and stores. Will write more later.
Basically, after reducing an 8-row block, results were stored to memory and immediately reloaded for the next reduction stage. I think we can avoid that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the original implementation we reduce 8 rows and immediately store the result, only to reload it right after:
Original:
- Load rows 0-7 from Dest Reg -> LREG0-3, LREG4-7
- Reduce (max in LREG0/LREG4)
- Store LREG0/LREG4 -> Dest Reg offset 0
- Load rows 8-15 from Dest Reg -> LREG0-3, LREG4-7
- Reduce (max in LREG0/LREG4)
- Store LREG0/LREG4 -> Dest Reg offset 16
- Load from Dest Reg offset 0 -> LREG0/LREG4 <- reload result of step 3
- Load from Dest Reg offset 16 -> LREG1/LREG5 <- reload result of step 6
- SWAP
- store final
Idea:
- Load rows 0-7 from Dest Reg -> LREG0-3, LREG4-7
- Reduce (max in LREG0/LREG4)
- Store LREG0/LREG4 -> Dest Reg offset 0
- Load rows 8-15 from Dest Reg -> LREG0-3, LREG4-7
- Reduce (max in LREG0/LREG4)
- Skip store - keep in LREG0/LREG4
- Load from Dest Reg offset 0 -> LREG1/LREG5
- Skip load - second result already in LREG0/LREG4
- SWAP
- store final
So steps 6 and step 8 differ, and this is where we can save some time.
This cuts out 2 stores and 2 loads per call to process_16_rows x4 = 8 loads and 8 stores
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one's a bit more tricky. Originally we process all rows 0-15 first, then all rows 16-31, then do the final swaps:
process_16_rows(0, even); // rows 0-15 even
process_16_rows(0, odd); // rows 0-15 odd
process_16_rows(32, even); // rows 16-31 even, store result
process_16_rows(32, odd); // rows 16-31 odd, store result
final_swap(even); // has to reload rows 16-31 even
final_swap(odd); // has to reload rows 16-31 oddBut there's no dependency between even and odd columns (they touch different memory), so we can process each column completely and then move on:
process_16_rows(0, even, true);
process_16_rows(32, even, false); // keep result in LREG0/LREG4
final_swap(even); // use what's already in registers
process_16_rows(0, odd, true);
process_16_rows(32, odd, false); // keep result in LREG0/LREG4
final_swap(odd); // use what's already in registersNow final_swap() can use the result directly from LREG0/LREG4 instead of loading it back from memory.
That's another 2 stores and 2 loads saved per column, so 4 more of each total.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I summarized my ideas and wrote the instructions to the LLM and here is what it produced. I haven't had the time to actually test it, but I went through the code and I think it does what I summarized above:
/**
* @brief Calculates column-wise MaxPool of a tile, placing output values into the first row.
* Also places the index of the max value into the first row of the indices tile.
* Supports {FP32, FP16_B} for values, and {UINT16, INT32, UINT32} for indices, inferred from the Dest mode used.
* Can reduce up to 32 rows of a tile.
* @tparam APPROXIMATION_MODE Whether to use the approximation mode (unused).
* @tparam is_fp32_dest_acc_en Whether Dest is in 32bit mode (true) or 16bit mode (false).
* @tparam ITERATIONS The number of iterations to use for the MaxPool operation (unused).
* @tparam accumulate Whether to accumulate results for large kernels (default is false).
* @param values_tile_idx The index of the tile in the Dest register containing the data to be reduced.
* @param indices_tile_idx The index of the tile in the Dest register containing the indices of the data.
* @param chunk The chunk index for large kernel accumulation.
*
* Note this function is only implemented for ROW_MAJOR data layout, so when _init_max_pool_with_indices_ is called
* it must be called with layout=DataLayout::ROW_MAJOR.
*/
template <bool APPROXIMATION_MODE, bool is_fp32_dest_acc_en, int ITERATIONS, bool accumulate = false>
inline void _calculate_max_pool_with_indices_generic_(const uint values_tile_idx, const uint indices_tile_idx, const uint chunk)
{
// size of each tile in Dest is 64 rows
constexpr uint dst_tile_size = 64;
const uint values_tile_offset = values_tile_idx * dst_tile_size;
const uint indices_tile_offset = indices_tile_idx * dst_tile_size;
const uint values_accum_tile_offset = (values_tile_idx + 1) * dst_tile_size;
const uint indices_accum_tile_offset = (indices_tile_idx + 1) * dst_tile_size;
// each face is 16 rows
constexpr uint eight_row_offset = 16;
constexpr uint sixteen_row_offset = 32;
constexpr uint8_t instr_mod_index = is_fp32_dest_acc_en ? InstrModLoadStore::INT32 : InstrModLoadStore::LO16;
// ROW MAJOR DATA VERSION OF MPWI
// DATA IS EXPECTED TO BE IN THE FOLLOWING ORDER IN DEST:
// Face 0 Row 0
// Face 1 Row 0
// Face 0 Row 1
// Face 1 Row 1
// .
// .
// .
// Face 0 Row 31
// Face 1 Row 31
// Reduces 8 rows to max in LREG0/LREG4, optionally stores result.
auto reduce_8_rows = [instr_mod_index](const uint val_base, const uint idx_base, const bool store_result) __attribute__((always_inline))
{
// data - precomputed base address eliminates repeated arithmetic
TT_SFPLOAD(p_sfpu::LREG0, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_base + 0); // Row 0 and 1
TT_SFPLOAD(p_sfpu::LREG1, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_base + 4); // Row 2 and 3
TT_SFPLOAD(p_sfpu::LREG2, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_base + 8); // Row 4 and 5
TT_SFPLOAD(p_sfpu::LREG3, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_base + 12); // Row 6 and 7
// index
TT_SFPLOAD(p_sfpu::LREG4, instr_mod_index, ADDR_MOD_7, idx_base + 0);
TT_SFPLOAD(p_sfpu::LREG5, instr_mod_index, ADDR_MOD_7, idx_base + 4);
TT_SFPLOAD(p_sfpu::LREG6, instr_mod_index, ADDR_MOD_7, idx_base + 8);
TT_SFPLOAD(p_sfpu::LREG7, instr_mod_index, ADDR_MOD_7, idx_base + 12);
// Reduce 8 rows to max in LREG0/LREG4 via replay buffer
lltt::replay(0, 7);
// Only store when necessary - caller controls this
if (store_result)
{
TT_SFPSTORE(p_sfpu::LREG0, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_base + 0);
TT_SFPSTORE(p_sfpu::LREG4, instr_mod_index, ADDR_MOD_7, idx_base + 0);
}
// Result: Max of 8 rows in LREG0 (values) and LREG4 (indices)
};
// OPTIMIZATION: Flattened process_16_rows - eliminates nested lambda overhead
// and removes redundant store-load pairs.
// Note: After reducing the second 8-row block, the result is already in LREG0/LREG4.
// We only need to reload the first block's result into LREG1/LREG5 for the final swap.
// store_result: if false, result stays in LREG0/LREG4 for caller to use directly.
auto process_16_rows = [&reduce_8_rows, values_tile_offset, indices_tile_offset, eight_row_offset, instr_mod_index](
const uint base_offset, const uint col_offset, const bool store_result) __attribute__((always_inline))
{
// Precompute base addresses for both 8-row blocks
const uint val_base_first = values_tile_offset + base_offset + col_offset;
const uint idx_base_first = indices_tile_offset + base_offset + col_offset;
const uint val_base_second = values_tile_offset + eight_row_offset + base_offset + col_offset;
const uint idx_base_second = indices_tile_offset + eight_row_offset + base_offset + col_offset;
// First 8 rows: reduce and STORE (we need to free registers for second block)
reduce_8_rows(val_base_first, idx_base_first, true);
// Second 8 rows: reduce but DON'T STORE - keep result in LREG0/LREG4
reduce_8_rows(val_base_second, idx_base_second, false);
// Now: LREG0/LREG4 contains Max(R8-15), need to load Max(R0-7) into LREG1/LREG5
TT_SFPLOAD(p_sfpu::LREG1, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_base_first);
TT_SFPLOAD(p_sfpu::LREG5, instr_mod_index, ADDR_MOD_7, idx_base_first);
// Swap to get Max(R0-15) in LREG0/LREG4
TTI_SFPSWAP(0, p_sfpu::LREG0, p_sfpu::LREG1, p_sfpswap::ALL_ROWS_MAX);
// Only store if needed - caller may use LREG0/LREG4 directly
if (store_result)
{
TT_SFPSTORE(p_sfpu::LREG4, instr_mod_index, ADDR_MOD_7, idx_base_first);
TT_SFPSTORE(p_sfpu::LREG0, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_base_first);
}
};
// Final swap: combine first 16 rows with second 16 rows.
// OPTIMIZATION: After process_16_rows(sixteen_row_offset, col), Max(R16-31) is already in LREG0/LREG4.
// We only need to load Max(R0-15) into LREG1/LREG5, saving 2 loads per column.
auto final_swap =
[values_tile_offset, indices_tile_offset, values_accum_tile_offset, indices_accum_tile_offset, instr_mod_index, chunk](
const uint col_offset) __attribute__((always_inline))
{
// Precompute addresses
const uint val_first = values_tile_offset + col_offset;
const uint idx_first = indices_tile_offset + col_offset;
// LREG0/LREG4 already contains Max(R16-31) from the previous process_16_rows call
// Only need to load Max(R0-15) into LREG1/LREG5
TT_SFPLOAD(p_sfpu::LREG1, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_first); // Max(R0-15) for F0,1
TT_SFPLOAD(p_sfpu::LREG5, instr_mod_index, ADDR_MOD_7, idx_first);
TTI_SFPSWAP(0, p_sfpu::LREG0, p_sfpu::LREG1, p_sfpswap::ALL_ROWS_MAX); // LREG0 contains Max(R0-31) for F0,1
if constexpr (accumulate)
{
const uint val_accum = values_accum_tile_offset + col_offset;
const uint idx_accum = indices_accum_tile_offset + col_offset;
if (chunk > 0)
{ // for all but the first chunk we need to load the previous result from DST 1 and 3 and do a max with the current result in DST 0 and 2
TT_SFPLOAD(p_sfpu::LREG1, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_accum); // previous accumulated value
TT_SFPLOAD(p_sfpu::LREG5, instr_mod_index, ADDR_MOD_7, idx_accum); // previous accumulated index
TTI_SFPSWAP(0, p_sfpu::LREG0, p_sfpu::LREG1, p_sfpswap::ALL_ROWS_MAX); // LREG0 contains max of current and previous value
}
// for each chunk we store the running result to DST 1 and 3
TT_SFPSTORE(p_sfpu::LREG4, instr_mod_index, ADDR_MOD_7, idx_accum);
TT_SFPSTORE(p_sfpu::LREG0, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_accum);
}
// store the final result to DST 0 (data) and DST 2 (indices)
TT_SFPSTORE(p_sfpu::LREG4, instr_mod_index, ADDR_MOD_7, idx_first);
TT_SFPSTORE(p_sfpu::LREG0, InstrModLoadStore::DEFAULT, ADDR_MOD_7, val_first);
};
// OPTIMIZATION: Process each column completely before moving to the next.
// This allows the second process_16_rows to leave Max(R16-31) in LREG0/LREG4,
// which final_swap can use directly without reloading.
// Saves 2 stores + 2 loads per column = 4 stores + 4 loads total.
constexpr int even_column_offset = 0;
constexpr int odd_column_offset = 2;
// Even columns: process rows 0-15, then 16-31, then final swap
process_16_rows(0, even_column_offset, true); // Store Max(R0-15) for final_swap to load
process_16_rows(sixteen_row_offset, even_column_offset, false); // Keep Max(R16-31) in LREG0/LREG4
final_swap(even_column_offset); // Uses LREG0/LREG4 directly
// Odd columns: process rows 0-15, then 16-31, then final swap
process_16_rows(0, odd_column_offset, true); // Store Max(R0-15) for final_swap to load
process_16_rows(sixteen_row_offset, odd_column_offset, false); // Keep Max(R16-31) in LREG0/LREG4
final_swap(odd_column_offset); // Uses LREG0/LREG4 directly
}
Ticket
tenstorrent/tt-metal#27845
Metal PR:
tenstorrent/tt-metal#35216
Problem description
MPWI currently only supports
kernel_hw<=32.What's changed
SFPU MPWI functions have been updated to do accumulation over multiple chunks
Type of change
Checklist