Skip to content

Commit 2cafe28

Browse files
The previous version uses tf.nest.map_structure to apply add_noise to a tf.RaggedTensor. This causes a bug when used in tensorflow federated because tf.nest.map_structure will also map add_noise to the tensor for shape information in tf.RaggedTensor. This causes failure when tff conducts automatic type conversion.
Also use fixed random seed to avoid flaky timeouts and testing failures. PiperOrigin-RevId: 384573740
1 parent 7f44b02 commit 2cafe28

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py

+30-19
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
"""
2626
import distutils
2727
import math
28-
import attr
28+
from typing import Optional
2929

30+
import attr
3031
import tensorflow as tf
31-
3232
from tensorflow_privacy.privacy.dp_query import dp_query
3333
from tensorflow_privacy.privacy.dp_query import tree_aggregation
3434

@@ -442,16 +442,20 @@ def _loop_body(i, h):
442442
return tree
443443

444444

445-
def _get_add_noise(stddev):
445+
def _get_add_noise(stddev, seed: int = None):
446446
"""Utility function to decide which `add_noise` to use according to tf version."""
447447
if distutils.version.LooseVersion(
448448
tf.__version__) < distutils.version.LooseVersion('2.0.0'):
449449

450+
# The seed should be only used for testing purpose.
451+
if seed is not None:
452+
tf.random.set_seed(seed)
453+
450454
def add_noise(v):
451455
return v + tf.random.normal(
452456
tf.shape(input=v), stddev=stddev, dtype=v.dtype)
453457
else:
454-
random_normal = tf.random_normal_initializer(stddev=stddev)
458+
random_normal = tf.random_normal_initializer(stddev=stddev, seed=seed)
455459

456460
def add_noise(v):
457461
return v + tf.cast(random_normal(tf.shape(input=v)), dtype=v.dtype)
@@ -478,17 +482,16 @@ class GlobalState(object):
478482
"""Class defining global state for `CentralTreeSumQuery`.
479483
480484
Attributes:
481-
stddev: The stddev of the noise added to each node in the tree.
482-
arity: The branching factor of the tree (i.e. the number of children each
483-
internal node has).
484485
l1_bound: An upper bound on the L1 norm of the input record. This is
485486
needed to bound the sensitivity and deploy differential privacy.
486487
"""
487-
stddev = attr.ib()
488-
arity = attr.ib()
489488
l1_bound = attr.ib()
490489

491-
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
490+
def __init__(self,
491+
stddev: float,
492+
arity: int = 2,
493+
l1_bound: int = 10,
494+
seed: Optional[int] = None):
492495
"""Initializes the `CentralTreeSumQuery`.
493496
494497
Args:
@@ -497,15 +500,17 @@ def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
497500
arity: The branching factor of the tree.
498501
l1_bound: An upper bound on the L1 norm of the input record. This is
499502
needed to bound the sensitivity and deploy differential privacy.
503+
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
504+
test purpose.
500505
"""
501506
self._stddev = stddev
502507
self._arity = arity
503508
self._l1_bound = l1_bound
509+
self._seed = seed
504510

505511
def initial_global_state(self):
506512
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
507-
return CentralTreeSumQuery.GlobalState(
508-
stddev=self._stddev, arity=self._arity, l1_bound=self._l1_bound)
513+
return CentralTreeSumQuery.GlobalState(l1_bound=self._l1_bound)
509514

510515
def derive_sample_params(self, global_state):
511516
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@@ -536,10 +541,9 @@ def get_noised_result(self, sample_state, global_state):
536541
The jth node on the ith layer of the tree can be accessed by tree[i][j]
537542
where tree is the returned value.
538543
"""
539-
add_noise = _get_add_noise(self._stddev)
540-
tree = _build_tree_from_leaf(sample_state, global_state.arity)
541-
return tf.nest.map_structure(
542-
add_noise, tree, expand_composites=True), global_state
544+
add_noise = _get_add_noise(self._stddev, self._seed)
545+
tree = _build_tree_from_leaf(sample_state, self._arity)
546+
return tf.map_fn(add_noise, tree), global_state
543547

544548

545549
class DistributedTreeSumQuery(dp_query.SumAggregationDPQuery):
@@ -577,18 +581,25 @@ class GlobalState(object):
577581
arity = attr.ib()
578582
l1_bound = attr.ib()
579583

580-
def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
584+
def __init__(self,
585+
stddev: float,
586+
arity: int = 2,
587+
l1_bound: int = 10,
588+
seed: Optional[int] = None):
581589
"""Initializes the `DistributedTreeSumQuery`.
582590
583591
Args:
584592
stddev: The stddev of the noise added to each node in the tree.
585593
arity: The branching factor of the tree.
586594
l1_bound: An upper bound on the L1 norm of the input record. This is
587595
needed to bound the sensitivity and deploy differential privacy.
596+
seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
597+
test purpose.
588598
"""
589599
self._stddev = stddev
590600
self._arity = arity
591601
self._l1_bound = l1_bound
602+
self._seed = seed
592603

593604
def initial_global_state(self):
594605
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
@@ -628,9 +639,9 @@ def preprocess_record(self, params, record):
628639
use_norm=l1_norm)
629640
preprocessed_record = preprocessed_record[0]
630641

631-
add_noise = _get_add_noise(self._stddev)
642+
add_noise = _get_add_noise(self._stddev, self._seed)
632643
tree = _build_tree_from_leaf(preprocessed_record, arity)
633-
noisy_tree = tf.nest.map_structure(add_noise, tree, expand_composites=True)
644+
noisy_tree = tf.map_fn(add_noise, tree)
634645

635646
# The following codes reshape the output vector so the output shape of can
636647
# be statically inferred. This is useful when used with

tensorflow_privacy/privacy/dp_query/tree_aggregation_query_test.py

+10-24
Original file line numberDiff line numberDiff line change
@@ -502,21 +502,15 @@ def test_get_noised_result(self, arity, record, expected_tree):
502502
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
503503
)
504504
def test_get_noised_result_with_noise(self, stddev, record, expected_tree):
505-
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev)
505+
query = tree_aggregation_query.CentralTreeSumQuery(stddev=stddev, seed=0)
506506
global_state = query.initial_global_state()
507507
params = query.derive_sample_params(global_state)
508508
preprocessed_record = query.preprocess_record(params, record)
509-
sample_state_list = []
510-
for _ in range(1000):
511-
sample_state, _ = query.get_noised_result(preprocessed_record,
512-
global_state)
513-
sample_state_list.append(sample_state.flat_values.numpy())
514-
expectation = np.mean(sample_state_list, axis=0)
515-
variance = np.std(sample_state_list, axis=0)
516-
517-
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
509+
510+
sample_state, _ = query.get_noised_result(preprocessed_record, global_state)
511+
518512
self.assertAllClose(
519-
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
513+
sample_state.flat_values, expected_tree, atol=3 * stddev)
520514

521515
@parameterized.named_parameters(
522516
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),
@@ -556,8 +550,7 @@ def test_initial_global_state_type(self):
556550
def test_derive_sample_params(self):
557551
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=NOISE_STD)
558552
global_state = query.initial_global_state()
559-
stddev, arity, l1_bound = query.derive_sample_params(
560-
global_state)
553+
stddev, arity, l1_bound = query.derive_sample_params(global_state)
561554
self.assertAllClose(stddev, NOISE_STD)
562555
self.assertAllClose(arity, 2)
563556
self.assertAllClose(l1_bound, 10)
@@ -587,21 +580,14 @@ def test_preprocess_record(self, arity, record, expected_tree):
587580
('stddev_0_1', 0.1, tf.constant([1, 0], dtype=tf.int32), [1., 1., 0.]),
588581
)
589582
def test_preprocess_record_with_noise(self, stddev, record, expected_tree):
590-
query = tree_aggregation_query.DistributedTreeSumQuery(stddev=stddev)
583+
query = tree_aggregation_query.DistributedTreeSumQuery(
584+
stddev=stddev, seed=0)
591585
global_state = query.initial_global_state()
592586
params = query.derive_sample_params(global_state)
593587

594-
preprocessed_record_list = []
595-
for _ in range(1000):
596-
preprocessed_record = query.preprocess_record(params, record)
597-
preprocessed_record_list.append(preprocessed_record.numpy())
598-
599-
expectation = np.mean(preprocessed_record_list, axis=0)
600-
variance = np.std(preprocessed_record_list, axis=0)
588+
preprocessed_record = query.preprocess_record(params, record)
601589

602-
self.assertAllClose(expectation, expected_tree, rtol=3 * stddev, atol=1e-4)
603-
self.assertAllClose(
604-
variance, np.ones(len(variance)) * stddev, rtol=0.1, atol=1e-4)
590+
self.assertAllClose(preprocessed_record, expected_tree, atol=3 * stddev)
605591

606592
@parameterized.named_parameters(
607593
('binary_test_int', 2, tf.constant([10, 10, 0, 0], dtype=tf.int32),

0 commit comments

Comments
 (0)