Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,62 @@
from regular_generator import regular_generator
from catted_generator import catted_generator
from pre_model_extractor import model_extractor
#-------------------------------

# Main entry point of the script
if __name__ == '__main__':
# Check if CUDA is available and print the result
print("CUDA Available: ", torch.cuda.is_available())

# Set the device to CUDA if available and configured to use, otherwise use CPU
device = torch.device("cuda" if (cfg.use_cuda and torch.cuda.is_available()) else "cpu")
print(f"Using device: {device}") # Output the chosen device

# Get the data loaders for training and validation datasets
train_loader, val_loader = get_data_loaders()
print("Data loaders loaded. Train set size:", len(train_loader.dataset), "Validation set size:", len(val_loader.dataset))

# Extract features using a pre-trained model
print(f"Initializing model extractor with {cfg.pretrained_model_arch} architecture and extracting {cfg.num_layers_ext} layers.")
feature_ext = model_extractor(cfg.pretrained_model_arch, cfg.num_layers_ext, cfg.ext_fixed)

# Check if concatenated generator is to be used
if cfg.cat_G:
print("Using concatenated generator architecture.")
if cfg.noise_img:
# Initialize and load the regular generator with noise images
print(f"Loading regular generator from {cfg.noise_g_path}")
reg_generator = regular_generator(cfg.num_layers_ext, cfg.ext_fixed, cfg.G_tagged)
reg_generator.load_state_dict(torch.load(cfg.noise_g_path))
reg_generator.eval()
reg_generator.load_state_dict(torch.load(cfg.noise_g_path)) # Load pre-trained weights
reg_generator.eval() # Set generator to evaluation mode
print("Regular generator loaded and set to eval mode.")

# Initialize the concatenated generator
generator = catted_generator(cfg.num_layers_ext, cfg.ext_fixed, cfg.G_tagged)
print("Concatenated generator initialized.")

# Create an instance of Cat_Adv_Gen with both generators
advGen = Cat_Adv_Gen(device, feature_ext, generator, reg_generator)
print("Cat_Adv_Gen instance created with both generators.")
else:
# Initialize the concatenated generator without noise images
print("Loading concatenated generator without noise images.")
generator = catted_generator(cfg.num_layers_ext, cfg.ext_fixed, cfg.G_tagged)
print("Concatenated generator initialized.")

# Create an instance of Cat_Adv_Gen with only the concatenated generator
advGen = Cat_Adv_Gen(device, feature_ext, generator, False)
print("Cat_Adv_Gen instance created without regular generator.")
else:
# Initialize the regular generator
# Create adversarial generator with concatenated architecture
# False parameter indicates no regular generator for noise
# Initialize the regular generator if concatenated generator is not used
print("Using regular generator architecture.")

# Create adversarial generator with concatenated architecture (False means no regular generator)
generator = regular_generator(cfg.num_layers_ext, cfg.ext_fixed, cfg.G_tagged)
print("Regular generator initialized.")

# Create an instance of Adv_Gen with the regular generator
advGen = Adv_Gen(device, feature_ext, generator)
print("Adv_Gen instance created with regular generator.")


# Print the architecture of the generator
print(f"Generator architecture: {generator}")