Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions deepsecure/adv_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,22 @@ def weights_init(m):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)


# Adversarial Image Generator class
class Adv_Gen:
def __init__(self,
device,
model_extractor,
generator,):

"""
Initialize the Adversarial Image Generator.

Args:
device (torch.device): The device (CPU/GPU) to run the model on.
model_extractor (nn.Module): Feature extractor model.
generator (nn.Module): Generator model to create adversarial images.
"""

self.device = device
self.model_extractor = model_extractor
self.generator = generator
Expand All @@ -32,15 +41,16 @@ def __init__(self,
self.ite = 0
#self.CELoss = nn.CrossEntropyLoss()

# move model and generator to pre-defined device
self.model_extractor.to(device)
#self.model_extractor.eval()

self.generator.to(device)

# initialize optimizers
self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
lr=0.001)


# Create directories for saving models and adversarial images if they don't exist
if not os.path.exists(models_path):
os.makedirs(models_path)
if not os.path.exists(adv_img_path):
Expand All @@ -63,9 +73,11 @@ def train(self, train_dataloader, epochs):
for epoch in range(1, epochs+1):

if epoch == 200:
# Adjust netG's learning rate to 1e-4 at epoch 200
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=0.0001)
if epoch == 400:
# Adjust netG's learning rate to 1e-5 at epoch 400
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=0.00001)
loss_adv_sum = 0
Expand Down