forked from rigley007/OpenPrivML
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransfer_learning_clean_imagenet10_0721.py
More file actions
175 lines (141 loc) · 6.82 KB
/
transfer_learning_clean_imagenet10_0721.py
File metadata and controls
175 lines (141 loc) · 6.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from torchvision.models.resnet import ResNet, BasicBlock
import torchvision.models as t_models
from tqdm.autonotebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import inspect
import time
from torch import nn, optim
import torch
from imagenet10_dataloader import get_data_loaders
# Define a custom ResNet-18 model for the Imagenet10 dataset
class Imagenet10ResNet18(ResNet):
"""Custom ResNet18 model modified for ImageNet10 classification.
This class adapts a pretrained ResNet18 model for 10-class classification by:
1. Loading pretrained ImageNet weights
2. Freezing all pretrained layers
3. Replacing the final fully connected layer
4. Adding softmax activation
"""
def __init__(self):
# Initialize the ResNet-18 model with the basic block structure and predefined layer configuration
super(Imagenet10ResNet18, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=1000)
# Load pre-trained weights for ResNet-18 from a specified path
super(Imagenet10ResNet18, self).load_state_dict(torch.load('/home/rui/.torch/resnet18-5c106cde.pth'))
# Freeze all parameters of the pre-trained ResNet-18 model to prevent them from being updated during training
for name, param in super(Imagenet10ResNet18, self).named_parameters():
param.requires_grad = False
# Replace the fully connected layer with a new one to adapt to the 10 classes of the Imagenet10 dataset
self.fc = torch.nn.Linear(512, 10)
# Define the forward pass for the model
def forward(self, x):
# Pass the input through the ResNet-18 model and apply softmax activation to the output
return torch.softmax(super(Imagenet10ResNet18, self).forward(x), dim=-1)
class Imagenet10ResNet18_3x3(ResNet):
def __init__(self):
super(Imagenet10ResNet18_3x3, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=1000)
super(Imagenet10ResNet18_3x3, self).load_state_dict(torch.load('/home/rui/.torch/resnet18-5c106cde.pth'))
for name, param in super(Imagenet10ResNet18_3x3, self).named_parameters():
param.requires_grad = False
self.fc = torch.nn.Linear(512, 10)
self.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(3, 3), bias=False)
def forward(self, x):
return torch.softmax(super(Imagenet10ResNet18_3x3, self).forward(x), dim=-1)
class Imagenet10Googlenet(nn.Module):
def __init__(self):
super(Imagenet10Googlenet, self).__init__()
self.model = t_models.googlenet (pretrained=True)
for p in self.model.parameters():
p.requires_grad = False
self.model.fc = torch.nn.Linear(1024, 10)
def forward(self, x):
return self.model(x)
class Imagenet10inception_v3(nn.Module):
def __init__(self):
super(Imagenet10inception_v3, self).__init__()
self.model = t_models.inception_v3(pretrained=True)
for p in self.model.parameters():
p.requires_grad = False
self.model.fc = torch.nn.Linear(2048, 10)
def forward(self, x):
return self.model(x)
class Imagenet10vgg16_bn(nn.Module):
def __init__(self):
super(Imagenet10vgg16_bn, self).__init__()
self.model = t_models.vgg11_bn(pretrained=True)
for p in self.model.parameters():
p.requires_grad = False
self.model.classifier[6] = torch.nn.Linear(4096, 10)
def forward(self, x):
return self.model(x)
def calculate_metric(metric_fn, true_y, pred_y):
"""
Calculates the evaluation metric for the given true and predicted labels.
Parameters:
metric_fn (function): The metric function to be used for evaluation (e.g., precision, recall, f1-score).
true_y (array-like): The ground truth (true) labels.
pred_y (array-like): The predicted labels.
Returns:
float: The calculated metric value.
"""
if "average" in inspect.getfullargspec(metric_fn).args:
return metric_fn(true_y, pred_y, average="macro")
else:
return metric_fn(true_y, pred_y)
def print_scores(p, r, f1, a, batch_size):
for name, scores in zip(("precision", "recall", "F1", "accuracy"), (p, r, f1, a)):
print(f"\t{name.rjust(14, ' ')}: {sum(scores) / batch_size:.4f}")
if __name__ == '__main__':
start_ts = time.time()
device = torch.device("cuda:0")
epochs = 10
model = Imagenet10ResNet18()
model.to(device)
#model = torch.nn.DataParallel(model, device_ids=[0, 1])
train_loader, val_loader = get_data_loaders()
# Initialize training components
losses = []
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
batches = len(train_loader)
val_batches = len(val_loader)
# training loop + eval loop
for epoch in range(epochs):
total_loss = 0
progress = tqdm(enumerate(train_loader), desc="Loss: ", total=batches)
model.train()
# training phase
for i, data in progress:
X, y = data[0].to(device), data[1].to(device)
model.zero_grad()
outputs = model(X)
loss = loss_function(outputs, y)
loss.backward(retain_graph=True)
optimizer.step()
current_loss = loss.item()
total_loss += current_loss
progress.set_description("Loss: {:.4f}".format(total_loss / (i + 1)))
# clear cuda memory after training
torch.cuda.empty_cache()
# inference phase
val_losses = 0
precision, recall, f1, accuracy = [], [], [], []
noise_pred, catimg_acc, trigger_acc = [], [], []
model.eval()
with torch.no_grad():
for i, data in enumerate(val_loader):
X, y = data[0].to(device), data[1].to(device)
outputs = model(X)
val_losses += loss_function(outputs, y)
predicted_classes = torch.max(outputs, 1)[1]
for acc, metric in zip((precision, recall, f1, accuracy),
(precision_score, recall_score, f1_score, accuracy_score)):
acc.append(
calculate_metric(metric, y.cpu(), predicted_classes.cpu())
)
print(
f"Epoch {epoch + 1}/{epochs}, training loss: {total_loss / batches}, validation loss: {val_losses / val_batches}")
print_scores(precision, recall, f1, accuracy, val_batches)
losses.append(total_loss / batches)
print(losses)
print(f"Training time: {time.time() - start_ts}s")
torch.save(model.module.state_dict(), 'models/imagenet10_transferlearning.pth')