Skip to content

Commit b901d8e

Browse files
jlamypoirierclaude
andauthored
Fix KDA equivalence tests and add accelerate dependency (#488)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent a6679fa commit b901d8e

2 files changed

Lines changed: 5 additions & 8 deletions

File tree

fast_llm_external_models/tests/test_apriel2/test_mixer_equivalence.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def kda_mixer_config(kda_config):
144144
"heads": num_heads,
145145
"head_dim": head_dim,
146146
"convolution_layer": {"kernel_size": 4},
147-
"normalization": {"epsilon": 1e-5},
147+
"normalization": {"epsilon": 1e-5, "activation": "sigmoid"},
148148
}
149149

150150

@@ -1088,9 +1088,8 @@ def test_vs_fla(
10881088
fla_cache = FLACache()
10891089
apriel_cache = Apriel2Cache(make_apriel2_config(kda_hidden_size, kda_mixer_config))
10901090

1091-
# Force chunk mode for prefill
1092-
fla_kda.mode = "chunk"
1093-
apriel_kda.mode = "chunk"
1091+
# Match Apriel2's mode selection: fused_recurrent for seq_len<=64 in eval
1092+
fla_kda.mode = "fused_recurrent"
10941093

10951094
# ========== PHASE 1: Initial Prefill ==========
10961095
prefill_input = hidden_states[:, :prefill_len, :]
@@ -1125,7 +1124,6 @@ def test_vs_fla(
11251124

11261125
# ========== PHASE 2: Decode (single tokens) ==========
11271126
fla_kda.mode = "fused_recurrent"
1128-
apriel_kda.mode = "fused_recurrent"
11291127

11301128
for i in range(decode_steps):
11311129
pos = prefill_len + i
@@ -1160,9 +1158,7 @@ def test_vs_fla(
11601158
)
11611159

11621160
# ========== PHASE 3: Prefill again (decode→prefill transition) ==========
1163-
# FLA KDA correctly uses initial_state in chunk mode, so this should match
1164-
fla_kda.mode = "chunk"
1165-
apriel_kda.mode = "chunk"
1161+
fla_kda.mode = "fused_recurrent"
11661162

11671163
prefill2_start = prefill_len + decode_steps
11681164
prefill2_input = hidden_states[:, prefill2_start : prefill2_start + prefill2_len, :]

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ OPTIONAL =
4444
# Huggingface tools
4545
HUGGINGFACE =
4646
transformers>=4.57.3,<5.0.0
47+
accelerate>=1.4.0
4748
hf-transfer>=0.1.9
4849
datasets>=4.4.1
4950
huggingface-hub>=0.36.0

0 commit comments

Comments
 (0)