@@ -144,7 +144,7 @@ def kda_mixer_config(kda_config):
144144 "heads" : num_heads ,
145145 "head_dim" : head_dim ,
146146 "convolution_layer" : {"kernel_size" : 4 },
147- "normalization" : {"epsilon" : 1e-5 },
147+ "normalization" : {"epsilon" : 1e-5 , "activation" : "sigmoid" },
148148 }
149149
150150
@@ -1088,9 +1088,8 @@ def test_vs_fla(
10881088 fla_cache = FLACache ()
10891089 apriel_cache = Apriel2Cache (make_apriel2_config (kda_hidden_size , kda_mixer_config ))
10901090
1091- # Force chunk mode for prefill
1092- fla_kda .mode = "chunk"
1093- apriel_kda .mode = "chunk"
1091+ # Match Apriel2's mode selection: fused_recurrent for seq_len<=64 in eval
1092+ fla_kda .mode = "fused_recurrent"
10941093
10951094 # ========== PHASE 1: Initial Prefill ==========
10961095 prefill_input = hidden_states [:, :prefill_len , :]
@@ -1125,7 +1124,6 @@ def test_vs_fla(
11251124
11261125 # ========== PHASE 2: Decode (single tokens) ==========
11271126 fla_kda .mode = "fused_recurrent"
1128- apriel_kda .mode = "fused_recurrent"
11291127
11301128 for i in range (decode_steps ):
11311129 pos = prefill_len + i
@@ -1160,9 +1158,7 @@ def test_vs_fla(
11601158 )
11611159
11621160 # ========== PHASE 3: Prefill again (decode→prefill transition) ==========
1163- # FLA KDA correctly uses initial_state in chunk mode, so this should match
1164- fla_kda .mode = "chunk"
1165- apriel_kda .mode = "chunk"
1161+ fla_kda .mode = "fused_recurrent"
11661162
11671163 prefill2_start = prefill_len + decode_steps
11681164 prefill2_input = hidden_states [:, prefill2_start : prefill2_start + prefill2_len , :]
0 commit comments