Skip to content

Commit 1860569

Browse files
committed
perf: TT prefetch + stack-bounded move scoring + L2/L3 NNUE SIMD
Three pure-perf changes with bit-identical search behavior verified against baseline (same node counts and PVs at depth 14 on three test positions, all 137 tests pass): - tt: add prefetch() method, called after make_move and make_null_move to warm the cache line for the child's TT probe. - search: replace per-node Vec<(Move, Score)> in quiescence captures and order_moves with stack-bounded arrays. Avoids millions of small heap allocs per search. - nnue: vectorize L2 and L3 forward passes the same way L1 was, with bit-exact tests. Forward-pass microbench: 661ns -> 565ns (~15%). Adds scripts/build-pgo.sh for opt-in profile-guided builds (requires rustup component add llvm-tools-preview).
1 parent d97c9b8 commit 1860569

5 files changed

Lines changed: 362 additions & 30 deletions

File tree

scripts/build-pgo.sh

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env bash
2+
#
3+
# Two-pass profile-guided optimization (PGO) build for focalors.
4+
#
5+
# Pass 1 builds an instrumented binary, runs it through a representative
6+
# search workload, and captures profile data. Pass 2 rebuilds with the
7+
# captured profile so the optimizer can specialize hot paths.
8+
#
9+
# Output: a PGO-optimized binary at target/release/focalors (overwrites the
10+
# regular release build). PGO is purely a compiler optimization; no source
11+
# changes, no semantic changes — search results are identical to a regular
12+
# release build, just typically 10–20% faster.
13+
#
14+
# Requires: `rustup component add llvm-tools-preview`
15+
#
16+
# Usage:
17+
# ./scripts/build-pgo.sh
18+
19+
set -euo pipefail
20+
21+
cd "$(git rev-parse --show-toplevel)"
22+
23+
PROFDATA_DIR="$(pwd)/target/pgo-data"
24+
PROFDATA_FILE="$PROFDATA_DIR/merged.profdata"
25+
26+
# Locate llvm-profdata from the active rustup toolchain.
27+
HOST="$(rustc -vV | sed -n 's|host: ||p')"
28+
LLVM_PROFDATA="$(rustc --print sysroot)/lib/rustlib/${HOST}/bin/llvm-profdata"
29+
30+
if [[ ! -x "$LLVM_PROFDATA" ]]; then
31+
echo "Error: llvm-profdata not found at $LLVM_PROFDATA"
32+
echo
33+
echo "Install the llvm-tools-preview component:"
34+
echo " rustup component add llvm-tools-preview"
35+
exit 1
36+
fi
37+
38+
# Clean prior profile data so stale .profraw files don't bias the merge.
39+
rm -rf "$PROFDATA_DIR"
40+
mkdir -p "$PROFDATA_DIR"
41+
42+
echo "=== Pass 1: instrumented build ==="
43+
RUSTFLAGS="-Cprofile-generate=$PROFDATA_DIR" \
44+
cargo build --release --bin focalors
45+
46+
echo "=== Pass 1: running profiling workload ==="
47+
# Representative workload — three searches at depth 12 covering opening,
48+
# tactical, and middlegame patterns. Enough variety to give the optimizer
49+
# good signal without taking forever.
50+
./target/release/focalors uci > /dev/null 2>&1 << 'EOF'
51+
uci
52+
isready
53+
position startpos
54+
go depth 12
55+
position fen r1bqkb1r/pppp1ppp/2n2n2/4p3/2B1P3/5N2/PPPP1PPP/RNBQK2R w KQkq - 4 4
56+
go depth 12
57+
position fen r2q1rk1/ppp1bppp/2np1n2/2b1p3/2B1P3/2NP1N2/PPP1QPPP/R1B2RK1 w - - 0 8
58+
go depth 12
59+
quit
60+
EOF
61+
62+
echo "=== Merging profile data ==="
63+
"$LLVM_PROFDATA" merge -o "$PROFDATA_FILE" "$PROFDATA_DIR"
64+
65+
echo "=== Pass 2: PGO-optimized build ==="
66+
RUSTFLAGS="-Cprofile-use=$PROFDATA_FILE" \
67+
cargo build --release --bin focalors
68+
69+
echo
70+
echo "Done. PGO binary at target/release/focalors"
71+
echo "Verify with: cargo test --release --locked"

src/nnue/mod.rs

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -174,26 +174,13 @@ fn forward(acc: &Accumulator, side_to_move: Color, net: &Network) -> Score {
174174
simd::l1_forward(&l1_input, net, &mut l1_out);
175175

176176
// Step 4: Layer 2 — l1_out[32] * l2_weights[32][32] + l2_biases[32]
177+
// Dispatches to AVX2 implementation when available; bit-exact equivalent.
177178
let mut l2_out = [0i32; L2_SIZE];
178-
for j in 0..L2_SIZE {
179-
let mut sum = net.l2_biases[j];
180-
for i in 0..L1_SIZE {
181-
sum += l1_out[i] * net.l2_weight(i, j) as i32;
182-
}
183-
l2_out[j] = sum.clamp(0, QB);
184-
}
179+
simd::l2_forward(&l1_out, net, &mut l2_out);
185180

186-
// Step 5: Output layer — dot product + bias
187-
let mut output = net.l3_bias;
188-
for j in 0..L2_SIZE {
189-
output += l2_out[j] * net.l3_weights[j] as i32;
190-
}
191-
192-
// Scale to centipawns.
193-
// The network output is in internal quantized units.
194-
// Divide by QB to account for hidden layer quantization.
195-
// The result is roughly in centipawn-ish units depending on training.
196-
output / QB
181+
// Step 5: Output layer — dot product + bias, scaled to centipawn-ish units.
182+
// Dispatches to AVX2 implementation when available; bit-exact equivalent.
183+
simd::l3_forward(&l2_out, net)
197184
}
198185

199186
/// The default net shipped with the binary. Embedded at compile time.

src/nnue/simd.rs

Lines changed: 237 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
//! via `is_x86_feature_detected!` and dispatches to scalar fallback
55
//! on non-AVX2 CPUs.
66
7-
use super::network::{L1_INPUT, L1_SIZE, QA, QB, Network};
7+
use super::network::{L1_INPUT, L1_SIZE, L2_SIZE, QA, QB, Network};
88

99
/// Cached result of CPU feature detection.
1010
use std::sync::OnceLock;
@@ -133,6 +133,175 @@ pub fn l1_forward(
133133
l1_forward_scalar(l1_input, net, l1_out);
134134
}
135135

136+
// ════════════════════════════════════════════════════════════════════════════
137+
// Layer 2: 32 i32 inputs → 32 i32 outputs, i8 weights, clamp to [0, QB].
138+
// No division (unlike L1); the QA scaling already happened at L1.
139+
// ════════════════════════════════════════════════════════════════════════════
140+
141+
/// Scalar L2:
142+
/// l2_out[j] = clamp(bias[j] + sum_i(l1_out[i] * w[i][j]), 0, QB)
143+
#[inline]
144+
pub fn l2_forward_scalar(
145+
l1_out: &[i32; L1_SIZE],
146+
net: &Network,
147+
l2_out: &mut [i32; L2_SIZE],
148+
) {
149+
for j in 0..L2_SIZE {
150+
let mut sum = net.l2_biases[j];
151+
for i in 0..L1_SIZE {
152+
sum += l1_out[i] * net.l2_weight(i, j) as i32;
153+
}
154+
l2_out[j] = sum.clamp(0, QB);
155+
}
156+
}
157+
158+
/// AVX2 L2 — bit-exact equivalent to `l2_forward_scalar`. Same row-major
159+
/// `[L1_SIZE][L2_SIZE]` weight layout as L1, so the same broadcast-and-
160+
/// accumulate pattern applies, just with one less inner dimension.
161+
///
162+
/// # Safety
163+
/// Caller must ensure AVX2 support. Enforced by `#[target_feature]`.
164+
#[cfg(target_arch = "x86_64")]
165+
#[target_feature(enable = "avx2")]
166+
pub unsafe fn l2_forward_avx2(
167+
l1_out: &[i32; L1_SIZE],
168+
net: &Network,
169+
l2_out: &mut [i32; L2_SIZE],
170+
) { unsafe {
171+
use std::arch::x86_64::*;
172+
173+
// 4 accumulators initialized with the 32 i32 biases.
174+
let bias_ptr = net.l2_biases.as_ptr() as *const __m256i;
175+
let mut acc0 = _mm256_loadu_si256(bias_ptr.add(0));
176+
let mut acc1 = _mm256_loadu_si256(bias_ptr.add(1));
177+
let mut acc2 = _mm256_loadu_si256(bias_ptr.add(2));
178+
let mut acc3 = _mm256_loadu_si256(bias_ptr.add(3));
179+
180+
let weights_ptr = net.l2_weights.as_ptr();
181+
182+
for i in 0..L1_SIZE {
183+
let x = _mm256_set1_epi32(l1_out[i]);
184+
185+
// Load 32 i8 weights for input row i.
186+
let w_i8 = _mm256_loadu_si256(weights_ptr.add(i * L2_SIZE) as *const __m256i);
187+
188+
// Sign-extend i8 → i32 (same dance as L1).
189+
let lo_i8 = _mm256_extracti128_si256::<0>(w_i8);
190+
let hi_i8 = _mm256_extracti128_si256::<1>(w_i8);
191+
let lo_i16 = _mm256_cvtepi8_epi16(lo_i8);
192+
let hi_i16 = _mm256_cvtepi8_epi16(hi_i8);
193+
let w0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256::<0>(lo_i16));
194+
let w1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256::<1>(lo_i16));
195+
let w2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256::<0>(hi_i16));
196+
let w3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256::<1>(hi_i16));
197+
198+
acc0 = _mm256_add_epi32(acc0, _mm256_mullo_epi32(x, w0));
199+
acc1 = _mm256_add_epi32(acc1, _mm256_mullo_epi32(x, w1));
200+
acc2 = _mm256_add_epi32(acc2, _mm256_mullo_epi32(x, w2));
201+
acc3 = _mm256_add_epi32(acc3, _mm256_mullo_epi32(x, w3));
202+
}
203+
204+
// Store sums and clamp scalar (no division for L2).
205+
let mut tmp = [0i32; L2_SIZE];
206+
let tmp_ptr = tmp.as_mut_ptr() as *mut __m256i;
207+
_mm256_storeu_si256(tmp_ptr.add(0), acc0);
208+
_mm256_storeu_si256(tmp_ptr.add(1), acc1);
209+
_mm256_storeu_si256(tmp_ptr.add(2), acc2);
210+
_mm256_storeu_si256(tmp_ptr.add(3), acc3);
211+
212+
for j in 0..L2_SIZE {
213+
l2_out[j] = tmp[j].clamp(0, QB);
214+
}
215+
}}
216+
217+
/// Dispatch to the best available L2 implementation.
218+
#[inline]
219+
pub fn l2_forward(
220+
l1_out: &[i32; L1_SIZE],
221+
net: &Network,
222+
l2_out: &mut [i32; L2_SIZE],
223+
) {
224+
#[cfg(target_arch = "x86_64")]
225+
{
226+
if has_avx2() {
227+
unsafe { l2_forward_avx2(l1_out, net, l2_out); }
228+
return;
229+
}
230+
}
231+
l2_forward_scalar(l1_out, net, l2_out);
232+
}
233+
234+
// ════════════════════════════════════════════════════════════════════════════
235+
// Layer 3: 32 i32 inputs → 1 i32 output, i8 weights, divide by QB.
236+
// Single dot product, no per-output broadcast needed.
237+
// ════════════════════════════════════════════════════════════════════════════
238+
239+
/// Scalar L3:
240+
/// l3_out = (bias + sum_j(l2_out[j] * w[j])) / QB
241+
#[inline]
242+
pub fn l3_forward_scalar(l2_out: &[i32; L2_SIZE], net: &Network) -> i32 {
243+
let mut output = net.l3_bias;
244+
for j in 0..L2_SIZE {
245+
output += l2_out[j] * net.l3_weights[j] as i32;
246+
}
247+
output / QB
248+
}
249+
250+
/// AVX2 L3 — bit-exact equivalent to `l3_forward_scalar`. Loads all 32 i32
251+
/// inputs and 32 i8 weights, sign-extends weights to i32, multiplies pairwise,
252+
/// and reduces horizontally to one scalar.
253+
///
254+
/// # Safety
255+
/// Caller must ensure AVX2 support. Enforced by `#[target_feature]`.
256+
#[cfg(target_arch = "x86_64")]
257+
#[target_feature(enable = "avx2")]
258+
pub unsafe fn l3_forward_avx2(l2_out: &[i32; L2_SIZE], net: &Network) -> i32 { unsafe {
259+
use std::arch::x86_64::*;
260+
261+
// Load l2_out: 32 i32 = 4 × __m256i.
262+
let l2_ptr = l2_out.as_ptr() as *const __m256i;
263+
let l2_0 = _mm256_loadu_si256(l2_ptr.add(0));
264+
let l2_1 = _mm256_loadu_si256(l2_ptr.add(1));
265+
let l2_2 = _mm256_loadu_si256(l2_ptr.add(2));
266+
let l2_3 = _mm256_loadu_si256(l2_ptr.add(3));
267+
268+
// Load 32 i8 weights, sign-extend to 4 × i32 vectors.
269+
let w_i8 = _mm256_loadu_si256(net.l3_weights.as_ptr() as *const __m256i);
270+
let lo_i8 = _mm256_extracti128_si256::<0>(w_i8);
271+
let hi_i8 = _mm256_extracti128_si256::<1>(w_i8);
272+
let lo_i16 = _mm256_cvtepi8_epi16(lo_i8);
273+
let hi_i16 = _mm256_cvtepi8_epi16(hi_i8);
274+
let w0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256::<0>(lo_i16));
275+
let w1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256::<1>(lo_i16));
276+
let w2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256::<0>(hi_i16));
277+
let w3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256::<1>(hi_i16));
278+
279+
// Pairwise multiply, then sum the 4 product vectors into one.
280+
let s = _mm256_add_epi32(
281+
_mm256_add_epi32(_mm256_mullo_epi32(l2_0, w0), _mm256_mullo_epi32(l2_1, w1)),
282+
_mm256_add_epi32(_mm256_mullo_epi32(l2_2, w2), _mm256_mullo_epi32(l2_3, w3)),
283+
);
284+
285+
// Horizontal sum of 8 i32 lanes — store and reduce in scalar (cheap for 8 elems).
286+
let mut tmp = [0i32; 8];
287+
_mm256_storeu_si256(tmp.as_mut_ptr() as *mut __m256i, s);
288+
let sum = tmp[0] + tmp[1] + tmp[2] + tmp[3] + tmp[4] + tmp[5] + tmp[6] + tmp[7];
289+
290+
(net.l3_bias + sum) / QB
291+
}}
292+
293+
/// Dispatch to the best available L3 implementation.
294+
#[inline]
295+
pub fn l3_forward(l2_out: &[i32; L2_SIZE], net: &Network) -> i32 {
296+
#[cfg(target_arch = "x86_64")]
297+
{
298+
if has_avx2() {
299+
return unsafe { l3_forward_avx2(l2_out, net) };
300+
}
301+
}
302+
l3_forward_scalar(l2_out, net)
303+
}
304+
136305
#[cfg(test)]
137306
mod tests {
138307
use super::*;
@@ -186,4 +355,71 @@ mod tests {
186355
assert!(v >= 0 && v <= QB);
187356
}
188357
}
358+
359+
/// Ensures the AVX2 L2 implementation is bit-exactly equivalent to the scalar code.
360+
#[test]
361+
#[cfg(target_arch = "x86_64")]
362+
fn l2_avx2_matches_scalar() {
363+
if !has_avx2() {
364+
eprintln!("Skipping: AVX2 not available on this CPU");
365+
return;
366+
}
367+
368+
let net = Network::random_for_test();
369+
370+
let mut state = 0xCAFEBABE_u64;
371+
for trial in 0..100 {
372+
// L1 outputs are post-clamp, so they live in [0, QB].
373+
let mut l1_out = [0i32; L1_SIZE];
374+
for v in &mut l1_out {
375+
state ^= state << 13;
376+
state ^= state >> 7;
377+
state ^= state << 17;
378+
*v = (state % (QB as u64 + 1)) as i32;
379+
}
380+
381+
let mut out_scalar = [0i32; L2_SIZE];
382+
let mut out_simd = [0i32; L2_SIZE];
383+
384+
l2_forward_scalar(&l1_out, &net, &mut out_scalar);
385+
unsafe { l2_forward_avx2(&l1_out, &net, &mut out_simd); }
386+
387+
assert_eq!(
388+
out_scalar, out_simd,
389+
"Trial {trial}: L2 AVX2 output diverges from scalar"
390+
);
391+
}
392+
}
393+
394+
/// Ensures the AVX2 L3 implementation is bit-exactly equivalent to the scalar code.
395+
#[test]
396+
#[cfg(target_arch = "x86_64")]
397+
fn l3_avx2_matches_scalar() {
398+
if !has_avx2() {
399+
eprintln!("Skipping: AVX2 not available on this CPU");
400+
return;
401+
}
402+
403+
let net = Network::random_for_test();
404+
405+
let mut state = 0xDEADBEEF_u64;
406+
for trial in 0..100 {
407+
// L2 outputs are post-clamp, so they live in [0, QB].
408+
let mut l2_out = [0i32; L2_SIZE];
409+
for v in &mut l2_out {
410+
state ^= state << 13;
411+
state ^= state >> 7;
412+
state ^= state << 17;
413+
*v = (state % (QB as u64 + 1)) as i32;
414+
}
415+
416+
let out_scalar = l3_forward_scalar(&l2_out, &net);
417+
let out_simd = unsafe { l3_forward_avx2(&l2_out, &net) };
418+
419+
assert_eq!(
420+
out_scalar, out_simd,
421+
"Trial {trial}: L3 AVX2 output diverges from scalar"
422+
);
423+
}
424+
}
189425
}

0 commit comments

Comments
 (0)