1
+ import os
2
+ import sys
3
+ sys .path .append ('models' )
4
+ import torch
5
+ import timm
6
+ import pytorch_lightning as pl
7
+ import torch .nn .functional as F
8
+ from torchmetrics .functional import f1_score , precision , recall
9
+ from sklearn .metrics import confusion_matrix
10
+ import matplotlib .pyplot as plt
11
+ import seaborn as sns
12
+ from optimizer import Lookahead
13
+
14
+
15
+ class ClassificationModel (pl .LightningModule ):
16
+
17
+ def __init__ (self , config : dict ):
18
+ super ().__init__ ()
19
+ self .config = config
20
+ model_name = config ['model' ]['params' ]['arch' ]
21
+ try :
22
+ self .backbone = timm .create_model (model_name ,
23
+ pretrained = config ['model' ]['params' ]['pretrained' ],
24
+ num_classes = config ['model' ]['params' ]['num_classes' ])
25
+ except :
26
+ raise ValueError (f'Undefined value of model name: { model_name } ' )
27
+
28
+ self .num_classes = config ['model' ]['params' ]['num_classes' ]
29
+
30
+ def forward (self , x ):
31
+ return self .backbone (x )
32
+
33
+ def configure_optimizers (self ):
34
+ optimizer_name = self .config ['optimizers' ][0 ]['target' ]
35
+ optimizer_params = self .config ['optimizers' ][0 ]['params' ]
36
+
37
+ optimizer_class = getattr (torch .optim , optimizer_name )
38
+
39
+ if self .config ['optimizers' ][0 ].get ('use_lookahead' , True ):
40
+ base_optim = optimizer_class (self .parameters (), ** optimizer_params )
41
+ optimizer = Lookahead (base_optim )
42
+ else :
43
+ optimizer = optimizer_class (self .parameters (), ** optimizer_params )
44
+
45
+ scheduler_name = self .config ['scheduler' ][0 ]['target' ]
46
+ scheduler_params = getattr (torch .optim .lr_scheduler , scheduler_name )
47
+ scheduler = scheduler_params (optimizer )
48
+
49
+ monitor = self .config ['scheduler' ][0 ].get ('monitor' , '' )
50
+ # return [optimizer], [scheduler]
51
+ return {"optimizer" : optimizer , "lr_scheduler" : scheduler , "monitor" : monitor }
52
+
53
+ def compute_metrics (self , pred , target ):
54
+ metrics = dict ()
55
+ metrics ['f1_score' ] = f1_score (pred , target , num_classes = self .num_classes , task = 'multiclass' )
56
+ metrics ['precision' ] = precision (pred , target , num_classes = self .num_classes , task = 'multiclass' )
57
+ metrics ['recall' ] = recall (pred , target , num_classes = self .num_classes , task = 'multiclass' )
58
+ return metrics
59
+
60
+ def training_step (self , batch , batch_idx ):
61
+ x , y = batch
62
+ output = self (x )
63
+ loss = F .cross_entropy (output , y )
64
+ metrics = self .compute_metrics (output , y )
65
+ self .log ('train_f1' , metrics ['f1_score' ], prog_bar = True , on_step = False , on_epoch = True )
66
+ self .log ('train_prec' , metrics ['precision' ], prog_bar = True , on_step = False , on_epoch = True )
67
+ self .log ('train_rec' , metrics ['recall' ], prog_bar = True , on_step = False , on_epoch = True )
68
+ self .log ('train_loss' , loss , prog_bar = True , on_step = False , on_epoch = True )
69
+ return loss
70
+
71
+ def validation_step (self , batch , batch_idx ):
72
+ x , y = batch
73
+ output = self (x )
74
+ loss = F .cross_entropy (output , y )
75
+ metrics = self .compute_metrics (output , y )
76
+ self .log ('val_f1' , metrics ['f1_score' ], on_step = False , on_epoch = True )
77
+ self .log ('val_prec' , metrics ['precision' ], on_step = False , on_epoch = True )
78
+ self .log ('val_rec' , metrics ['recall' ], on_step = False , on_epoch = True )
79
+ self .log ('val_loss' , loss , prog_bar = True , on_step = False , on_epoch = True )
80
+ return loss
81
+
82
+ def test_step (self , batch , batch_idx ):
83
+ x , y = batch
84
+ output = self (x )
85
+ loss = F .cross_entropy (output , y )
86
+ metrics = self .compute_metrics (output , y )
87
+ self .log ('test_f1' , metrics ['f1_score' ], on_step = False , on_epoch = True )
88
+ self .log ('test_prec' , metrics ['precision' ], on_step = False , on_epoch = True )
89
+ self .log ('test_rec' , metrics ['recall' ], on_step = False , on_epoch = True )
90
+ self .log ('test_loss' , loss , prog_bar = True , on_step = False , on_epoch = True )
91
+ print (f'test_metrics: { metrics } ' )
92
+
93
+ # Convert predictions to class indices
94
+ _ , preds = torch .max (output , 1 )
95
+
96
+ return {"loss" : loss , "preds" : preds , "labels" : y }
97
+
98
+ def test_epoch_end (self , outputs ):
99
+ all_preds = torch .cat ([out ["preds" ] for out in outputs ])
100
+ all_labels = torch .cat ([out ["labels" ] for out in outputs ])
101
+
102
+ all_preds = all_preds .cpu ().numpy ()
103
+ all_labels = all_labels .cpu ().numpy ()
104
+
105
+ conf_matrix = confusion_matrix (all_labels , all_preds )
106
+
107
+ plt .figure (figsize = (self .num_classes , self .num_classes ))
108
+ sns .heatmap (conf_matrix , annot = True , fmt = "d" , cmap = "Blues" ,
109
+ xticklabels = range (self .num_classes ), yticklabels = range (self .num_classes ))
110
+ plt .xlabel ("Predicted" )
111
+ plt .ylabel ("True" )
112
+ plt .title ("Confusion Matrix" )
113
+
114
+ # Save the confusion matrix as an image
115
+ plt .savefig (os .path .join (self .config ['common' ].get ('exp_name' , 'exp0' ) ,"confusion_matrix.png" ))
116
+ plt .close ()
0 commit comments