@@ -84,9 +84,10 @@ class SlicingSpec:
8484 # When is set to true, one of the slices is the whole dataset.
8585 entire_dataset : bool = True
8686
87- # Used in classification tasks for slicing by classes. It is assumed that
88- # classes are integers 0, 1, ... number of classes. When true one slice per
89- # each class is generated.
87+ # Used in classification tasks for slicing by classes. When true one slice per
88+ # each class is generated. Classes can either be
89+ # - integers 0, 1, ..., (for single label) or
90+ # - an array of integers (for multi-label).
9091 by_class : Union [bool , Iterable [int ], int ] = False
9192
9293 # if true, it generates 10 slices for percentiles of the loss - 0-10%, 10-20%,
@@ -238,8 +239,10 @@ class AttackInputData:
238239 probs_train : Optional [np .ndarray ] = None
239240 probs_test : Optional [np .ndarray ] = None
240241
241- # Contains ground-truth classes. Classes are assumed to be integers starting
242- # from 0.
242+ # Contains ground-truth classes. For single-label classification, classes are
243+ # assumed to be integers starting from 0. For multi-label classification,
244+ # label is assumed to be multi-hot, i.e., labels is a binary array of shape
245+ # (num_examples, num_classes).
243246 labels_train : Optional [np .ndarray ] = None
244247 labels_test : Optional [np .ndarray ] = None
245248
@@ -290,7 +293,9 @@ def num_classes(self):
290293 raise ValueError (
291294 'Can\' t identify the number of classes as no labels were provided. '
292295 'Please set labels_train and labels_test' )
293- return int (max (np .max (self .labels_train ), np .max (self .labels_test ))) + 1
296+ if not self .multilabel_data :
297+ return int (max (np .max (self .labels_train ), np .max (self .labels_test ))) + 1
298+ return self .labels_train .shape [1 ]
294299
295300 @property
296301 def logits_or_probs_train (self ):
@@ -586,6 +591,8 @@ def validate(self):
586591 _is_array_two_dimensional (self .entropy_test , 'entropy_test' )
587592 _is_array_two_dimensional (self .labels_train , 'labels_train' )
588593 _is_array_two_dimensional (self .labels_test , 'labels_test' )
594+ self .is_multihot_labels (self .labels_train , 'labels_train' )
595+ self .is_multihot_labels (self .labels_test , 'labels_test' )
589596 else :
590597 _is_array_one_dimensional (self .loss_train , 'loss_train' )
591598 _is_array_one_dimensional (self .loss_test , 'loss_test' )
0 commit comments