25
25
"""
26
26
import distutils
27
27
import math
28
- import attr
28
+ from typing import Optional
29
29
30
+ import attr
30
31
import tensorflow as tf
31
-
32
32
from tensorflow_privacy .privacy .dp_query import dp_query
33
33
from tensorflow_privacy .privacy .dp_query import tree_aggregation
34
34
@@ -442,16 +442,20 @@ def _loop_body(i, h):
442
442
return tree
443
443
444
444
445
- def _get_add_noise (stddev ):
445
+ def _get_add_noise (stddev , seed : int = None ):
446
446
"""Utility function to decide which `add_noise` to use according to tf version."""
447
447
if distutils .version .LooseVersion (
448
448
tf .__version__ ) < distutils .version .LooseVersion ('2.0.0' ):
449
449
450
+ # The seed should be only used for testing purpose.
451
+ if seed is not None :
452
+ tf .random .set_seed (seed )
453
+
450
454
def add_noise (v ):
451
455
return v + tf .random .normal (
452
456
tf .shape (input = v ), stddev = stddev , dtype = v .dtype )
453
457
else :
454
- random_normal = tf .random_normal_initializer (stddev = stddev )
458
+ random_normal = tf .random_normal_initializer (stddev = stddev , seed = seed )
455
459
456
460
def add_noise (v ):
457
461
return v + tf .cast (random_normal (tf .shape (input = v )), dtype = v .dtype )
@@ -478,17 +482,16 @@ class GlobalState(object):
478
482
"""Class defining global state for `CentralTreeSumQuery`.
479
483
480
484
Attributes:
481
- stddev: The stddev of the noise added to each node in the tree.
482
- arity: The branching factor of the tree (i.e. the number of children each
483
- internal node has).
484
485
l1_bound: An upper bound on the L1 norm of the input record. This is
485
486
needed to bound the sensitivity and deploy differential privacy.
486
487
"""
487
- stddev = attr .ib ()
488
- arity = attr .ib ()
489
488
l1_bound = attr .ib ()
490
489
491
- def __init__ (self , stddev : float , arity : int = 2 , l1_bound : int = 10 ):
490
+ def __init__ (self ,
491
+ stddev : float ,
492
+ arity : int = 2 ,
493
+ l1_bound : int = 10 ,
494
+ seed : Optional [int ] = None ):
492
495
"""Initializes the `CentralTreeSumQuery`.
493
496
494
497
Args:
@@ -497,15 +500,17 @@ def __init__(self, stddev: float, arity: int = 2, l1_bound: int = 10):
497
500
arity: The branching factor of the tree.
498
501
l1_bound: An upper bound on the L1 norm of the input record. This is
499
502
needed to bound the sensitivity and deploy differential privacy.
503
+ seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
504
+ test purpose.
500
505
"""
501
506
self ._stddev = stddev
502
507
self ._arity = arity
503
508
self ._l1_bound = l1_bound
509
+ self ._seed = seed
504
510
505
511
def initial_global_state (self ):
506
512
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
507
- return CentralTreeSumQuery .GlobalState (
508
- stddev = self ._stddev , arity = self ._arity , l1_bound = self ._l1_bound )
513
+ return CentralTreeSumQuery .GlobalState (l1_bound = self ._l1_bound )
509
514
510
515
def derive_sample_params (self , global_state ):
511
516
"""Implements `tensorflow_privacy.DPQuery.derive_sample_params`."""
@@ -536,10 +541,9 @@ def get_noised_result(self, sample_state, global_state):
536
541
The jth node on the ith layer of the tree can be accessed by tree[i][j]
537
542
where tree is the returned value.
538
543
"""
539
- add_noise = _get_add_noise (self ._stddev )
540
- tree = _build_tree_from_leaf (sample_state , global_state .arity )
541
- return tf .nest .map_structure (
542
- add_noise , tree , expand_composites = True ), global_state
544
+ add_noise = _get_add_noise (self ._stddev , self ._seed )
545
+ tree = _build_tree_from_leaf (sample_state , self ._arity )
546
+ return tf .map_fn (add_noise , tree ), global_state
543
547
544
548
545
549
class DistributedTreeSumQuery (dp_query .SumAggregationDPQuery ):
@@ -577,18 +581,25 @@ class GlobalState(object):
577
581
arity = attr .ib ()
578
582
l1_bound = attr .ib ()
579
583
580
- def __init__ (self , stddev : float , arity : int = 2 , l1_bound : int = 10 ):
584
+ def __init__ (self ,
585
+ stddev : float ,
586
+ arity : int = 2 ,
587
+ l1_bound : int = 10 ,
588
+ seed : Optional [int ] = None ):
581
589
"""Initializes the `DistributedTreeSumQuery`.
582
590
583
591
Args:
584
592
stddev: The stddev of the noise added to each node in the tree.
585
593
arity: The branching factor of the tree.
586
594
l1_bound: An upper bound on the L1 norm of the input record. This is
587
595
needed to bound the sensitivity and deploy differential privacy.
596
+ seed: Random seed to generate Gaussian noise. Defaults to `None`. Only for
597
+ test purpose.
588
598
"""
589
599
self ._stddev = stddev
590
600
self ._arity = arity
591
601
self ._l1_bound = l1_bound
602
+ self ._seed = seed
592
603
593
604
def initial_global_state (self ):
594
605
"""Implements `tensorflow_privacy.DPQuery.initial_global_state`."""
@@ -628,9 +639,9 @@ def preprocess_record(self, params, record):
628
639
use_norm = l1_norm )
629
640
preprocessed_record = preprocessed_record [0 ]
630
641
631
- add_noise = _get_add_noise (self ._stddev )
642
+ add_noise = _get_add_noise (self ._stddev , self . _seed )
632
643
tree = _build_tree_from_leaf (preprocessed_record , arity )
633
- noisy_tree = tf .nest . map_structure (add_noise , tree , expand_composites = True )
644
+ noisy_tree = tf .map_fn (add_noise , tree )
634
645
635
646
# The following codes reshape the output vector so the output shape of can
636
647
# be statically inferred. This is useful when used with
0 commit comments