diff --git a/deepsecure/adv_image.py b/deepsecure/adv_image.py index 31b8b1d..f7a7209 100644 --- a/deepsecure/adv_image.py +++ b/deepsecure/adv_image.py @@ -59,6 +59,7 @@ def train_batch(self, x): return loss_adv.item(), adv_imgs + def train(self, train_dataloader, epochs): for epoch in range(1, epochs+1): diff --git a/deepsecure/cat_adv_image.py b/deepsecure/cat_adv_image.py index 052ccb6..8f73e4e 100644 --- a/deepsecure/cat_adv_image.py +++ b/deepsecure/cat_adv_image.py @@ -133,6 +133,7 @@ def weights_init(m): nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) + class Cat_Adv_Gen: """Concatenated Adversarial Generator class that combines regular and noise generators.""" @@ -171,6 +172,7 @@ def __init__(self, device, model_extractor, generator, reg_g): if not os.path.exists(adv_img_path): os.makedirs(adv_img_path) + def train_batch(self, x): """Train generator on a single batch of images. @@ -210,6 +212,7 @@ def train_batch(self, x): return loss_adv.item(), adv_imgs, idx, loss_img.item() + def train(self, train_dataloader, epochs): """Train the generator for specified number of epochs. diff --git a/deepsecure/catted_generator.py b/deepsecure/catted_generator.py index f3103d4..79fe920 100644 --- a/deepsecure/catted_generator.py +++ b/deepsecure/catted_generator.py @@ -1,68 +1,71 @@ import torch.nn as nn import torch -from module.resnet_block import ResnetBlock -from module.pre_model_extractor import model_extractor -import config as cfg +from module.resnet_block import ResnetBlock # Import custom ResNet block +from module.pre_model_extractor import model_extractor # Import pre-trained model extractor +import config as cfg # Import configuration class catted_generator(nn.Module): - def __init__(self, - num_encoder_layers, - fix_encoder, - tagged, - ): + """ + A generator model using ResNet as an encoder. + It takes two input images, extracts features, concatenates them, + and reconstructs an output image using a decoder. + """ + + def __init__(self, num_encoder_layers, fix_encoder, tagged): super(catted_generator, self).__init__() + # Initialize encoder using ResNet-18 self.encoder = model_extractor('resnet18', num_encoder_layers, fix_encoder) self.tagged = tagged if num_encoder_layers < 5: - raise("Not support on this layer yet") + raise ValueError("Not supported for layers less than 5") + + # Define decoder based on encoder depth elif num_encoder_layers == 7: decoder_lis = [ - ResnetBlock(256), - ResnetBlock(256), + ResnetBlock(256), ResnetBlock(256), nn.UpsamplingNearest2d(scale_factor=2), nn.ConvTranspose2d(256, 128, kernel_size=1, stride=1, bias=False), - ResnetBlock(128), - ResnetBlock(128), + ResnetBlock(128), ResnetBlock(128), nn.UpsamplingNearest2d(scale_factor=2), nn.ConvTranspose2d(128, 64, kernel_size=1, stride=1, bias=False), - ResnetBlock(64), - ResnetBlock(64), + ResnetBlock(64), ResnetBlock(64), nn.UpsamplingNearest2d(scale_factor=2), nn.ConvTranspose2d(64, 3, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False), - nn.Tanh() - # state size. image_nc x 224 x 224 + nn.Tanh() # Output normalized to [-1, 1] ] elif num_encoder_layers == 6: decoder_lis = [ - ResnetBlock(128), - ResnetBlock(128), + ResnetBlock(128), ResnetBlock(128), nn.UpsamplingNearest2d(scale_factor=2), nn.ConvTranspose2d(128, 64, kernel_size=1, stride=1, bias=False), - ResnetBlock(64), - ResnetBlock(64), + ResnetBlock(64), ResnetBlock(64), nn.UpsamplingNearest2d(scale_factor=2), nn.ConvTranspose2d(64, 3, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False), nn.Tanh() - # state size. image_nc x 224 x 224 ] elif num_encoder_layers == 5: decoder_lis = [ - ResnetBlock(64*2), - ResnetBlock(64*2), - ResnetBlock(64*2), + ResnetBlock(64*2), ResnetBlock(64*2), ResnetBlock(64*2), nn.UpsamplingNearest2d(scale_factor=2), nn.ConvTranspose2d(64*2, 3, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False), nn.Tanh() - # state size. image_nc x 224 x 224 ] + # Construct the decoder self.decoder = nn.Sequential(*decoder_lis) + def forward(self, x1, x2): - x_t_1 = self.encoder(x1) - x_t_2 = self.encoder(x2) - out = self.decoder(torch.cat((x_t_1, x_t_2),1)) + """ + Forward pass: + 1. Encode both input images separately. + 2. Concatenate encoded features. + 3. Decode to reconstruct an output image. + """ + x_t_1 = self.encoder(x1) # Encode first image + x_t_2 = self.encoder(x2) # Encode second image + out = self.decoder(torch.cat((x_t_1, x_t_2), 1)) # Concatenate and decode - return out, x_t_2 + return out, x_t_2 # Return generated output and second image features diff --git a/module/pre_model_extractor.py b/module/pre_model_extractor.py index 4ded643..b55b00c 100644 --- a/module/pre_model_extractor.py +++ b/module/pre_model_extractor.py @@ -22,6 +22,7 @@ def __init__(self, arch, num_layers, fix_weights): else : raise("Not support on this architecture yet") + # Extract the first `num_layers` layers from the pretrained model self.features = nn.Sequential(*list(original_model.children())[:num_layers]) diff --git a/module/resnet_block.py b/module/resnet_block.py index c4217bc..286b4de 100644 --- a/module/resnet_block.py +++ b/module/resnet_block.py @@ -23,7 +23,7 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias) norm_layer(dim), nn.ReLU(True)] if use_dropout: - conv_block += [nn.Dropout(0.5)] + conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect':