1
+ import random
2
+
3
+ from examples .Machine_learning .NeuralNetwork .Segmentation .scripts .dataset import SegmentationDataset
4
+ from examples .Machine_learning .NeuralNetwork .Segmentation .scripts .metric import AllMetricTracker
5
+ from iOpt .trial import Point
6
+ from iOpt .trial import FunctionValue
7
+ from iOpt .problem import Problem
8
+ from typing import Dict
9
+ from datetime import datetime
10
+ import os
11
+ from sklearn .model_selection import train_test_split
12
+ from torch .utils .data import DataLoader
13
+ from lightning .pytorch import Trainer
14
+ from lightning .pytorch .callbacks import ModelCheckpoint , EarlyStopping
15
+ import torch
16
+ import torch .nn as nn
17
+ import numpy as np
18
+ from lightning .pytorch import LightningModule
19
+ from examples .Machine_learning .NeuralNetwork .Segmentation .scripts .metric import SegmentationMetric
20
+ from examples .Machine_learning .NeuralNetwork .Segmentation .scripts .model import Encoder , Decoder , UNet
21
+
22
+
23
+ class UnetModule (LightningModule ):
24
+ def __init__ (self , kernel_size = 23 , q = 1.2 , label_smoothing = 0 , p = 0.75 ):
25
+ super ().__init__ ()
26
+ self .save_hyperparameters ()
27
+ encoder = Encoder (12 , kernel_size = kernel_size , q = q , p = p )
28
+ decoder = Decoder (encoder , 4 )
29
+
30
+ self .model = UNet (encoder , decoder )
31
+ self .loss = nn .CrossEntropyLoss (ignore_index = 4 , label_smoothing = label_smoothing )
32
+
33
+ self .p_metric = SegmentationMetric ('p' , 'all' , return_type = 'f1' , samples = 150 )
34
+ self .t_metric = SegmentationMetric ('t' , 'all' , return_type = 'f1' , samples = 150 )
35
+ self .qrs_metric = SegmentationMetric ('qrs' , 'all' , return_type = 'f1' , samples = 150 )
36
+
37
+ def predict (self , x ):
38
+ if isinstance (x , np .ndarray ):
39
+ x = torch .Tensor (x )
40
+ x = x .unsqueeze (0 ) if len (x .shape ) == 2 else x
41
+ x = x .to (self .device )
42
+ logits = self .model (x )
43
+ y_pred = logits .argmax (axis = 1 )
44
+ return y_pred .cpu ().detach ().numpy ()
45
+
46
+ def training_step (self , batch ):
47
+ _ , x , y = batch
48
+ logits = self .model (x )
49
+ loss = self .loss (logits , y )
50
+ dict_ = {'train_loss' : loss }
51
+ self .log_dict (dict_ , on_epoch = True , on_step = False )
52
+ return loss
53
+
54
+ def validation_step (self , batch ):
55
+ _ , x , y = batch
56
+ logits = self .model (x )
57
+ loss = self .loss (logits , y )
58
+ dict_ = {'val_loss' : loss }
59
+
60
+ metrics = self .get_metric (x , y , 'val' )
61
+ dict_ .update (metrics )
62
+
63
+ self .log_dict (dict_ , on_epoch = True , on_step = False )
64
+
65
+ return loss
66
+
67
+ def get_metric (self , x , y_true , prefix ):
68
+ y_true = y_true .cpu ().detach ().numpy ()
69
+ y_pred = self .predict (x )
70
+ p_f1_score = self .p_metric (y_pred , y_true )
71
+ qrs_f1_score = self .qrs_metric (y_pred , y_true )
72
+ t_f1_score = self .t_metric (y_pred , y_true )
73
+ dict = {f'{ prefix } _p_wave' : p_f1_score , f'{ prefix } _qrs_wave' : qrs_f1_score , f'{ prefix } _t_wave' : t_f1_score }
74
+ return dict
75
+
76
+ def configure_optimizers (self ):
77
+ optimizer = torch .optim .AdamW (self .model .parameters ())
78
+ scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (optimizer , factor = 0.3 , patience = 50 )
79
+ return [optimizer ], [{"scheduler" : scheduler ,
80
+ "interval" : "epoch" ,
81
+ "monitor" : "train_loss" }]
82
+
83
+ def get_dataset (paths ):
84
+ return [np .load (f'data/signals/{ x } ' ) for x in paths ], \
85
+ [np .load (f'data/masks/{ x } ' ) for x in paths ]
86
+
87
+ class Cardio2D (Problem ):
88
+ def __init__ (self , p_bound : Dict [str , float ], q_bound : Dict [str , float ]):
89
+ super (Cardio2D , self ).__init__ ()
90
+ self .dimension = 2
91
+ self .number_of_float_variables = 2
92
+ self .number_of_discrete_variables = 0
93
+ self .number_of_objectives = 1
94
+ self .number_of_constraints = 0
95
+
96
+ ecg_list = sorted (os .listdir ('data/signals/' ))
97
+ ecg_list = [x for x in ecg_list if x .split ('_' )[- 1 ] != 'unsupervised.npy' ]
98
+
99
+ train_list , test_list = train_test_split (ecg_list , test_size = 0.2 , shuffle = True , random_state = 42 )
100
+
101
+ for x in sorted (os .listdir ('data/signals/' )):
102
+ if x .split ('_' )[- 1 ] == 'unsupervised.npy' :
103
+ train_list .append (x )
104
+
105
+ x_train , y_train = get_dataset (train_list )
106
+ x_test , y_test = get_dataset (test_list )
107
+
108
+ train_dataset = SegmentationDataset ('cpu' , train_list , x_train , y_train , common_mask = True , for_train = True )
109
+ val_dataset = SegmentationDataset ('cpu' , test_list , x_test , y_test , common_mask = True )
110
+
111
+ self .train_loader = DataLoader (train_dataset , batch_size = 32 , shuffle = True )
112
+ self .val_loader = DataLoader (val_dataset , batch_size = 32 )
113
+
114
+ self .float_variable_names = np .array (["P parameter" , "Q parameter" ], dtype = str )
115
+ self .lower_bound_of_float_variables = np .array ([p_bound ['low' ], q_bound ['low' ]],
116
+ dtype = np .double )
117
+ self .upper_bound_of_float_variables = np .array ([p_bound ['up' ], q_bound ['up' ]],
118
+ dtype = np .double )
119
+
120
+ def calculate (self , point : Point , function_value : FunctionValue ) -> FunctionValue :
121
+ p , q = point .float_variables [0 ], point .float_variables [1 ]
122
+
123
+ now = datetime .now ().strftime ('%d.%m.%Y_%H:%M:%S' )
124
+
125
+ checkpoint = ModelCheckpoint (dirpath = f'models/' ,
126
+ filename = f"{ random .uniform (1 , 100 ):.9f} " + " " + f"{ p :.9f} " + '_' + f"{ q :.9f} " + '_' + '{epoch}_{val_p_wave:.6f}_{val_qrs_wave:.6f}_{val_t_wave:.6f}' ,
127
+ monitor = 'val_p_wave' ,
128
+ save_top_k = 3 ,
129
+ mode = 'max' )
130
+ early_stopping = EarlyStopping (monitor = 'val_loss' ,
131
+ patience = 300 )
132
+
133
+ cb = AllMetricTracker ()
134
+ model = UnetModule (p = p , q = q )
135
+ trainer = Trainer (max_epochs = 1_000_000 , callbacks = [checkpoint , early_stopping , cb ])
136
+ try :
137
+ trainer .fit (model , self .train_loader , self .val_loader )
138
+ except Exception as err :
139
+ print (f"Unexpected { err = } , { type (err )= } " )
140
+
141
+ print ('p ' + f"{ p :.9f} " )
142
+ print ('q ' + f"{ q :.9f} " )
143
+ function_value .value = - cb .best_p_valscore
144
+ print (- cb .best_p_valscore )
145
+ return function_value
0 commit comments