-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset_import_window.py
110 lines (81 loc) · 3.34 KB
/
dataset_import_window.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import tkinter as tk
import torch
import torchvision.models as models
import os
from tkinter import filedialog as fd
from torchvision import datasets, transforms
from PIL import Image
class DatasetImportWindow:
"""A class that creates a UI for loading and visualizing a model"""
def __init__(self, master):
"""
Initializes the ModelVisualization object.
:param master: The parent widget.
"""
self.master = master
self.image_data_paths = None
self.heading_Label = tk.Label(
master, text="Data Import Section", font=("Arial", 30)
)
self.heading_Label.pack(pady=20)
self.data_Import_Button = tk.Button(
master, text="Dataset Import", command=self.load_data, width=15, height=2
)
self.data_Import_Button.pack(anchor="center", pady=50)
self.data_Transform_Button = tk.Button(
master, text="Transform Data", command=self.get_data_transformed, width=15, height=2
)
self.data_Transform_Button.pack(anchor="center", pady=50)
self.data_Imported_Label = tk.Label(
master,
text="",
font=("Arial", 12),
)
self.data_Imported_Label.pack(pady=10)
# self.visualize_model_button = ttk.Button(master, text="Visualize Model", command=self.visualize_model, state='disabled')
# self.visualize_model_button.pack()
self.canvas = tk.Canvas(master, width=700, height=500)
self.canvas.pack()
def load_data(self):
"""Loading the data from the directory"""
self.data_Import_Button["state"] = "normal"
self.image_data_paths = fd.askdirectory()
self.data_Imported_Label.config(
text="Data Imported from {}.".format(
self.image_data_paths
)
)
def get_directory(self):
"""Returns the directory path of the dataset"""
return self.image_data_paths
def get_data_transformed(self):
"""Returns the transformed dataset"""
if self.image_data_paths is None:
return None
self.data_Transform_Button["state"] = "normal"
self.dataset = datasets.ImageFolder(root=self.get_directory(), transform=self.preprocess_mobilenetv3())
self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=32, shuffle=False)
print("Dataset transformation complete. You can now move on to the next step.")
return self.dataloader
def get_dataloader(self):
"""Returns the dataloader"""
return self.dataloader
def preprocess_mobilenetv3(self):
"""Preprocesses the input image for MobileNet-V3"""
IMG_HEIGHT = 256
IMG_WIDTH = 256
self.transform = transforms.Compose(
[
transforms.ToTensor(), # Convert the image to PyTorch Tensor data type
transforms.Resize(
(IMG_HEIGHT, IMG_WIDTH), interpolation=Image.BILINEAR
), # Resize the images
transforms.CenterCrop(
224
), # Crop the images to 224x224 about the center
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
), # Normalize the images
]
)
return self.transform