Skip to content

Commit ec7c9bd

Browse files
authored
[skip uplift] Optimised fp32→bf16 typecast with RNE; fp32_to_[u]int32; uint16_to_uint32. (#945)
### Ticket tenstorrent/tt-metal#30147 Part of typecast LLK revamp: tenstorrent/tt-metal#33976 ### Problem description - fp32 to bf16 typecast used `SFPSTOCHRND` with "round to nearest, ties away from zero". - fp32_to_[u]int32 LLKs were in need of optimisation. ### What's changed - Modify fp32 to bf16 typecast to use round-to-nearest-even - Use SFPLOADMACRO to achieve 3-cycle throughput - Needs pairing with tt-metal PR tenstorrent/tt-metal#34376 to use init - Optimised fp32_to_[u]int32 LLKs. - Optimised uint16_to_uint32 via SFPLOADMACRO. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update ### Checklist <!-- These are required steps and need to be run from tt-metal repository's Actions. Use links below and replace them with your run --> - [ ] [All post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml) CI passes - [ ] [Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml) CI passes (if applicable)
1 parent c4a0105 commit ec7c9bd

File tree

5 files changed

+344
-209
lines changed

5 files changed

+344
-209
lines changed

tests/helpers/include/llk_sfpu_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,5 @@ enum class SfpuType
105105
acosh,
106106
reduce,
107107
add_top_row,
108+
typecast,
108109
};

tt_llk_blackhole/common/inc/sfpu/ckernel_sfpu_typecast.h

Lines changed: 161 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC
2+
// SPDX-FileCopyrightText: © 2025 Jason Davies <jason@jasondavies.com>
23
//
34
// SPDX-License-Identifier: Apache-2.0
45

@@ -70,64 +71,104 @@ inline void _calculate_typecast_int32_to_fp16b_()
7071
}
7172

7273
template <bool APPROXIMATION_MODE, int ITERATIONS>
73-
inline void _calculate_typecast_fp16b_to_int32_()
74+
inline void _calculate_typecast_fp32_to_int32_()
7475
{
75-
#pragma GCC unroll 0
76+
#pragma GCC unroll 8
7677
for (int d = 0; d < ITERATIONS; d++)
7778
{
78-
sfpi::vFloat in = sfpi::dst_reg[0];
79-
80-
// extract exponent
81-
sfpi::vInt exp = exexp(in);
82-
83-
v_if (exp < 0)
84-
{
85-
sfpi::dst_reg[0] = 0;
86-
}
87-
v_elseif (exp > 30)
88-
{
89-
// set to int32 max value in case of overflow
90-
sfpi::vInt tmp = std::numeric_limits<int32_t>::max();
91-
// check sign
92-
v_if (in < 0)
93-
{
94-
// 2's complement conversion
95-
tmp = (~tmp) + 1;
96-
}
97-
v_endif sfpi::dst_reg[0] = tmp;
98-
}
99-
v_else
100-
{
101-
// extract mantissa
102-
sfpi::vInt man = exman8(in);
103-
// shift the mantissa by (23-exponent) to the right
104-
sfpi::vInt shift = exp - 23;
105-
man = sfpi::shft(sfpi::reinterpret<sfpi::vUInt>(man), shift);
106-
// check sign
107-
v_if (in < 0)
108-
{
109-
// 2's complement conversion
110-
man = (~man) + 1;
111-
}
112-
v_endif sfpi::dst_reg[0] = man;
113-
}
114-
v_endif
115-
116-
sfpi::dst_reg++;
79+
TTI_SFPLOAD(p_sfpu::LREG0, InstrModLoadStore::DEFAULT, ADDR_MOD_7, 0);
80+
// result = 0
81+
TTI_SFPLOADI(p_sfpu::LREG1, sfpi::SFPLOADI_MOD0_USHORT, 0);
82+
83+
// exp = in.Exp (LaneEnabled = exp >= 0)
84+
TTI_SFPEXEXP(0, p_sfpu::LREG0, p_sfpu::LREG2, sfpi::SFPEXEXP_MOD1_SET_CC_SGN_EXP | sfpi::SFPEXEXP_MOD1_SET_CC_COMP_EXP);
85+
// result = INT_MIN
86+
TTI_SFPLOADI(p_sfpu::LREG1, sfpi::SFPLOADI_MOD0_FLOATB, 0x8000);
87+
// exp -= 31 (LaneEnabled = exp < 31)
88+
TTI_SFPIADD(-31 & 0xfff, p_sfpu::LREG2, p_sfpu::LREG2, sfpi::SFPIADD_MOD1_ARG_IMM | sfpi::SFPIADD_MOD1_CC_LT0);
89+
// exp += 8
90+
TTI_SFPIADD(8, p_sfpu::LREG2, p_sfpu::LREG2, sfpi::SFPIADD_MOD1_ARG_IMM | sfpi::SFPIADD_MOD1_CC_NONE);
91+
// result = exman8(in) << (exp - 23)
92+
TTI_SFPEXMAN(0, p_sfpu::LREG0, p_sfpu::LREG1, 0);
93+
TTI_SFPSHFT(0, p_sfpu::LREG2, p_sfpu::LREG1, 0);
94+
// LaneEnabled = true
95+
TTI_SFPENCC(0, 0, 0, 0);
96+
97+
// LaneEnabled = in < 0
98+
TTI_SFPSETCC(0, p_sfpu::LREG0, 0, sfpi::SFPSETCC_MOD1_LREG_LT0);
99+
// result = -result (two's complement)
100+
TTI_SFPIADD(0, p_sfpu::LCONST_0, p_sfpu::LREG1, sfpi::SFPIADD_MOD1_ARG_2SCOMP_LREG_DST | sfpi::SFPIADD_MOD1_CC_NONE);
101+
// LaneEnabled = true
102+
TTI_SFPENCC(0, 0, 0, 0);
103+
104+
TTI_SFPSTORE(p_sfpu::LREG1, InstrModLoadStore::INT32, ADDR_MOD_6, 0);
105+
}
106+
}
107+
108+
template <bool APPROXIMATION_MODE, int ITERATIONS>
109+
inline void _calculate_typecast_fp32_to_uint32_()
110+
{
111+
#pragma GCC unroll 8
112+
for (int d = 0; d < ITERATIONS; d++)
113+
{
114+
TTI_SFPLOAD(p_sfpu::LREG0, InstrModLoadStore::DEFAULT, ADDR_MOD_7, 0);
115+
// result = 0
116+
TTI_SFPLOADI(p_sfpu::LREG1, sfpi::SFPLOADI_MOD0_USHORT, 0);
117+
118+
// LaneEnabled = in >= 0
119+
TTI_SFPSETCC(0, p_sfpu::LREG0, 0, sfpi::SFPSETCC_MOD1_LREG_GTE0);
120+
// exp = in.Exp (LaneEnabled = exp >= 0)
121+
TTI_SFPEXEXP(0, p_sfpu::LREG0, p_sfpu::LREG2, sfpi::SFPEXEXP_MOD1_SET_CC_SGN_EXP | sfpi::SFPEXEXP_MOD1_SET_CC_COMP_EXP);
122+
// result = 0xffffffff
123+
TTI_SFPLOADI(p_sfpu::LREG1, sfpi::SFPLOADI_MOD0_SHORT, 0xffff);
124+
// exp -= 32 (LaneEnabled = exp < 31)
125+
TTI_SFPIADD(-32 & 0xfff, p_sfpu::LREG2, p_sfpu::LREG2, sfpi::SFPIADD_MOD1_ARG_IMM | sfpi::SFPIADD_MOD1_CC_LT0);
126+
// exp += 9
127+
TTI_SFPIADD(9, p_sfpu::LREG2, p_sfpu::LREG2, sfpi::SFPIADD_MOD1_ARG_IMM | sfpi::SFPIADD_MOD1_CC_NONE);
128+
// result = exman8(in) << (exp - 23)
129+
TTI_SFPEXMAN(0, p_sfpu::LREG0, p_sfpu::LREG1, 0);
130+
TTI_SFPSHFT(0, p_sfpu::LREG2, p_sfpu::LREG1, 0);
131+
// LaneEnabled = true
132+
TTI_SFPENCC(0, 0, 0, 0);
133+
134+
TTI_SFPSTORE(p_sfpu::LREG1, InstrModLoadStore::INT32, ADDR_MOD_6, 0);
117135
}
118136
}
119137

120138
template <bool APPROXIMATION_MODE, int ITERATIONS>
121139
inline void _calculate_typecast_fp32_to_fp16b_()
122140
{
123-
#pragma GCC unroll 0
141+
// This uses SFPLOADMACRO to achieve a throughput of 3 cycles per input row.
142+
//
143+
// Notation: [x] means scheduled by SFPLOADMACRO with VD=x.
144+
//
145+
// t | Load | Simple | MAD | Round | Store |
146+
// - | ---- | --------------- | --- | ---------- | ------- |
147+
// 0 | [a] | | | | |
148+
// 1 | [b] | | | [a] >>= 16 | |
149+
// 2 | | a &= 1 | | | |
150+
// 0 | ... | [b] += 0x7fff | | | |
151+
// 1 | ... | [a] L16 = a + b | | | [a] |
152+
// 2 | ... | | | | [b] L16 |
153+
//
154+
// Note that [a] schedules a 32-bit store, writing all zeros except for the
155+
// LSB, which may be 0 or 1. Then, [b] schedules a 16-bit store with
156+
// MOD0_FMT_BF16. The zeros mean that even if rounding is applied by
157+
// packers, the result will be truncated.
158+
159+
constexpr int b = p_sfpu::LREG2;
160+
161+
#pragma GCC unroll 8
124162
for (int d = 0; d < ITERATIONS; d++)
125163
{
126-
TTI_SFPLOAD(0, 0, ADDR_MOD_7, 0);
127-
TTI_SFP_STOCH_RND(0, 0, 2, 0, 1, 1);
128-
TTI_SFPSTORE(1, 0, ADDR_MOD_7, 0);
129-
sfpi::dst_reg++;
164+
int a = d & 1;
165+
TT_SFPLOADMACRO((0 << 2) | (a & 3), 0, ADDR_MOD_7, a >> 2);
166+
TTI_SFPLOADMACRO((1 << 2) | (b & 3), 0, ADDR_MOD_6, b >> 2);
167+
TT_SFPAND(0, p_sfpu::LREG12, a, 0);
130168
}
169+
TTI_SFPNOP;
170+
TTI_SFPNOP;
171+
TTI_SFPNOP;
131172
}
132173

133174
template <bool APPROXIMATION_MODE, int ITERATIONS>
@@ -166,61 +207,6 @@ inline void _calculate_typecast_int32_to_fp32_()
166207
}
167208
}
168209

169-
template <bool APPROXIMATION_MODE, int ITERATIONS>
170-
inline void _calculate_typecast_fp16b_to_uint32_()
171-
{
172-
#pragma GCC unroll 0
173-
for (int d = 0; d < ITERATIONS; d++)
174-
{
175-
sfpi::vFloat in = sfpi::dst_reg[0];
176-
177-
// check sign
178-
v_if (in <= 0)
179-
{
180-
sfpi::dst_reg[0] = 0;
181-
}
182-
v_else
183-
{
184-
// extract exponent
185-
sfpi::vInt exp = exexp(in);
186-
187-
v_if (exp < 0)
188-
{
189-
sfpi::dst_reg[0] = 0;
190-
}
191-
v_elseif (exp > 31)
192-
{
193-
// set to uint32 max value in case of overflow
194-
sfpi::vInt tmp = std::numeric_limits<int32_t>::max();
195-
sfpi::dst_reg[0] = sfpi::setsgn(sfpi::reinterpret<sfpi::vFloat>(tmp), 1);
196-
}
197-
v_elseif (exp == 31)
198-
{
199-
// extract mantissa without hidden bit
200-
sfpi::vInt man = exman9(in);
201-
// shift the mantissa by (23-exponent) to the right
202-
sfpi::vInt shift = exp - 23;
203-
man = sfpi::shft(sfpi::reinterpret<sfpi::vUInt>(man), shift);
204-
// add hidden bit back (due to bug when shifting a 1 into MSB)
205-
sfpi::dst_reg[0] = sfpi::setsgn(sfpi::reinterpret<sfpi::vFloat>(man), 1);
206-
}
207-
v_else
208-
{
209-
// extract mantissa
210-
sfpi::vInt man = exman8(in);
211-
// shift the mantissa by (23-exponent) to the right
212-
sfpi::vInt shift = exp - 23;
213-
man = sfpi::shft(sfpi::reinterpret<sfpi::vUInt>(man), shift);
214-
sfpi::dst_reg[0] = man;
215-
}
216-
v_endif
217-
}
218-
v_endif
219-
220-
sfpi::dst_reg++;
221-
}
222-
}
223-
224210
template <bool APPROXIMATION_MODE, int ITERATIONS>
225211
inline void _calculate_typecast_uint32_to_fp16b_()
226212
{
@@ -259,13 +245,12 @@ inline void _calculate_typecast_uint32_to_fp32_()
259245
template <bool APPROXIMATION_MODE, int ITERATIONS>
260246
inline void _calculate_typecast_uint16_to_uint32_()
261247
{
262-
#pragma GCC unroll 0
248+
#pragma GCC unroll 8
263249
for (int d = 0; d < ITERATIONS; d++)
264250
{
265-
TTI_SFPLOAD(p_sfpu::LREG0, InstrModLoadStore::LO16, ADDR_MOD_7, 0);
266-
TTI_SFPSTORE(p_sfpu::LREG0, InstrModLoadStore::INT32_2S_COMP, ADDR_MOD_7, 0);
267-
sfpi::dst_reg++;
251+
TTI_SFPLOADMACRO((0 << 2) | 0, InstrModLoadStore::LO16, ADDR_MOD_6, 0);
268252
}
253+
TTI_SFPNOP;
269254
}
270255

271256
template <bool APPROXIMATION_MODE, int ITERATIONS>
@@ -301,5 +286,76 @@ inline void _calculate_typecast_int32_to_uint16_()
301286
}
302287
}
303288

289+
template <bool APPROXIMATION_MODE>
290+
inline void _init_typecast_fp32_to_fp16b_()
291+
{
292+
constexpr int b = p_sfpu::LREG2;
293+
294+
sfpi::vConstIntPrgm0 = 1;
295+
sfpi::vConstIntPrgm1 = 0x7fff;
296+
297+
// InstructionTemplate[0]
298+
TTI_SFPSHFT2(-16 & 0xfff, 0, 12, 6); // SFPSHFT2_MOD1_SHFT_IMM
299+
300+
// InstructionTemplate[1]
301+
TTI_SFPIADD(0, p_sfpu::LREG13, 13, sfpi::SFPIADD_MOD1_CC_NONE);
302+
303+
// InstructionTemplate[2]
304+
TTI_SFPIADD(0, b, 14, sfpi::SFPIADD_MOD1_CC_NONE);
305+
306+
// Macro 0: [a]
307+
{
308+
constexpr uint simple_bits = 0x80 | 0x40 | (3 << 3) | (4 + 2);
309+
constexpr uint mad_bits = 0;
310+
constexpr uint round_bits = 0x80 | 0x00 | (0 << 3) | (4 + 0);
311+
constexpr uint store_bits = 0x00 | 0x00 | (3 << 3) | 3;
312+
313+
TTI_SFPLOADI(0, sfpi::SFPLOADI_MOD0_LOWER, (mad_bits << 8) | simple_bits);
314+
TTI_SFPLOADI(0, sfpi::SFPLOADI_MOD0_UPPER, (store_bits << 8) | round_bits);
315+
TTI_SFPCONFIG(0, 4 + 0, 0);
316+
}
317+
318+
// Macro 1: [b]
319+
{
320+
constexpr uint simple_bits = 0x80 | 0x00 | (1 << 3) | (4 + 1);
321+
constexpr uint mad_bits = 0;
322+
constexpr uint round_bits = 0;
323+
constexpr uint store_bits = 0x00 | 0x40 | (3 << 3) | 3;
324+
325+
TTI_SFPLOADI(0, sfpi::SFPLOADI_MOD0_LOWER, (mad_bits << 8) | simple_bits);
326+
TTI_SFPLOADI(0, sfpi::SFPLOADI_MOD0_UPPER, (store_bits << 8) | round_bits);
327+
TTI_SFPCONFIG(0, 4 + 1, 0);
328+
}
329+
330+
// Misc: {
331+
// StoreMod0: 2,
332+
// UsesLoadMod0ForStore: {1,0},
333+
// UnitDelayKind: {1,1}, (WaitForElapsedInstructions=1)
334+
// }
335+
TTI_SFPCONFIG(0x312, 8, 1);
336+
}
337+
338+
template <bool APPROXIMATION_MODE>
339+
inline void _init_typecast_uint16_to_uint32_()
340+
{
341+
{
342+
constexpr uint simple_bits = 0;
343+
constexpr uint mad_bits = 0;
344+
constexpr uint round_bits = 0;
345+
constexpr uint store_bits = 0x00 | 0x00 | (0 << 3) | 3;
346+
347+
TTI_SFPLOADI(0, sfpi::SFPLOADI_MOD0_LOWER, (mad_bits << 8) | simple_bits);
348+
TTI_SFPLOADI(0, sfpi::SFPLOADI_MOD0_UPPER, (store_bits << 8) | round_bits);
349+
TTI_SFPCONFIG(0, 4 + 0, 0);
350+
}
351+
352+
// Misc: {
353+
// StoreMod0: InstrModLoadStore::INT32,
354+
// UsesLoadMod0ForStore: {0},
355+
// UnitDelayKind: {1}, (WaitForElapsedInstructions=1)
356+
// }
357+
TTI_SFPCONFIG(0x100 | InstrModLoadStore::INT32, 8, 1);
358+
}
359+
304360
} // namespace sfpu
305361
} // namespace ckernel

tt_llk_blackhole/llk_lib/llk_math_eltwise_unary_sfpu.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ inline void eltwise_unary_sfpu_configure_addrmod()
5050
}
5151
.set(ADDR_MOD_6);
5252
}
53+
54+
if constexpr (sfpu_op == SfpuType::typecast)
55+
{
56+
addr_mod_t {
57+
.srca = {.incr = 0},
58+
.srcb = {.incr = 0},
59+
.dest = {.incr = 2},
60+
}
61+
.set(ADDR_MOD_6);
62+
}
5363
}
5464

5565
inline void eltwise_unary_sfpu_configure_mop();

0 commit comments

Comments
 (0)