Skip to content

Commit e286fd6

Browse files
JAX Privacy Teamcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 917296085
1 parent c3ddf25 commit e286fd6

2 files changed

Lines changed: 261 additions & 362 deletions

File tree

jax_privacy/experimental/monte_carlo/sample_generation.py

Lines changed: 39 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,14 @@ def _all_balls_in_bins_modes(
7575
return sp.linalg.toeplitz(first_mode, zeros_vector)
7676

7777

78-
def _generate_zero_mean_sample(
79-
iterations: int,
80-
noise_multiplier: float,
81-
rng: np.random.Generator,
82-
) -> np.ndarray:
83-
"""Generates a sample from a zero-mean Gaussian."""
84-
return rng.normal(loc=0.0, scale=noise_multiplier, size=iterations)
85-
86-
8778
def _generate_balls_in_bins_sample(
8879
iterations: int,
8980
cycle_length: int,
9081
noise_multiplier: float,
9182
c_col: np.ndarray,
9283
seed: Seed = None,
9384
positive_sample: bool = True,
94-
num_samples: int | None = None,
85+
num_samples: int = 1,
9586
) -> np.ndarray:
9687
"""Sample from the dominating pair for DP-BandMF using balls-in-bins sampling.
9788
@@ -109,8 +100,7 @@ def _generate_balls_in_bins_sample(
109100
pair corresponding to the case where the sensitive example is included.
110101
Otherwise, we sample from the other case in the dominating pair, where the
111102
sensitive example is not included.
112-
num_samples: The number of samples to generate. None means generate a single
113-
sample.
103+
num_samples: The number of samples to generate.
114104
115105
Returns:
116106
A sample from the dominating PLD for DP-BandMF using balls-in-bins sampling.
@@ -121,7 +111,6 @@ def _generate_balls_in_bins_sample(
121111
raise ValueError('cycle_length must be positive.')
122112
if c_col.size > iterations:
123113
c_col = c_col[:iterations]
124-
num_samples_or_one = num_samples or 1
125114
rng = np.random.default_rng(seed)
126115
if positive_sample:
127116
# Add Cx to the Gaussian noise, where x is a vector which is 1 in every
@@ -131,13 +120,12 @@ def _generate_balls_in_bins_sample(
131120
iterations, cycle_length, tuple(c_col)
132121
)
133122
counts = rng.multinomial(
134-
n=num_samples_or_one, pvals=np.full(cycle_length, 1.0 / cycle_length)
123+
n=num_samples, pvals=np.full(cycle_length, 1.0 / cycle_length)
135124
)
136125
mode = np.repeat(possible_modes, repeats=counts, axis=1)
137126
else:
138-
mode = np.zeros((iterations, num_samples_or_one))
139-
sample = rng.normal(loc=mode, scale=noise_multiplier)
140-
return sample[:, 0] if num_samples is None else sample
127+
mode = np.zeros((iterations, num_samples))
128+
return rng.normal(loc=mode, scale=noise_multiplier)
141129

142130

143131
def _sample_b_min_sep_positive_modes_no_truncation(
@@ -301,7 +289,7 @@ def _generate_b_min_sep_sample(
301289
c_col: np.ndarray,
302290
seed: Seed = None,
303291
positive_sample: bool = True,
304-
num_samples: int | None = None,
292+
num_samples: int = 1,
305293
dataset_size: int | None = None,
306294
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
307295
"""Samples from the dominating pair for DP-BandMF using b-min-sep sampling.
@@ -320,8 +308,7 @@ def _generate_b_min_sep_sample(
320308
pair corresponding to the case where the sensitive example is included.
321309
Otherwise, we sample from the other case in the dominating pair, where the
322310
sensitive example is not included.
323-
num_samples: The number of samples to generate. None means generate a single
324-
sample.
311+
num_samples: The number of samples to generate.
325312
dataset_size: The size of the dataset. Only used if
326313
strategy.truncated_batch_size is not None.
327314
@@ -335,29 +322,22 @@ def _generate_b_min_sep_sample(
335322
if c_col.size > strategy.iterations:
336323
c_col = c_col[: strategy.iterations]
337324
rng = np.random.default_rng(seed)
338-
# For simplicity, if num_samples is None, we still create a 2D array.
339-
num_samples_or_one = num_samples or 1
340325
if strategy.truncated_batch_size:
341326
mode, rest_batch_sizes = _sample_b_min_sep_modes_with_truncation(
342-
strategy, c_col, rng, num_samples_or_one, positive_sample, dataset_size
327+
strategy, c_col, rng, num_samples, positive_sample, dataset_size
343328
)
344329
elif positive_sample:
345330
mode = _sample_b_min_sep_positive_modes_no_truncation(
346-
strategy, c_col, rng, num_samples_or_one
331+
strategy, c_col, rng, num_samples
347332
)
348333
rest_batch_sizes = None
349334
else:
350-
mode = np.zeros((strategy.iterations, num_samples_or_one))
335+
mode = np.zeros((strategy.iterations, num_samples))
351336
rest_batch_sizes = None
352-
if num_samples is None:
353-
output = rng.normal(loc=mode[:, 0], scale=noise_multiplier)
354-
else:
355-
output = rng.normal(loc=mode, scale=noise_multiplier)
337+
output = rng.normal(loc=mode, scale=noise_multiplier)
356338
if rest_batch_sizes is None:
357339
return output
358340
else:
359-
if num_samples is None:
360-
rest_batch_sizes = rest_batch_sizes[:, 0]
361341
return output, rest_batch_sizes
362342

363343

@@ -367,7 +347,7 @@ def generate_sample(
367347
c_col: np.ndarray,
368348
seed: Seed = None,
369349
positive_sample: bool = True,
370-
num_samples: int | None = None,
350+
num_samples: int = 1,
371351
dataset_size: int | None = None,
372352
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
373353
"""Generates a sample from the dominating pair for amplified DP-BandMF.
@@ -385,19 +365,18 @@ def generate_sample(
385365
pair corresponding to the case where the sensitive example is included.
386366
Otherwise, we sample from the other case in the dominating pair, where the
387367
sensitive example is not included.
388-
num_samples: The number of samples to generate. None means generate a single
389-
sample.
368+
num_samples: The number of samples to generate. The default is 1, but it is
369+
typically much more efficient to generate multiple samples in a single
370+
call to benefit from vectorization.
390371
dataset_size: The size of the dataset. Should only be set if accounting for
391372
the strategy supports truncation, and strategy.truncated_batch_size is not
392373
None.
393374
394375
Returns:
395376
Sample(s) from the dominating PLD for DP-BandMF using balls-in-bins
396-
sampling. If num_samples is None, the output is 1D with dimension
397-
strategy.iterations. If num_samples is not None, the output is 2D with
398-
dimension (strategy.iterations, num_samples). Potentially also returns
399-
a second array containing auxiliary information needed to evaluate the
400-
privacy loss.
377+
sampling. The output is 2D with dimension (strategy.iterations,
378+
num_samples). Potentially also returns a second array containing auxiliary
379+
information needed to evaluate the privacy loss.
401380
"""
402381
if noise_multiplier < 0:
403382
raise ValueError('noise_multiplier must be non-negative.')
@@ -447,14 +426,13 @@ def _compute_balls_in_bins_privacy_loss(
447426
sample: np.ndarray,
448427
noise_multiplier: float,
449428
c_col: np.ndarray,
450-
) -> float | np.ndarray:
429+
) -> np.ndarray:
451430
"""Computes the privacy loss for a sample from balls-in-bins sampling.
452431
453432
Args:
454433
epoch_length: The length of each epoch (number of bins) for balls-in-bins
455434
sampling.
456-
sample: The sample(s), generated by _generate_balls_in_bins_sample. If 1D,
457-
treated as a single sample. If 2D, columns are treated as samples.
435+
sample: The sample(s), generated by _generate_balls_in_bins_sample.
458436
noise_multiplier: The noise multiplier of DP-MF. This is multiplied by the
459437
clip norm, not accounting for the norm of c_col.
460438
c_col: The non-zero entries in the first column of C. Should be non-negative
@@ -468,12 +446,8 @@ def _compute_balls_in_bins_privacy_loss(
468446
raise ValueError('epoch_length must be positive.')
469447
if noise_multiplier <= 0:
470448
raise ValueError('noise_multiplier must be positive.')
471-
if sample.ndim > 2:
472-
raise ValueError('sample must be a 1D or 2D array.')
473-
squeeze_output = False
474-
if sample.ndim == 1:
475-
sample = sample[:, np.newaxis]
476-
squeeze_output = True
449+
if sample.ndim != 2:
450+
raise ValueError('sample must be a 2D array.')
477451
iterations = sample.shape[0]
478452
_validate_c_col(c_col)
479453
if c_col.size > iterations:
@@ -486,7 +460,7 @@ def _compute_balls_in_bins_privacy_loss(
486460
squared_mode_norms = (modes_matrix**2).sum(axis=0)[:, np.newaxis]
487461
llrs = (2 * dot_products - squared_mode_norms) / (2 * noise_multiplier**2)
488462
privacy_loss = sp.special.logsumexp(llrs, axis=0) - np.log(epoch_length)
489-
return privacy_loss[0] if squeeze_output else privacy_loss
463+
return privacy_loss
490464

491465

492466
def _compute_b_min_sep_privacy_loss_no_truncation(
@@ -507,7 +481,7 @@ def _compute_b_min_sep_privacy_loss_no_truncation(
507481
it will participate in 1 / (b - 1 + 1 / sampling_prob) fraction of the
508482
iterations on average, not sampling_prob fraction of the iterations as in
509483
Poisson sampling.
510-
samples: The sample(s), generated by _generate_b_min_sep_sample.
484+
samples: The samples, generated by _generate_b_min_sep_sample.
511485
noise_multiplier: The noise multiplier of DP-MF. This is multiplied by the
512486
clip norm, not accounting for the norm of c_col.
513487
c_col: The non-zero entries in the first column of C. Should be non-negative
@@ -585,7 +559,7 @@ def _compute_b_min_sep_privacy_loss(
585559
it will participate in 1 / (b - 1 + 1 / sampling_prob) fraction of the
586560
iterations on average, not sampling_prob fraction of the iterations as in
587561
Poisson sampling.
588-
samples: The sample(s), generated by _generate_b_min_sep_sample.
562+
samples: The samples, generated by _generate_b_min_sep_sample.
589563
noise_multiplier: The noise multiplier of DP-MF. This is multiplied by the
590564
clip norm, not accounting for the norm of c_col.
591565
c_col: The non-zero entries in the first column of C. Should be non-negative
@@ -714,7 +688,7 @@ def compute_privacy_loss(
714688
noise_multiplier: float,
715689
c_col: np.ndarray,
716690
aux: np.ndarray | None = None,
717-
) -> float | np.ndarray:
691+
) -> np.ndarray:
718692
"""Computes the privacy loss on a sample from the dominating pair.
719693
720694
This method reports the privacy loss assuming we sample from the distribution
@@ -724,9 +698,7 @@ def compute_privacy_loss(
724698
725699
Args:
726700
strategy: The batch selection strategy used to generate the sample.
727-
sample: The sample(s), generated by generate_sample. If 2D, we assume the
728-
second dimension is the number of samples and return a vector of privacy
729-
losses.
701+
sample: The samples, generated by generate_sample.
730702
noise_multiplier: The noise multiplier of DP-MF. This is multiplied by the
731703
clip norm, not accounting for the norm of c_col.
732704
c_col: The non-zero entries in the first column of C. Should be non-negative
@@ -762,22 +734,13 @@ def compute_privacy_loss(
762734
)
763735
if aux is not None and aux.shape != sample.shape:
764736
raise ValueError('aux must have the same shape as sample.')
765-
if sample.ndim == 1:
766-
return _compute_b_min_sep_privacy_loss(
767-
strategy,
768-
np.expand_dims(sample, axis=1),
769-
noise_multiplier,
770-
c_col,
771-
rest_batch_sizes=aux,
772-
)[0]
773-
else:
774-
return _compute_b_min_sep_privacy_loss(
775-
strategy,
776-
sample,
777-
noise_multiplier,
778-
c_col,
779-
rest_batch_sizes=aux,
780-
)
737+
return _compute_b_min_sep_privacy_loss(
738+
strategy,
739+
sample,
740+
noise_multiplier,
741+
c_col,
742+
rest_batch_sizes=aux,
743+
)
781744
else:
782745
raise ValueError(f'Unsupported batch selection strategy: {type(strategy)}')
783746

@@ -788,9 +751,9 @@ def get_privacy_loss_sample(
788751
c_col: np.ndarray,
789752
seed: Seed = None,
790753
positive_sample: bool = True,
791-
num_samples: int | None = None,
754+
num_samples: int = 1,
792755
dataset_size: int | None = None,
793-
) -> tuple[float | np.ndarray, np.ndarray | tuple[np.ndarray, np.ndarray]]:
756+
) -> tuple[np.ndarray, np.ndarray | tuple[np.ndarray, np.ndarray]]:
794757
"""Returns sample(s) from DP-BandMF's dominating privacy loss distribution.
795758
796759
Args:
@@ -804,9 +767,9 @@ def get_privacy_loss_sample(
804767
pair corresponding to the case where the sensitive example is included.
805768
Otherwise, we sample from the other case in the dominating pair, where the
806769
sensitive example is not included.
807-
num_samples: The number of samples to generate. None means generate a single
808-
sample, and the output will be a float. If not None, the output will be a
809-
vector of floats.
770+
num_samples: The number of samples to generate. The default is 1, but it is
771+
typically much more efficient to generate multiple samples in a single
772+
call to benefit from vectorization.
810773
dataset_size: The size of the dataset. Should only be set if accounting for
811774
the strategy supports truncation, and strategy.truncated_batch_size is not
812775
None.

0 commit comments

Comments
 (0)