Skip to content

Commit 563075a

Browse files
JAX Privacy Teamcopybara-github
authored andcommitted
Add sampling method to privacy accounting.
PiperOrigin-RevId: 921568766
1 parent e286fd6 commit 563075a

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

jax_privacy/accounting/calibrate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def calibrate_num_updates(
5353
examples_per_user: int | None = None,
5454
cycle_length: int | None = None,
5555
truncated_batch_size: int | None = None,
56+
sampling_method: analysis.SamplingMethod = analysis.SamplingMethod.POISSON,
5657
initial_max_updates: int = 4,
5758
initial_min_updates: int = 1,
5859
tol: float = 0.1,
@@ -76,6 +77,7 @@ def calibrate_num_updates(
7677
the cycle.
7778
truncated_batch_size: If using truncated Poisson sampling, the maximum batch
7879
size to truncate to.
80+
sampling_method: Which sampling method the privacy analysis should assume.
7981
initial_max_updates: An initial estimate of the number of updates.
8082
initial_min_updates: Minimum number of updates.
8183
tol: tolerance of the optimizer for the calibration.
@@ -95,6 +97,7 @@ def get_epsilon(num_updates: int) -> float:
9597
examples_per_user=examples_per_user,
9698
cycle_length=cycle_length,
9799
truncated_batch_size=truncated_batch_size,
100+
sampling_method=sampling_method,
98101
)
99102
return accountant.compute_epsilon(num_updates, dp_params)
100103

@@ -133,6 +136,7 @@ def calibrate_noise_multiplier(
133136
examples_per_user: int | None = None,
134137
cycle_length: int | None = None,
135138
truncated_batch_size: int | None = None,
139+
sampling_method: analysis.SamplingMethod = analysis.SamplingMethod.POISSON,
136140
initial_max_noise: float = 1.0,
137141
initial_min_noise: float = 0.0,
138142
tol: float = 0.01,
@@ -154,6 +158,7 @@ def calibrate_noise_multiplier(
154158
the cycle.
155159
truncated_batch_size: If using truncated Poisson sampling, the maximum batch
156160
size to truncate to.
161+
sampling_method: Which sampling method the privacy analysis should assume.
157162
initial_max_noise: An initial estimate of the noise multiplier.
158163
initial_min_noise: Minimum noise multiplier.
159164
tol: tolerance of the optimizer for the calibration.
@@ -175,6 +180,7 @@ def get_epsilon(noise_multiplier: float) -> float:
175180
examples_per_user=examples_per_user,
176181
cycle_length=cycle_length,
177182
truncated_batch_size=truncated_batch_size,
183+
sampling_method=sampling_method,
178184
)
179185
return accountant.compute_epsilon(num_updates, dp_params)
180186

@@ -202,6 +208,7 @@ def calibrate_batch_size(
202208
examples_per_user: int | None = None,
203209
cycle_length: int | None = None,
204210
truncated_batch_size: int | None = None,
211+
sampling_method: analysis.SamplingMethod = analysis.SamplingMethod.POISSON,
205212
initial_max_batch_size: int = 8,
206213
initial_min_batch_size: int = 1,
207214
tol: float = 0.01,
@@ -223,6 +230,7 @@ def calibrate_batch_size(
223230
the cycle.
224231
truncated_batch_size: If using truncated Poisson sampling, the maximum batch
225232
size to truncate to.
233+
sampling_method: Which sampling method the privacy analysis should assume.
226234
initial_max_batch_size: An initial estimate of the batch size.
227235
initial_min_batch_size: Minimum batch size.
228236
tol: tolerance of the optimizer for the calibration.
@@ -244,6 +252,7 @@ def get_epsilon(batch_size: int) -> float:
244252
examples_per_user=examples_per_user,
245253
cycle_length=cycle_length,
246254
truncated_batch_size=truncated_batch_size,
255+
sampling_method=sampling_method,
247256
)
248257
return accountant.compute_epsilon(num_updates, dp_params)
249258

0 commit comments

Comments
 (0)