@@ -242,6 +242,92 @@ def test_balls_in_bins_sampling_with_large_cycle_length(self):
242242 _check_no_repeated_indices (batches [:cycle_length ])
243243 _check_cyclic_property (batches , cycle_length )
244244
245+ @parameterized .product (
246+ num_examples = [10 , 100 ],
247+ num_participations = [1 , 3 , 5 ],
248+ iterations = [10 , 20 ],
249+ )
250+ def test_random_allocation_sampling (
251+ self , num_examples , num_participations , iterations
252+ ):
253+ """Tests that random allocation gives exact k participations per example."""
254+ strategy = batch_selection .RandomAllocationSampling (
255+ num_participations = num_participations ,
256+ iterations = iterations ,
257+ )
258+ batches = list (strategy .batch_iterator (num_examples , rng = 0 ))
259+ self .assertLen (batches , iterations )
260+ _check_element_range (batches , num_examples )
261+ _check_signed_indices (batches )
262+ _check_max_participation (batches , num_participations )
263+ # Each example must appear in *exactly* k batches.
264+ all_indices = np .concatenate (batches )
265+ counts = collections .Counter (int (x ) for x in all_indices )
266+ for example_idx in range (num_examples ):
267+ self .assertEqual (counts [example_idx ], num_participations )
268+ # Within each batch, no example should appear twice.
269+ for batch in batches :
270+ self .assertEqual (len (batch ), len (set (batch .tolist ())))
271+
272+ def test_random_allocation_sampling_k_equals_zero (self ):
273+ """All batches should be empty when num_participations=0."""
274+ strategy = batch_selection .RandomAllocationSampling (
275+ num_participations = 0 ,
276+ iterations = 5 ,
277+ )
278+ batches = list (strategy .batch_iterator (10 , rng = 0 ))
279+ self .assertLen (batches , 5 )
280+ for batch in batches :
281+ self .assertEmpty (batch )
282+
283+ def test_random_allocation_sampling_k_equals_t (self ):
284+ """Every example should appear in every batch when k == t."""
285+ strategy = batch_selection .RandomAllocationSampling (
286+ num_participations = 5 ,
287+ iterations = 5 ,
288+ )
289+ batches = list (strategy .batch_iterator (10 , rng = 0 ))
290+ self .assertLen (batches , 5 )
291+ for batch in batches :
292+ self .assertLen (batch , 10 )
293+ _check_element_range (batches , 10 )
294+
295+ def test_random_allocation_sampling_expected_batch_size (self ):
296+ """Average batch size should be approximately n*k/t."""
297+ num_examples = 1000
298+ num_participations = 3
299+ iterations = 50
300+ strategy = batch_selection .RandomAllocationSampling (
301+ num_participations = num_participations ,
302+ iterations = iterations ,
303+ )
304+ batches = list (strategy .batch_iterator (num_examples , rng = 0 ))
305+ expected_batch_size = num_examples * num_participations / iterations
306+ actual_mean = sum (len (b ) for b in batches ) / iterations
307+ self .assertAlmostEqual (actual_mean , expected_batch_size , delta = 0.01 )
308+
309+ def test_random_allocation_sampling_is_deterministic (self ):
310+ """RandomAllocationSampling should respect the provided RNG."""
311+ strategy = batch_selection .RandomAllocationSampling (
312+ num_participations = 2 ,
313+ iterations = 10 ,
314+ )
315+ batches_a = list (strategy .batch_iterator (50 , rng = 0 ))
316+ batches_b = list (strategy .batch_iterator (50 , rng = 0 ))
317+ for batch_a , batch_b in zip (batches_a , batches_b , strict = True ):
318+ np .testing .assert_array_equal (batch_a , batch_b )
319+
320+ def test_random_allocation_sampling_zero_examples (self ):
321+ """Should produce empty batches when there are no examples."""
322+ strategy = batch_selection .RandomAllocationSampling (
323+ num_participations = 2 ,
324+ iterations = 5 ,
325+ )
326+ batches = list (strategy .batch_iterator (0 , rng = 0 ))
327+ self .assertLen (batches , 5 )
328+ for batch in batches :
329+ self .assertEmpty (batch )
330+
245331 def test_cyclic_poisson_sampling_independent_is_deterministic (self ):
246332 """CyclicPoissonSampling should respect the provided RNG."""
247333 strategy = batch_selection .CyclicPoissonSampling (
0 commit comments