3333 eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
3434"""
3535
36- import attr
36+ from typing import Any , NamedTuple
37+
3738import dp_accounting
3839import tensorflow as tf
40+
3941from tensorflow_privacy .privacy .dp_query import dp_query
4042from tensorflow_privacy .privacy .dp_query import tree_aggregation
4143
@@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
8486 O(clip_norm*log(T)/eps) to guarantee eps-DP.
8587 """
8688
87- @attr .s (frozen = True )
88- class GlobalState (object ):
89+ class GlobalState (NamedTuple ):
8990 """Class defining global state for Tree sum queries.
9091
9192 Attributes:
@@ -94,9 +95,9 @@ class GlobalState(object):
9495 clip_value: The clipping value to be passed to clip_fn.
9596 samples_cumulative_sum: Noiseless cumulative sum of samples over time.
9697 """
97- tree_state = attr . ib ()
98- clip_value = attr . ib ()
99- samples_cumulative_sum = attr . ib ()
98+ tree_state : Any
99+ clip_value : Any
100+ samples_cumulative_sum : Any
100101
101102 def __init__ (self ,
102103 record_specs ,
@@ -182,10 +183,11 @@ def get_noised_result(self, sample_state, global_state):
182183 global_state .tree_state )
183184 noised_cumulative_sum = tf .nest .map_structure (tf .add , new_cumulative_sum ,
184185 cumulative_sum_noise )
185- new_global_state = attr .evolve (
186- global_state ,
186+ new_global_state = TreeCumulativeSumQuery .GlobalState (
187+ tree_state = new_tree_state ,
188+ clip_value = global_state .clip_value ,
187189 samples_cumulative_sum = new_cumulative_sum ,
188- tree_state = new_tree_state )
190+ )
189191 event = dp_accounting .UnsupportedDpEvent ()
190192 return noised_cumulative_sum , new_global_state , event
191193
@@ -206,10 +208,11 @@ def reset_state(self, noised_results, global_state):
206208 state for the next cumulative sum.
207209 """
208210 new_tree_state = self ._tree_aggregator .reset_state (global_state .tree_state )
209- return attr .evolve (
210- global_state ,
211+ return TreeCumulativeSumQuery .GlobalState (
212+ tree_state = new_tree_state ,
213+ clip_value = global_state .clip_value ,
211214 samples_cumulative_sum = noised_results ,
212- tree_state = new_tree_state )
215+ )
213216
214217 @classmethod
215218 def build_l2_gaussian_query (cls ,
@@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
312315 O(clip_norm*log(T)/eps) to guarantee eps-DP.
313316 """
314317
315- @attr .s (frozen = True )
316- class GlobalState (object ):
318+ class GlobalState (NamedTuple ):
317319 """Class defining global state for Tree sum queries.
318320
319321 Attributes:
@@ -323,9 +325,9 @@ class GlobalState(object):
323325 previous_tree_noise: Cumulative noise by tree aggregation from the
324326 previous time the query is called on a sample.
325327 """
326- tree_state = attr . ib ()
327- clip_value = attr . ib ()
328- previous_tree_noise = attr . ib ()
328+ tree_state : Any
329+ clip_value : Any
330+ previous_tree_noise : Any
329331
330332 def __init__ (self ,
331333 record_specs ,
@@ -426,8 +428,11 @@ def get_noised_result(self, sample_state, global_state):
426428 noised_sample = tf .nest .map_structure (lambda a , b , c : a + b - c ,
427429 sample_state , tree_noise ,
428430 global_state .previous_tree_noise )
429- new_global_state = attr .evolve (
430- global_state , previous_tree_noise = tree_noise , tree_state = new_tree_state )
431+ new_global_state = TreeResidualSumQuery .GlobalState (
432+ tree_state = new_tree_state ,
433+ clip_value = global_state .clip_value ,
434+ previous_tree_noise = tree_noise ,
435+ )
431436 event = dp_accounting .UnsupportedDpEvent ()
432437 return noised_sample , new_global_state , event
433438
@@ -448,21 +453,28 @@ def reset_state(self, noised_results, global_state):
448453 """
449454 del noised_results
450455 new_tree_state = self ._tree_aggregator .reset_state (global_state .tree_state )
451- return attr .evolve (
452- global_state ,
456+ return TreeResidualSumQuery .GlobalState (
457+ tree_state = new_tree_state ,
458+ clip_value = global_state .clip_value ,
453459 previous_tree_noise = self ._zero_initial_noise (),
454- tree_state = new_tree_state )
460+ )
455461
456462 def reset_l2_clip_gaussian_noise (self , global_state , clip_norm , stddev ):
457463 noise_generator_state = global_state .tree_state .value_generator_state
458464 assert isinstance (self ._tree_aggregator .value_generator ,
459465 tree_aggregation .GaussianNoiseGenerator )
460466 noise_generator_state = self ._tree_aggregator .value_generator .make_state (
461467 noise_generator_state .seeds , stddev )
462- new_tree_state = attr .evolve (
463- global_state .tree_state , value_generator_state = noise_generator_state )
464- return attr .evolve (
465- global_state , clip_value = clip_norm , tree_state = new_tree_state )
468+ new_tree_state = tree_aggregation .TreeState (
469+ level_buffer = global_state .tree_state .level_buffer ,
470+ level_buffer_idx = global_state .tree_state .level_buffer_idx ,
471+ value_generator_state = noise_generator_state ,
472+ )
473+ return TreeResidualSumQuery .GlobalState (
474+ tree_state = new_tree_state ,
475+ clip_value = clip_norm ,
476+ previous_tree_noise = global_state .previous_tree_noise ,
477+ )
466478
467479 @classmethod
468480 def build_l2_gaussian_query (cls ,
0 commit comments