@@ -75,23 +75,14 @@ def _all_balls_in_bins_modes(
7575 return sp .linalg .toeplitz (first_mode , zeros_vector )
7676
7777
78- def _generate_zero_mean_sample (
79- iterations : int ,
80- noise_multiplier : float ,
81- rng : np .random .Generator ,
82- ) -> np .ndarray :
83- """Generates a sample from a zero-mean Gaussian."""
84- return rng .normal (loc = 0.0 , scale = noise_multiplier , size = iterations )
85-
86-
8778def _generate_balls_in_bins_sample (
8879 iterations : int ,
8980 cycle_length : int ,
9081 noise_multiplier : float ,
9182 c_col : np .ndarray ,
9283 seed : Seed = None ,
9384 positive_sample : bool = True ,
94- num_samples : int | None = None ,
85+ num_samples : int = 1 ,
9586) -> np .ndarray :
9687 """Sample from the dominating pair for DP-BandMF using balls-in-bins sampling.
9788
@@ -109,8 +100,7 @@ def _generate_balls_in_bins_sample(
109100 pair corresponding to the case where the sensitive example is included.
110101 Otherwise, we sample from the other case in the dominating pair, where the
111102 sensitive example is not included.
112- num_samples: The number of samples to generate. None means generate a single
113- sample.
103+ num_samples: The number of samples to generate.
114104
115105 Returns:
116106 A sample from the dominating PLD for DP-BandMF using balls-in-bins sampling.
@@ -121,7 +111,6 @@ def _generate_balls_in_bins_sample(
121111 raise ValueError ('cycle_length must be positive.' )
122112 if c_col .size > iterations :
123113 c_col = c_col [:iterations ]
124- num_samples_or_one = num_samples or 1
125114 rng = np .random .default_rng (seed )
126115 if positive_sample :
127116 # Add Cx to the Gaussian noise, where x is a vector which is 1 in every
@@ -131,13 +120,12 @@ def _generate_balls_in_bins_sample(
131120 iterations , cycle_length , tuple (c_col )
132121 )
133122 counts = rng .multinomial (
134- n = num_samples_or_one , pvals = np .full (cycle_length , 1.0 / cycle_length )
123+ n = num_samples , pvals = np .full (cycle_length , 1.0 / cycle_length )
135124 )
136125 mode = np .repeat (possible_modes , repeats = counts , axis = 1 )
137126 else :
138- mode = np .zeros ((iterations , num_samples_or_one ))
139- sample = rng .normal (loc = mode , scale = noise_multiplier )
140- return sample [:, 0 ] if num_samples is None else sample
127+ mode = np .zeros ((iterations , num_samples ))
128+ return rng .normal (loc = mode , scale = noise_multiplier )
141129
142130
143131def _sample_b_min_sep_positive_modes_no_truncation (
@@ -301,7 +289,7 @@ def _generate_b_min_sep_sample(
301289 c_col : np .ndarray ,
302290 seed : Seed = None ,
303291 positive_sample : bool = True ,
304- num_samples : int | None = None ,
292+ num_samples : int = 1 ,
305293 dataset_size : int | None = None ,
306294) -> np .ndarray | tuple [np .ndarray , np .ndarray ]:
307295 """Samples from the dominating pair for DP-BandMF using b-min-sep sampling.
@@ -320,8 +308,7 @@ def _generate_b_min_sep_sample(
320308 pair corresponding to the case where the sensitive example is included.
321309 Otherwise, we sample from the other case in the dominating pair, where the
322310 sensitive example is not included.
323- num_samples: The number of samples to generate. None means generate a single
324- sample.
311+ num_samples: The number of samples to generate.
325312 dataset_size: The size of the dataset. Only used if
326313 strategy.truncated_batch_size is not None.
327314
@@ -335,29 +322,22 @@ def _generate_b_min_sep_sample(
335322 if c_col .size > strategy .iterations :
336323 c_col = c_col [: strategy .iterations ]
337324 rng = np .random .default_rng (seed )
338- # For simplicity, if num_samples is None, we still create a 2D array.
339- num_samples_or_one = num_samples or 1
340325 if strategy .truncated_batch_size :
341326 mode , rest_batch_sizes = _sample_b_min_sep_modes_with_truncation (
342- strategy , c_col , rng , num_samples_or_one , positive_sample , dataset_size
327+ strategy , c_col , rng , num_samples , positive_sample , dataset_size
343328 )
344329 elif positive_sample :
345330 mode = _sample_b_min_sep_positive_modes_no_truncation (
346- strategy , c_col , rng , num_samples_or_one
331+ strategy , c_col , rng , num_samples
347332 )
348333 rest_batch_sizes = None
349334 else :
350- mode = np .zeros ((strategy .iterations , num_samples_or_one ))
335+ mode = np .zeros ((strategy .iterations , num_samples ))
351336 rest_batch_sizes = None
352- if num_samples is None :
353- output = rng .normal (loc = mode [:, 0 ], scale = noise_multiplier )
354- else :
355- output = rng .normal (loc = mode , scale = noise_multiplier )
337+ output = rng .normal (loc = mode , scale = noise_multiplier )
356338 if rest_batch_sizes is None :
357339 return output
358340 else :
359- if num_samples is None :
360- rest_batch_sizes = rest_batch_sizes [:, 0 ]
361341 return output , rest_batch_sizes
362342
363343
@@ -367,7 +347,7 @@ def generate_sample(
367347 c_col : np .ndarray ,
368348 seed : Seed = None ,
369349 positive_sample : bool = True ,
370- num_samples : int | None = None ,
350+ num_samples : int = 1 ,
371351 dataset_size : int | None = None ,
372352) -> np .ndarray | tuple [np .ndarray , np .ndarray ]:
373353 """Generates a sample from the dominating pair for amplified DP-BandMF.
@@ -385,19 +365,18 @@ def generate_sample(
385365 pair corresponding to the case where the sensitive example is included.
386366 Otherwise, we sample from the other case in the dominating pair, where the
387367 sensitive example is not included.
388- num_samples: The number of samples to generate. None means generate a single
389- sample.
368+ num_samples: The number of samples to generate. The default is 1, but it is
369+ typically much more efficient to generate multiple samples in a single
370+ call to benefit from vectorization.
390371 dataset_size: The size of the dataset. Should only be set if accounting for
391372 the strategy supports truncation, and strategy.truncated_batch_size is not
392373 None.
393374
394375 Returns:
395376 Sample(s) from the dominating PLD for DP-BandMF using balls-in-bins
396- sampling. If num_samples is None, the output is 1D with dimension
397- strategy.iterations. If num_samples is not None, the output is 2D with
398- dimension (strategy.iterations, num_samples). Potentially also returns
399- a second array containing auxiliary information needed to evaluate the
400- privacy loss.
377+ sampling. The output is 2D with dimension (strategy.iterations,
378+ num_samples). Potentially also returns a second array containing auxiliary
379+ information needed to evaluate the privacy loss.
401380 """
402381 if noise_multiplier < 0 :
403382 raise ValueError ('noise_multiplier must be non-negative.' )
@@ -447,14 +426,13 @@ def _compute_balls_in_bins_privacy_loss(
447426 sample : np .ndarray ,
448427 noise_multiplier : float ,
449428 c_col : np .ndarray ,
450- ) -> float | np .ndarray :
429+ ) -> np .ndarray :
451430 """Computes the privacy loss for a sample from balls-in-bins sampling.
452431
453432 Args:
454433 epoch_length: The length of each epoch (number of bins) for balls-in-bins
455434 sampling.
456- sample: The sample(s), generated by _generate_balls_in_bins_sample. If 1D,
457- treated as a single sample. If 2D, columns are treated as samples.
435+ sample: The sample(s), generated by _generate_balls_in_bins_sample.
458436 noise_multiplier: The noise multiplier of DP-MF. This is multiplied by the
459437 clip norm, not accounting for the norm of c_col.
460438 c_col: The non-zero entries in the first column of C. Should be non-negative
@@ -468,12 +446,8 @@ def _compute_balls_in_bins_privacy_loss(
468446 raise ValueError ('epoch_length must be positive.' )
469447 if noise_multiplier <= 0 :
470448 raise ValueError ('noise_multiplier must be positive.' )
471- if sample .ndim > 2 :
472- raise ValueError ('sample must be a 1D or 2D array.' )
473- squeeze_output = False
474- if sample .ndim == 1 :
475- sample = sample [:, np .newaxis ]
476- squeeze_output = True
449+ if sample .ndim != 2 :
450+ raise ValueError ('sample must be a 2D array.' )
477451 iterations = sample .shape [0 ]
478452 _validate_c_col (c_col )
479453 if c_col .size > iterations :
@@ -486,7 +460,7 @@ def _compute_balls_in_bins_privacy_loss(
486460 squared_mode_norms = (modes_matrix ** 2 ).sum (axis = 0 )[:, np .newaxis ]
487461 llrs = (2 * dot_products - squared_mode_norms ) / (2 * noise_multiplier ** 2 )
488462 privacy_loss = sp .special .logsumexp (llrs , axis = 0 ) - np .log (epoch_length )
489- return privacy_loss [ 0 ] if squeeze_output else privacy_loss
463+ return privacy_loss
490464
491465
492466def _compute_b_min_sep_privacy_loss_no_truncation (
@@ -507,7 +481,7 @@ def _compute_b_min_sep_privacy_loss_no_truncation(
507481 it will participate in 1 / (b - 1 + 1 / sampling_prob) fraction of the
508482 iterations on average, not sampling_prob fraction of the iterations as in
509483 Poisson sampling.
510- samples: The sample(s) , generated by _generate_b_min_sep_sample.
484+ samples: The samples , generated by _generate_b_min_sep_sample.
511485 noise_multiplier: The noise multiplier of DP-MF. This is multiplied by the
512486 clip norm, not accounting for the norm of c_col.
513487 c_col: The non-zero entries in the first column of C. Should be non-negative
@@ -585,7 +559,7 @@ def _compute_b_min_sep_privacy_loss(
585559 it will participate in 1 / (b - 1 + 1 / sampling_prob) fraction of the
586560 iterations on average, not sampling_prob fraction of the iterations as in
587561 Poisson sampling.
588- samples: The sample(s) , generated by _generate_b_min_sep_sample.
562+ samples: The samples , generated by _generate_b_min_sep_sample.
589563 noise_multiplier: The noise multiplier of DP-MF. This is multiplied by the
590564 clip norm, not accounting for the norm of c_col.
591565 c_col: The non-zero entries in the first column of C. Should be non-negative
@@ -714,7 +688,7 @@ def compute_privacy_loss(
714688 noise_multiplier : float ,
715689 c_col : np .ndarray ,
716690 aux : np .ndarray | None = None ,
717- ) -> float | np .ndarray :
691+ ) -> np .ndarray :
718692 """Computes the privacy loss on a sample from the dominating pair.
719693
720694 This method reports the privacy loss assuming we sample from the distribution
@@ -724,9 +698,7 @@ def compute_privacy_loss(
724698
725699 Args:
726700 strategy: The batch selection strategy used to generate the sample.
727- sample: The sample(s), generated by generate_sample. If 2D, we assume the
728- second dimension is the number of samples and return a vector of privacy
729- losses.
701+ sample: The samples, generated by generate_sample.
730702 noise_multiplier: The noise multiplier of DP-MF. This is multiplied by the
731703 clip norm, not accounting for the norm of c_col.
732704 c_col: The non-zero entries in the first column of C. Should be non-negative
@@ -762,22 +734,13 @@ def compute_privacy_loss(
762734 )
763735 if aux is not None and aux .shape != sample .shape :
764736 raise ValueError ('aux must have the same shape as sample.' )
765- if sample .ndim == 1 :
766- return _compute_b_min_sep_privacy_loss (
767- strategy ,
768- np .expand_dims (sample , axis = 1 ),
769- noise_multiplier ,
770- c_col ,
771- rest_batch_sizes = aux ,
772- )[0 ]
773- else :
774- return _compute_b_min_sep_privacy_loss (
775- strategy ,
776- sample ,
777- noise_multiplier ,
778- c_col ,
779- rest_batch_sizes = aux ,
780- )
737+ return _compute_b_min_sep_privacy_loss (
738+ strategy ,
739+ sample ,
740+ noise_multiplier ,
741+ c_col ,
742+ rest_batch_sizes = aux ,
743+ )
781744 else :
782745 raise ValueError (f'Unsupported batch selection strategy: { type (strategy )} ' )
783746
@@ -788,9 +751,9 @@ def get_privacy_loss_sample(
788751 c_col : np .ndarray ,
789752 seed : Seed = None ,
790753 positive_sample : bool = True ,
791- num_samples : int | None = None ,
754+ num_samples : int = 1 ,
792755 dataset_size : int | None = None ,
793- ) -> tuple [float | np .ndarray , np .ndarray | tuple [np .ndarray , np .ndarray ]]:
756+ ) -> tuple [np .ndarray , np .ndarray | tuple [np .ndarray , np .ndarray ]]:
794757 """Returns sample(s) from DP-BandMF's dominating privacy loss distribution.
795758
796759 Args:
@@ -804,9 +767,9 @@ def get_privacy_loss_sample(
804767 pair corresponding to the case where the sensitive example is included.
805768 Otherwise, we sample from the other case in the dominating pair, where the
806769 sensitive example is not included.
807- num_samples: The number of samples to generate. None means generate a single
808- sample, and the output will be a float. If not None, the output will be a
809- vector of floats .
770+ num_samples: The number of samples to generate. The default is 1, but it is
771+ typically much more efficient to generate multiple samples in a single
772+ call to benefit from vectorization .
810773 dataset_size: The size of the dataset. Should only be set if accounting for
811774 the strategy supports truncation, and strategy.truncated_batch_size is not
812775 None.
0 commit comments