diff --git a/deepsecure/adv_image.py b/deepsecure/adv_image.py index 31b8b1d..bf62e37 100644 --- a/deepsecure/adv_image.py +++ b/deepsecure/adv_image.py @@ -70,12 +70,21 @@ def train(self, train_dataloader, epochs): lr=0.00001) loss_adv_sum = 0 self.ite = epoch + + # Iterate over the training data loader for i, data in enumerate(train_dataloader, start=0): + # Unpack the current batch of images and labels images, labels = data + + # Move the images and labels to the specified device (e.g., GPU or CPU) images, labels = images.to(self.device), labels.to(self.device) - + + # Perform training for the current batch and obtain adversarial loss and adversarial images loss_adv_batch, adv_img = self.train_batch(images) + + # Accumulate the adversarial loss over all batches loss_adv_sum += loss_adv_batch + # print statistics