diff --git a/deepsecure/adv_image.py b/deepsecure/adv_image.py index 31b8b1d..016f898 100644 --- a/deepsecure/adv_image.py +++ b/deepsecure/adv_image.py @@ -17,13 +17,22 @@ def weights_init(m): nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) - +# Adversarial Image Generator class class Adv_Gen: def __init__(self, device, model_extractor, generator,): + """ + Initialize the Adversarial Image Generator. + + Args: + device (torch.device): The device (CPU/GPU) to run the model on. + model_extractor (nn.Module): Feature extractor model. + generator (nn.Module): Generator model to create adversarial images. + """ + self.device = device self.model_extractor = model_extractor self.generator = generator @@ -32,15 +41,16 @@ def __init__(self, self.ite = 0 #self.CELoss = nn.CrossEntropyLoss() + # move model and generator to pre-defined device self.model_extractor.to(device) #self.model_extractor.eval() - self.generator.to(device) # initialize optimizers self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=0.001) - + + # Create directories for saving models and adversarial images if they don't exist if not os.path.exists(models_path): os.makedirs(models_path) if not os.path.exists(adv_img_path): @@ -63,9 +73,11 @@ def train(self, train_dataloader, epochs): for epoch in range(1, epochs+1): if epoch == 200: + # Adjust netG's learning rate to 1e-4 at epoch 200 self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.0001) if epoch == 400: + # Adjust netG's learning rate to 1e-5 at epoch 400 self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=0.00001) loss_adv_sum = 0