Skip to content

Commit ef8ab2c

Browse files
Ryan McKennacopybara-github
authored andcommitted
Refactor BandMFExecutionPlanConfig: separate calibration from config.
The main reasoning for this change was to make it easier to calibrate a mechanism to things other than epsilon/delta. From speaking with folks at TPDP, we probably want to support calibrating to GDP or TPR/FPR as well in the future. This change will make that feel less clunky. An additional benefit of this change is that the dataclass becomes significantly simpler, with epsilon/delta being removed, accountant being removed, and things like partition type and neighboring relation being automatically inferred based on the other inputs. This should reduce cognitive load for the different classes/functions and make it easier to understand and configure a BandMF mechanism. PiperOrigin-RevId: 926393400
1 parent 7ad9c12 commit ef8ab2c

5 files changed

Lines changed: 158 additions & 191 deletions

File tree

examples/dp_logistic_regression.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,8 @@ def main(_):
112112
num_bands=BANDS,
113113
l2_clip_norm=L2_CLIP_NORM,
114114
normalize_by=EXPECTED_BATCH_SIZE,
115-
epsilon=EPSILON,
116-
delta=DELTA,
117115
sampling_prob=EXPECTED_BATCH_SIZE / train_users * BANDS,
118-
)
116+
).calibrate(epsilon=EPSILON, delta=DELTA)
119117
print('Initialized BandMFExecutionPlanConfig')
120118
plan = config.make()
121119
grad_fn = plan.clipped_grad(logistic_loss, batch_argnums=(1, 2))

examples/dp_sgd_transformer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,11 @@ def main(argv: Sequence[str]) -> None:
247247
config = execution_plan.BandMFExecutionPlanConfig.default(
248248
iterations=iterations,
249249
num_bands=1,
250-
epsilon=epsilon,
251-
delta=delta,
252250
sampling_prob=expected_batch_size / train_size,
253251
rescale_to_unit_norm=True,
254252
normalize_by=expected_batch_size,
255253
l2_clip_norm=clipping_norm,
256-
)
254+
).calibrate(epsilon=epsilon, delta=delta)
257255
plan = config.make()
258256
grad_fn = plan.clipped_grad(
259257
loss_fn,

jax_privacy/experimental/README.md

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ include standard DP-SGD as the special case `num_bands=1`.
2828

2929
```python
3030
import jax.numpy as jnp
31-
import dp_accounting
32-
from jax_privacy import batch_selection
3331
from jax_privacy import clipping
3432
from jax_privacy.experimental import execution_plan
3533

@@ -44,13 +42,7 @@ config = execution_plan.BandMFExecutionPlanConfig.default(
4442
iterations=1000,
4543
num_bands=1,
4644
sampling_prob=128 / 60000,
47-
epsilon=2.0,
48-
delta=1e-6,
49-
partition_type=batch_selection.PartitionType.INDEPENDENT,
50-
accountant=dp_accounting.pld.PLDAccountant(
51-
dp_accounting.NeighboringRelation.ADD_OR_REMOVE_ONE
52-
),
53-
)
45+
).calibrate(epsilon=2.0, delta=1e-6)
5446

5547
plan = config.make()
5648

0 commit comments

Comments
 (0)