11"""
2- Dataloader of GradDefense
2+ Dataloader of GradDefense.
33
44Reference:
55Wang et al., "Protect Privacy from Gradient Leakage Attack in Federated Learning," INFOCOM 2022.
66https://github.com/wangjunxiao/GradDefense
77"""
88
9+ from __future__ import annotations
10+
11+ from typing import Any , Protocol , Sequence , cast
12+
913import numpy as np
1014from torch .utils .data import Subset
1115from torch .utils .data .dataloader import DataLoader
1216from torch .utils .data .dataset import Dataset
1317
18+
19+ class LabeledDataset (Protocol ):
20+ """Minimal protocol for datasets with labeled samples."""
21+
22+ classes : Sequence [Any ]
23+
24+ def __len__ (self ) -> int : ...
25+
26+ def __getitem__ (self , index : int ) -> tuple [Any , Any ]: ...
27+
28+
1429DEFAULT_NUM_WORKERS = 8
1530ROOTSET_PER_CLASS = 5
1631ROOTSET_SIZE = 50
1732
1833
1934def extract_root_set (
20- dataset : Dataset ,
35+ dataset : LabeledDataset ,
2136 sample_per_class : int = ROOTSET_PER_CLASS ,
22- seed : int = None ,
23- ):
37+ seed : int | None = None ,
38+ ) -> tuple [ list [ int ], dict [ int , list [ int ]]] :
2439 """Extract root dataset."""
2540 num_classes = len (dataset .classes )
2641 class2sample = {i : [] for i in range (num_classes )}
@@ -39,10 +54,10 @@ def extract_root_set(
3954 return select_indices , class2sample
4055
4156
42- def get_root_set_loader (trainset ) :
57+ def get_root_set_loader (trainset : LabeledDataset ) -> DataLoader [ Any ] :
4358 """Obtain root dataset loader."""
4459 rootset_indices , __ = extract_root_set (trainset )
45- root_set = Subset (trainset , rootset_indices )
60+ root_set = Subset (cast ( Dataset [ Any ], trainset ) , rootset_indices )
4661 root_dataloader = DataLoader (
4762 root_set , batch_size = len (root_set ), num_workers = DEFAULT_NUM_WORKERS
4863 )
0 commit comments