Skip to content

Commit 93a9632

Browse files
Ryan McKennacopybara-github
authored andcommitted
Add RandomAllocationSampling to jax_privacy.batch_selection
PiperOrigin-RevId: 924347562
1 parent 42f75fe commit 93a9632

2 files changed

Lines changed: 125 additions & 1 deletion

File tree

jax_privacy/batch_selection.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
from jax_privacy import sharding_utils
6262
import numpy as np
6363

64-
6564
RngType = np.random.Generator | int | None
6665

6766

@@ -301,6 +300,45 @@ def batch_iterator(
301300
yield groups[i % self.cycle_length]
302301

303302

303+
@dataclasses.dataclass(frozen=True)
304+
class RandomAllocationSampling(BatchSelectionStrategy):
305+
"""Implements k-out-of-t random allocation (aka balanced-iteration sampling).
306+
307+
Each example independently selects exactly k steps (out of iterations total)
308+
to participate in, uniformly at random. See https://arxiv.org/abs/2602.17284
309+
and https://arxiv.org/abs/2605.07072 for details about this strategy.
310+
311+
Formal guarantees of the batch_iterator:
312+
- All batches consist of indices in the range [0, num_examples).
313+
- Each example appears in exactly k of the iterations batches, chosen
314+
uniformly at random without replacement from [0, iterations).
315+
- The allocation for each example is independent of all other examples.
316+
317+
Attributes:
318+
num_participations: The number of steps each example participates in (k).
319+
iterations: The total number of iterations / batches to generate (t).
320+
"""
321+
322+
num_participations: int
323+
iterations: int
324+
325+
def batch_iterator(
326+
self, num_examples: int, rng: RngType = None
327+
) -> Iterator[np.ndarray]:
328+
rng = np.random.default_rng(rng)
329+
dtype = np.min_scalar_type(-num_examples)
330+
# At step i, each example with r remaining participations and (t-i)
331+
# remaining steps participates with probability r/(t-i). This is equivalent
332+
# to each example choosing k steps uniformly without replacement, but uses
333+
# only O(n) space instead of O(n*k).
334+
remaining = np.full(num_examples, self.num_participations)
335+
for i in range(self.iterations):
336+
probs = remaining / (self.iterations - i)
337+
mask = rng.random(num_examples) < probs
338+
yield np.where(mask)[0].astype(dtype)
339+
remaining -= mask
340+
341+
304342
@dataclasses.dataclass(frozen=True)
305343
class FixedBatchSampling(BatchSelectionStrategy):
306344
"""Implements fixed-size batch sampling.

tests/batch_selection_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,92 @@ def test_balls_in_bins_sampling_with_large_cycle_length(self):
242242
_check_no_repeated_indices(batches[:cycle_length])
243243
_check_cyclic_property(batches, cycle_length)
244244

245+
@parameterized.product(
246+
num_examples=[10, 100],
247+
num_participations=[1, 3, 5],
248+
iterations=[10, 20],
249+
)
250+
def test_random_allocation_sampling(
251+
self, num_examples, num_participations, iterations
252+
):
253+
"""Tests that random allocation gives exact k participations per example."""
254+
strategy = batch_selection.RandomAllocationSampling(
255+
num_participations=num_participations,
256+
iterations=iterations,
257+
)
258+
batches = list(strategy.batch_iterator(num_examples, rng=0))
259+
self.assertLen(batches, iterations)
260+
_check_element_range(batches, num_examples)
261+
_check_signed_indices(batches)
262+
_check_max_participation(batches, num_participations)
263+
# Each example must appear in *exactly* k batches.
264+
all_indices = np.concatenate(batches)
265+
counts = collections.Counter(int(x) for x in all_indices)
266+
for example_idx in range(num_examples):
267+
self.assertEqual(counts[example_idx], num_participations)
268+
# Within each batch, no example should appear twice.
269+
for batch in batches:
270+
self.assertEqual(len(batch), len(set(batch.tolist())))
271+
272+
def test_random_allocation_sampling_k_equals_zero(self):
273+
"""All batches should be empty when num_participations=0."""
274+
strategy = batch_selection.RandomAllocationSampling(
275+
num_participations=0,
276+
iterations=5,
277+
)
278+
batches = list(strategy.batch_iterator(10, rng=0))
279+
self.assertLen(batches, 5)
280+
for batch in batches:
281+
self.assertEmpty(batch)
282+
283+
def test_random_allocation_sampling_k_equals_t(self):
284+
"""Every example should appear in every batch when k == t."""
285+
strategy = batch_selection.RandomAllocationSampling(
286+
num_participations=5,
287+
iterations=5,
288+
)
289+
batches = list(strategy.batch_iterator(10, rng=0))
290+
self.assertLen(batches, 5)
291+
for batch in batches:
292+
self.assertLen(batch, 10)
293+
_check_element_range(batches, 10)
294+
295+
def test_random_allocation_sampling_expected_batch_size(self):
296+
"""Average batch size should be exactly n*k/t."""
297+
num_examples = 1000
298+
num_participations = 3
299+
iterations = 50
300+
strategy = batch_selection.RandomAllocationSampling(
301+
num_participations=num_participations,
302+
iterations=iterations,
303+
)
304+
batches = list(strategy.batch_iterator(num_examples, rng=0))
305+
expected_batch_size = num_examples * num_participations / iterations
306+
actual_mean = sum(len(b) for b in batches) / iterations
307+
self.assertAlmostEqual(actual_mean, expected_batch_size, delta=1e-5)
308+
309+
def test_random_allocation_sampling_is_deterministic(self):
310+
"""RandomAllocationSampling should respect the provided RNG."""
311+
strategy = batch_selection.RandomAllocationSampling(
312+
num_participations=2,
313+
iterations=10,
314+
)
315+
batches_a = list(strategy.batch_iterator(50, rng=0))
316+
batches_b = list(strategy.batch_iterator(50, rng=0))
317+
for batch_a, batch_b in zip(batches_a, batches_b, strict=True):
318+
np.testing.assert_array_equal(batch_a, batch_b)
319+
320+
def test_random_allocation_sampling_zero_examples(self):
321+
"""Should produce empty batches when there are no examples."""
322+
strategy = batch_selection.RandomAllocationSampling(
323+
num_participations=2,
324+
iterations=5,
325+
)
326+
batches = list(strategy.batch_iterator(0, rng=0))
327+
self.assertLen(batches, 5)
328+
for batch in batches:
329+
self.assertEmpty(batch)
330+
245331
def test_cyclic_poisson_sampling_independent_is_deterministic(self):
246332
"""CyclicPoissonSampling should respect the provided RNG."""
247333
strategy = batch_selection.CyclicPoissonSampling(

0 commit comments

Comments
 (0)