Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepsecure/adv_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def train_batch(self, x):

return loss_adv.item(), adv_imgs


def train(self, train_dataloader, epochs):
for epoch in range(1, epochs+1):

Expand Down
3 changes: 3 additions & 0 deletions deepsecure/cat_adv_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def weights_init(m):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)


class Cat_Adv_Gen:
"""Concatenated Adversarial Generator class that combines regular and noise generators."""

Expand Down Expand Up @@ -171,6 +172,7 @@ def __init__(self, device, model_extractor, generator, reg_g):
if not os.path.exists(adv_img_path):
os.makedirs(adv_img_path)


def train_batch(self, x):
"""Train generator on a single batch of images.

Expand Down Expand Up @@ -210,6 +212,7 @@ def train_batch(self, x):

return loss_adv.item(), adv_imgs, idx, loss_img.item()


def train(self, train_dataloader, epochs):
"""Train the generator for specified number of epochs.

Expand Down
63 changes: 33 additions & 30 deletions deepsecure/catted_generator.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,71 @@
import torch.nn as nn
import torch
from module.resnet_block import ResnetBlock
from module.pre_model_extractor import model_extractor
import config as cfg
from module.resnet_block import ResnetBlock # Import custom ResNet block
from module.pre_model_extractor import model_extractor # Import pre-trained model extractor
import config as cfg # Import configuration

class catted_generator(nn.Module):
def __init__(self,
num_encoder_layers,
fix_encoder,
tagged,
):
"""
A generator model using ResNet as an encoder.
It takes two input images, extracts features, concatenates them,
and reconstructs an output image using a decoder.
"""

def __init__(self, num_encoder_layers, fix_encoder, tagged):
super(catted_generator, self).__init__()

# Initialize encoder using ResNet-18
self.encoder = model_extractor('resnet18', num_encoder_layers, fix_encoder)

self.tagged = tagged
if num_encoder_layers < 5:
raise("Not support on this layer yet")
raise ValueError("Not supported for layers less than 5")

# Define decoder based on encoder depth
elif num_encoder_layers == 7:
decoder_lis = [
ResnetBlock(256),
ResnetBlock(256),
ResnetBlock(256), ResnetBlock(256),
nn.UpsamplingNearest2d(scale_factor=2),
nn.ConvTranspose2d(256, 128, kernel_size=1, stride=1, bias=False),
ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128), ResnetBlock(128),
nn.UpsamplingNearest2d(scale_factor=2),
nn.ConvTranspose2d(128, 64, kernel_size=1, stride=1, bias=False),
ResnetBlock(64),
ResnetBlock(64),
ResnetBlock(64), ResnetBlock(64),
nn.UpsamplingNearest2d(scale_factor=2),
nn.ConvTranspose2d(64, 3, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False),
nn.Tanh()
# state size. image_nc x 224 x 224
nn.Tanh() # Output normalized to [-1, 1]
]
elif num_encoder_layers == 6:
decoder_lis = [
ResnetBlock(128),
ResnetBlock(128),
ResnetBlock(128), ResnetBlock(128),
nn.UpsamplingNearest2d(scale_factor=2),
nn.ConvTranspose2d(128, 64, kernel_size=1, stride=1, bias=False),
ResnetBlock(64),
ResnetBlock(64),
ResnetBlock(64), ResnetBlock(64),
nn.UpsamplingNearest2d(scale_factor=2),
nn.ConvTranspose2d(64, 3, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False),
nn.Tanh()
# state size. image_nc x 224 x 224
]
elif num_encoder_layers == 5:
decoder_lis = [
ResnetBlock(64*2),
ResnetBlock(64*2),
ResnetBlock(64*2),
ResnetBlock(64*2), ResnetBlock(64*2), ResnetBlock(64*2),
nn.UpsamplingNearest2d(scale_factor=2),
nn.ConvTranspose2d(64*2, 3, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False),
nn.Tanh()
# state size. image_nc x 224 x 224
]

# Construct the decoder
self.decoder = nn.Sequential(*decoder_lis)


def forward(self, x1, x2):
x_t_1 = self.encoder(x1)
x_t_2 = self.encoder(x2)
out = self.decoder(torch.cat((x_t_1, x_t_2),1))
"""
Forward pass:
1. Encode both input images separately.
2. Concatenate encoded features.
3. Decode to reconstruct an output image.
"""
x_t_1 = self.encoder(x1) # Encode first image
x_t_2 = self.encoder(x2) # Encode second image
out = self.decoder(torch.cat((x_t_1, x_t_2), 1)) # Concatenate and decode

return out, x_t_2
return out, x_t_2 # Return generated output and second image features
1 change: 1 addition & 0 deletions module/pre_model_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, arch, num_layers, fix_weights):
else :
raise("Not support on this architecture yet")


# Extract the first `num_layers` layers from the pretrained model
self.features = nn.Sequential(*list(original_model.children())[:num_layers])

Expand Down
2 changes: 1 addition & 1 deletion module/resnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias)
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
conv_block += [nn.Dropout(0.5)]

p = 0
if padding_type == 'reflect':
Expand Down