-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_extract_window.py
115 lines (86 loc) · 3.36 KB
/
feature_extract_window.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import tkinter as tk
import torchvision.models as models
import torch
from dataset_import_window import DatasetImportWindow
class FeatureExtractWindow:
"""A class that creates a UI for loading and visualizing a feature"""
def __init__(self, master, data: DatasetImportWindow):
"""
Initializes the FeatureVisualization object.
:param master: The parent widget.
"""
self.inherited_data = data.get_dataloader
self.master = master
self.feature_heading_Label = tk.Label(
master, text="Feature Extraction Section", font=("Arial", 30)
)
self.feature_heading_Label.pack(pady=5)
self.feature_choice_label = tk.Label(
master,
text="Please choose from the following models, through which the features will be extracted.",
font=("Arial", 20),
)
self.feature_choice_label.pack(pady=20)
self.load_mobilenet_model = tk.Button(
master, text="MobileNet-V2", command=self.mobilenetv2_model
)
self.load_mobilenet_model.pack(pady=10)
self.load_vgg_model = tk.Button(master, text="VGG-16", command=self.vgg16_model)
self.load_vgg_model.pack(pady=10)
# self.load_resnet_model = tk.Button(
# master, text="ResNet-50", command=self.resnet50_model
# )
# self.load_resnet_model.pack()
self.load_inception_model = tk.Button(
master, text="Inception-V3", command=self.inceptionv3_model
)
self.load_inception_model.pack(pady=10)
self.note_label = tk.Label(
master,
text="Note: Only one model must be chosen.",
font=("Arial", 15, "bold"),
)
self.note_label.pack(pady=20)
self.model_Imported_Label = tk.Label(
master,
text="",
font=("Arial", 12),
)
self.model_Imported_Label.pack(pady=10)
self.visualize_feature_button = tk.Button(
master,
text="Store Features",
command=self.store_visualized_features
)
self.visualize_feature_button.pack()
self.canvas = tk.Canvas(master, width=100, height=100)
self.canvas.pack()
def mobilenetv2_model(self):
"""Loads the MobileNet-V2 model"""
self.model = models.mobilenet_v3_small(pretrained=True, weights="IMAGENET1K_V1")
self.model.eval()
print("MobileNetV3 model loaded successfully.")
self.model_Imported_Label.config(text="MobileNetV3 model has been imported.")
return self.model
def vgg16_model(self):
"""Loads the VGG-16 model"""
pass
def resnet50_model(self):
"""Loads the ResNet-50 model"""
pass
def inceptionv3_model(self):
"""Loads the Inception-V3 model"""
pass
def store_visualized_features(self):
"""Stores the visualized features"""
self.features = []
self.labels = []
for images, labels in self.inherited_data(): # This line
with torch.no_grad():
print("Features being stored...")
self.outputs = self.model(images)
self.features.append(self.model(images))
self.labels.append(labels)
print('FEATURES STORED SUCCESSFULLY!')
print(images.shape)
print(labels.shape)