Skip to content

Commit af68377

Browse files
Ryan McKennacopybara-github
authored andcommitted
Refactor BandMFExecutionPlanConfig: separate calibration from config
PiperOrigin-RevId: 926393400
1 parent 7ad9c12 commit af68377

5 files changed

Lines changed: 154 additions & 191 deletions

File tree

examples/dp_logistic_regression.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,8 @@ def main(_):
112112
num_bands=BANDS,
113113
l2_clip_norm=L2_CLIP_NORM,
114114
normalize_by=EXPECTED_BATCH_SIZE,
115-
epsilon=EPSILON,
116-
delta=DELTA,
117115
sampling_prob=EXPECTED_BATCH_SIZE / train_users * BANDS,
118-
)
116+
).calibrate(epsilon=EPSILON, delta=DELTA)
119117
print('Initialized BandMFExecutionPlanConfig')
120118
plan = config.make()
121119
grad_fn = plan.clipped_grad(logistic_loss, batch_argnums=(1, 2))

examples/dp_sgd_transformer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,13 +247,11 @@ def main(argv: Sequence[str]) -> None:
247247
config = execution_plan.BandMFExecutionPlanConfig.default(
248248
iterations=iterations,
249249
num_bands=1,
250-
epsilon=epsilon,
251-
delta=delta,
252250
sampling_prob=expected_batch_size / train_size,
253251
rescale_to_unit_norm=True,
254252
normalize_by=expected_batch_size,
255253
l2_clip_norm=clipping_norm,
256-
)
254+
).calibrate(epsilon=epsilon, delta=delta)
257255
plan = config.make()
258256
grad_fn = plan.clipped_grad(
259257
loss_fn,

jax_privacy/experimental/README.md

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ include standard DP-SGD as the special case `num_bands=1`.
2828

2929
```python
3030
import jax.numpy as jnp
31-
import dp_accounting
32-
from jax_privacy import batch_selection
3331
from jax_privacy import clipping
3432
from jax_privacy.experimental import execution_plan
3533

@@ -44,13 +42,7 @@ config = execution_plan.BandMFExecutionPlanConfig.default(
4442
iterations=1000,
4543
num_bands=1,
4644
sampling_prob=128 / 60000,
47-
epsilon=2.0,
48-
delta=1e-6,
49-
partition_type=batch_selection.PartitionType.INDEPENDENT,
50-
accountant=dp_accounting.pld.PLDAccountant(
51-
dp_accounting.NeighboringRelation.ADD_OR_REMOVE_ONE
52-
),
53-
)
45+
).calibrate(epsilon=2.0, delta=1e-6)
5446

5547
plan = config.make()
5648

0 commit comments

Comments
 (0)