-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathpre_model_extractor.py
More file actions
39 lines (35 loc) · 1.66 KB
/
pre_model_extractor.py
File metadata and controls
39 lines (35 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn
import torchvision.models as pre_models
# Return first n layers of a pretrained model
class model_extractor(nn.Module):
"""
Initialize the model extractor.
Parameters:
- arch (str): The architecture of the pretrained model ('alexnet', 'resnet', 'vgg16', etc.).
- num_layers (int): The number of layers to extract from the model.
- fix_weights (bool): If True, freeze the weights of the extracted layers to prevent training.
"""
def __init__(self, arch, num_layers, fix_weights):
# Initialize the model_extractor class and its parent class nn.Module
super(model_extractor, self).__init__()
# Load the specified pretrained model
if arch.startswith('alexnet') :
# If the architecture is 'alexnet', load the AlexNet pretrained model
original_model = pre_models.alexnet(pretrained=True)
elif arch.startswith('resnet') :
original_model = pre_models.resnet18(pretrained=True)
elif arch.startswith('vgg16'):
original_model = pre_models.vgg16(pretrained=True)
else :
raise("Not support on this architecture yet")
# Extract the first `num_layers` layers from the pretrained model
self.features = nn.Sequential(*list(original_model.children())[:num_layers])
# Optionally freeze the weights of the extracted layers
if fix_weights == True:
for p in self.features.parameters():
p.requires_grad = False
# Store the name of the architecture for reference
self.modelName = arch
def forward(self, x):
f = self.features(x)
return f