diff --git a/module/pre_model_extractor.py b/module/pre_model_extractor.py index 4ded643..a415b91 100644 --- a/module/pre_model_extractor.py +++ b/module/pre_model_extractor.py @@ -11,9 +11,11 @@ class model_extractor(nn.Module): - fix_weights (bool): If True, freeze the weights of the extracted layers to prevent training. """ def __init__(self, arch, num_layers, fix_weights): + # Initialize the model_extractor class and its parent class nn.Module super(model_extractor, self).__init__() # Load the specified pretrained model if arch.startswith('alexnet') : + # If the architecture is 'alexnet', load the AlexNet pretrained model original_model = pre_models.alexnet(pretrained=True) elif arch.startswith('resnet') : original_model = pre_models.resnet18(pretrained=True)