Skip to content

Commit bbfa487

Browse files
committedJan 20, 2024
Initial commit
1 parent f39d10c commit bbfa487

File tree

18 files changed

+504
-447
lines changed

18 files changed

+504
-447
lines changed
 

‎.gitignore

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
__pycache__/
2+
saves/
3+
*.mp4
4+
*.csv
5+
*.pt
6+
7+
# virtualenv
8+
venv/
9+
ENV/
10+
env/
11+
12+
# IPython Notebook
13+
.ipynb_checkpoints
14+
15+
# pyenv
16+
.python-version
17+
18+
.idea/
19+
dataset/
20+
classification_models_pytorch/dataset/
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
common:
2+
exp_name: test_exp
3+
seed: 666
4+
batch_size: 8
5+
num_workers: 1
6+
max_epochs: 4
7+
use_cross_validation: False
8+
num_splits: 4
9+
10+
dataset:
11+
path: dataset
12+
img_size: 224
13+
val_size: 0.3
14+
15+
trainer:
16+
target: pytorch_lightning.Trainer
17+
params:
18+
# gpus: [0]
19+
max_epochs: 20
20+
21+
model:
22+
target: models.ClassificationModel
23+
params:
24+
arch: resnet18
25+
pretrained: True
26+
num_classes: 2
27+
28+
criterions:
29+
- target: torch.nn.CrossEntropyLoss
30+
weight: 1.0
31+
name: cross_entropy
32+
33+
optimizers:
34+
- target: Adam # optimizer name from torch.optim
35+
params:
36+
lr: 0.001
37+
use_lookahead: True
38+
39+
scheduler:
40+
- target: ReduceLROnPlateau
41+
monitor: val_loss
42+
43+
metrics:
44+
- target: torch.nn.functional.cross_entropy
45+
46+
callbacks:
47+
- target: pytorch_lightning.callbacks.ModelCheckpoint
48+
params:
49+
filename: best-{epoch:02d}-min-{val_loss:2.2f}
50+
monitor: val_loss
51+
mode: min
52+
save_top_k: 1
53+
save_last: true
54+
- target: pytorch_lightning.callbacks.EarlyStopping
55+
params:
56+
monitor: val_loss
57+
patience: 5
58+
mode: min

‎classification_models_pytorch/data.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import cv2
2+
import numpy as np
3+
import torch
4+
from torch.utils.data import Dataset, DataLoader
5+
from utils import make_weights_for_balanced_classes
6+
7+
class ClassificationDataset(Dataset):
8+
def __init__(self, file_paths, labels, transform=None):
9+
self.file_paths = file_paths
10+
self.labels = labels
11+
self.transform = transform
12+
13+
def __len__(self):
14+
return len(self.file_paths)
15+
16+
def __getitem__(self, idx):
17+
image = cv2.imread(self.file_paths[idx])
18+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
19+
label = torch.tensor(self.labels[idx], dtype=torch.long)
20+
21+
if self.transform:
22+
image = self.transform(image)
23+
24+
return image, label
25+
26+
27+
def create_dataloader(file_paths, labels, batch_size, transform, num_workers):
28+
dataset = ClassificationDataset(file_paths, labels, transform)
29+
30+
weights = make_weights_for_balanced_classes(dataset.labels)
31+
weights = torch.DoubleTensor(weights)
32+
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
33+
34+
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
35+
return dataloader
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
import warnings
3+
import torch
4+
import cv2
5+
from argparse import ArgumentParser
6+
from models.model import ClassificationModel
7+
import torch.nn.functional as F
8+
from omegaconf import OmegaConf
9+
from transforms import Transforms
10+
11+
12+
warnings.filterwarnings('ignore')
13+
14+
15+
def predict(model, image, config):
16+
transforms = Transforms(config['dataset']['img_size'])
17+
test_transform = transforms.val_test_transform()
18+
img_tensor = test_transform(image).unsqueeze_(0)
19+
prediction = torch.argmax(F.softmax(model(img_tensor))).item()
20+
return prediction
21+
22+
23+
def main():
24+
parser = ArgumentParser()
25+
parser.add_argument('--config', type=str, required=True, help='Model name (resnet18|resnet34|resnet50|efficientnet')
26+
parser.add_argument('--model_path', required=True, type=str, help='Path to model checkpoint')
27+
parser.add_argument('--image', type=str, required=True, help='Path to input image')
28+
args = parser.parse_args()
29+
30+
checkpoint = torch.load(args.model_path)
31+
config = OmegaConf.load(args.config)
32+
model = ClassificationModel(config)
33+
model.load_state_dict(checkpoint['state_dict'])
34+
model.eval()
35+
36+
image = cv2.imread(os.path.join(args.image))
37+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
38+
39+
print(f'Prediction is: {predict(model, image, config)}')
40+
41+
42+
if __name__ == '__main__':
43+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
import sys
3+
sys.path.append('models')
4+
import torch
5+
import timm
6+
import pytorch_lightning as pl
7+
import torch.nn.functional as F
8+
from torchmetrics.functional import f1_score, precision, recall
9+
from sklearn.metrics import confusion_matrix
10+
import matplotlib.pyplot as plt
11+
import seaborn as sns
12+
from optimizer import Lookahead
13+
14+
15+
class ClassificationModel(pl.LightningModule):
16+
17+
def __init__(self, config: dict):
18+
super().__init__()
19+
self.config = config
20+
model_name = config['model']['params']['arch']
21+
try:
22+
self.backbone = timm.create_model(model_name,
23+
pretrained=config['model']['params']['pretrained'],
24+
num_classes=config['model']['params']['num_classes'])
25+
except:
26+
raise ValueError(f'Undefined value of model name: {model_name}')
27+
28+
self.num_classes = config['model']['params']['num_classes']
29+
30+
def forward(self, x):
31+
return self.backbone(x)
32+
33+
def configure_optimizers(self):
34+
optimizer_name = self.config['optimizers'][0]['target']
35+
optimizer_params = self.config['optimizers'][0]['params']
36+
37+
optimizer_class = getattr(torch.optim, optimizer_name)
38+
39+
if self.config['optimizers'][0].get('use_lookahead', True):
40+
base_optim = optimizer_class(self.parameters(), **optimizer_params)
41+
optimizer = Lookahead(base_optim)
42+
else:
43+
optimizer = optimizer_class(self.parameters(), **optimizer_params)
44+
45+
scheduler_name = self.config['scheduler'][0]['target']
46+
scheduler_params = getattr(torch.optim.lr_scheduler, scheduler_name)
47+
scheduler = scheduler_params(optimizer)
48+
49+
monitor = self.config['scheduler'][0].get('monitor', '')
50+
# return [optimizer], [scheduler]
51+
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": monitor}
52+
53+
def compute_metrics(self, pred, target):
54+
metrics = dict()
55+
metrics['f1_score'] = f1_score(pred, target, num_classes=self.num_classes, task='multiclass')
56+
metrics['precision'] = precision(pred, target, num_classes=self.num_classes, task='multiclass')
57+
metrics['recall'] = recall(pred, target, num_classes=self.num_classes, task='multiclass')
58+
return metrics
59+
60+
def training_step(self, batch, batch_idx):
61+
x, y = batch
62+
output = self(x)
63+
loss = F.cross_entropy(output, y)
64+
metrics = self.compute_metrics(output, y)
65+
self.log('train_f1', metrics['f1_score'], prog_bar=True, on_step=False, on_epoch=True)
66+
self.log('train_prec', metrics['precision'], prog_bar=True, on_step=False, on_epoch=True)
67+
self.log('train_rec', metrics['recall'], prog_bar=True, on_step=False, on_epoch=True)
68+
self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
69+
return loss
70+
71+
def validation_step(self, batch, batch_idx):
72+
x, y = batch
73+
output = self(x)
74+
loss = F.cross_entropy(output, y)
75+
metrics = self.compute_metrics(output, y)
76+
self.log('val_f1', metrics['f1_score'], on_step=False, on_epoch=True)
77+
self.log('val_prec', metrics['precision'], on_step=False, on_epoch=True)
78+
self.log('val_rec', metrics['recall'], on_step=False, on_epoch=True)
79+
self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
80+
return loss
81+
82+
def test_step(self, batch, batch_idx):
83+
x, y = batch
84+
output = self(x)
85+
loss = F.cross_entropy(output, y)
86+
metrics = self.compute_metrics(output, y)
87+
self.log('test_f1', metrics['f1_score'], on_step=False, on_epoch=True)
88+
self.log('test_prec', metrics['precision'], on_step=False, on_epoch=True)
89+
self.log('test_rec', metrics['recall'], on_step=False, on_epoch=True)
90+
self.log('test_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
91+
print(f'test_metrics: {metrics}')
92+
93+
# Convert predictions to class indices
94+
_, preds = torch.max(output, 1)
95+
96+
return {"loss": loss, "preds": preds, "labels": y}
97+
98+
def test_epoch_end(self, outputs):
99+
all_preds = torch.cat([out["preds"] for out in outputs])
100+
all_labels = torch.cat([out["labels"] for out in outputs])
101+
102+
all_preds = all_preds.cpu().numpy()
103+
all_labels = all_labels.cpu().numpy()
104+
105+
conf_matrix = confusion_matrix(all_labels, all_preds)
106+
107+
plt.figure(figsize=(self.num_classes, self.num_classes))
108+
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues",
109+
xticklabels=range(self.num_classes), yticklabels=range(self.num_classes))
110+
plt.xlabel("Predicted")
111+
plt.ylabel("True")
112+
plt.title("Confusion Matrix")
113+
114+
# Save the confusion matrix as an image
115+
plt.savefig(os.path.join(self.config['common'].get('exp_name', 'exp0') ,"confusion_matrix.png"))
116+
plt.close()

0 commit comments

Comments
 (0)
Please sign in to comment.