Skip to content

Commit bc9ec9b

Browse files
authored
Append ECG neural network tune example (#199)
* I supplemented the documentation with a paragraph about the work of the framework with the optimal selection of two real and one discrete parameters. Corrected the problem code for finding real and discrete parameters. * correct target score * Append new examples. Correct documentation * Corrected documentation of examples
1 parent 6109a69 commit bc9ec9b

File tree

9 files changed

+533
-0
lines changed

9 files changed

+533
-0
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,7 @@ dmypy.json
134134

135135
# datasets
136136
benchmarks/data/datasets
137+
examples/Machine_learning/NeuralNetwork/Segmentation/data
138+
examples/Machine_learning/NeuralNetwork/Segmentation/models/*
139+
examples/Machine_learning/NeuralNetwork/Segmentation/lightning_logs
140+
examples/Machine_learning/NeuralNetwork/Segmentation/data.zip
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import random
2+
3+
from examples.Machine_learning.NeuralNetwork.Segmentation.scripts.dataset import SegmentationDataset
4+
from examples.Machine_learning.NeuralNetwork.Segmentation.scripts.metric import AllMetricTracker
5+
from iOpt.trial import Point
6+
from iOpt.trial import FunctionValue
7+
from iOpt.problem import Problem
8+
from typing import Dict
9+
from datetime import datetime
10+
import os
11+
from sklearn.model_selection import train_test_split
12+
from torch.utils.data import DataLoader
13+
from lightning.pytorch import Trainer
14+
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
15+
import torch
16+
import torch.nn as nn
17+
import numpy as np
18+
from lightning.pytorch import LightningModule
19+
from examples.Machine_learning.NeuralNetwork.Segmentation.scripts.metric import SegmentationMetric
20+
from examples.Machine_learning.NeuralNetwork.Segmentation.scripts.model import Encoder, Decoder, UNet
21+
22+
23+
class UnetModule(LightningModule):
24+
def __init__(self, kernel_size=23, q=1.2, label_smoothing=0, p=0.75):
25+
super().__init__()
26+
self.save_hyperparameters()
27+
encoder = Encoder(12, kernel_size=kernel_size, q=q, p=p)
28+
decoder = Decoder(encoder, 4)
29+
30+
self.model = UNet(encoder, decoder)
31+
self.loss = nn.CrossEntropyLoss(ignore_index=4, label_smoothing=label_smoothing)
32+
33+
self.p_metric = SegmentationMetric('p', 'all', return_type='f1', samples=150)
34+
self.t_metric = SegmentationMetric('t', 'all', return_type='f1', samples=150)
35+
self.qrs_metric = SegmentationMetric('qrs', 'all', return_type='f1', samples=150)
36+
37+
def predict(self, x):
38+
if isinstance(x, np.ndarray):
39+
x = torch.Tensor(x)
40+
x = x.unsqueeze(0) if len(x.shape) == 2 else x
41+
x = x.to(self.device)
42+
logits = self.model(x)
43+
y_pred = logits.argmax(axis=1)
44+
return y_pred.cpu().detach().numpy()
45+
46+
def training_step(self, batch):
47+
_, x, y = batch
48+
logits = self.model(x)
49+
loss = self.loss(logits, y)
50+
dict_ = {'train_loss': loss}
51+
self.log_dict(dict_, on_epoch=True, on_step=False)
52+
return loss
53+
54+
def validation_step(self, batch):
55+
_, x, y = batch
56+
logits = self.model(x)
57+
loss = self.loss(logits, y)
58+
dict_ = {'val_loss': loss}
59+
60+
metrics = self.get_metric(x, y, 'val')
61+
dict_.update(metrics)
62+
63+
self.log_dict(dict_, on_epoch=True, on_step=False)
64+
65+
return loss
66+
67+
def get_metric(self, x, y_true, prefix):
68+
y_true = y_true.cpu().detach().numpy()
69+
y_pred = self.predict(x)
70+
p_f1_score = self.p_metric(y_pred, y_true)
71+
qrs_f1_score = self.qrs_metric(y_pred, y_true)
72+
t_f1_score = self.t_metric(y_pred, y_true)
73+
dict = {f'{prefix}_p_wave': p_f1_score, f'{prefix}_qrs_wave': qrs_f1_score, f'{prefix}_t_wave': t_f1_score}
74+
return dict
75+
76+
def configure_optimizers(self):
77+
optimizer = torch.optim.AdamW(self.model.parameters())
78+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.3, patience=50)
79+
return [optimizer], [{"scheduler": scheduler,
80+
"interval": "epoch",
81+
"monitor": "train_loss"}]
82+
83+
def get_dataset(paths):
84+
return [np.load(f'data/signals/{x}') for x in paths], \
85+
[np.load(f'data/masks/{x}') for x in paths]
86+
87+
class Cardio2D(Problem):
88+
def __init__(self, p_bound: Dict[str, float], q_bound: Dict[str, float]):
89+
super(Cardio2D, self).__init__()
90+
self.dimension = 2
91+
self.number_of_float_variables = 2
92+
self.number_of_discrete_variables = 0
93+
self.number_of_objectives = 1
94+
self.number_of_constraints = 0
95+
96+
ecg_list = sorted(os.listdir('data/signals/'))
97+
ecg_list = [x for x in ecg_list if x.split('_')[-1] != 'unsupervised.npy']
98+
99+
train_list, test_list = train_test_split(ecg_list, test_size=0.2, shuffle=True, random_state=42)
100+
101+
for x in sorted(os.listdir('data/signals/')):
102+
if x.split('_')[-1] == 'unsupervised.npy':
103+
train_list.append(x)
104+
105+
x_train, y_train = get_dataset(train_list)
106+
x_test, y_test = get_dataset(test_list)
107+
108+
train_dataset = SegmentationDataset('cpu', train_list, x_train, y_train, common_mask=True, for_train=True)
109+
val_dataset = SegmentationDataset('cpu', test_list, x_test, y_test, common_mask=True)
110+
111+
self.train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
112+
self.val_loader = DataLoader(val_dataset, batch_size=32)
113+
114+
self.float_variable_names = np.array(["P parameter", "Q parameter"], dtype=str)
115+
self.lower_bound_of_float_variables = np.array([p_bound['low'], q_bound['low']],
116+
dtype=np.double)
117+
self.upper_bound_of_float_variables = np.array([p_bound['up'], q_bound['up']],
118+
dtype=np.double)
119+
120+
def calculate(self, point: Point, function_value: FunctionValue) -> FunctionValue:
121+
p, q = point.float_variables[0], point.float_variables[1]
122+
123+
now = datetime.now().strftime('%d.%m.%Y_%H:%M:%S')
124+
125+
checkpoint = ModelCheckpoint(dirpath=f'models/',
126+
filename=f"{random.uniform(1, 100):.9f}" + " " + f"{p:.9f}" + '_' + f"{q:.9f}" + '_' + '{epoch}_{val_p_wave:.6f}_{val_qrs_wave:.6f}_{val_t_wave:.6f}',
127+
monitor='val_p_wave',
128+
save_top_k=3,
129+
mode='max')
130+
early_stopping = EarlyStopping(monitor='val_loss',
131+
patience=300)
132+
133+
cb = AllMetricTracker()
134+
model = UnetModule(p=p, q=q)
135+
trainer = Trainer(max_epochs=1_000_000, callbacks=[checkpoint, early_stopping, cb])
136+
try:
137+
trainer.fit(model, self.train_loader, self.val_loader)
138+
except Exception as err:
139+
print(f"Unexpected {err=}, {type(err)=}")
140+
141+
print('p ' + f"{p:.9f}")
142+
print('q ' + f"{q:.9f}")
143+
function_value.value = -cb.best_p_valscore
144+
print(-cb.best_p_valscore)
145+
return function_value
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import shutil
2+
3+
import numpy as np
4+
from examples.Machine_learning.NeuralNetwork.Segmentation.Problem.Cardio2D import Cardio2D
5+
from iOpt.output_system.listeners.console_outputers import ConsoleOutputListener
6+
from iOpt.solver import Solver
7+
from iOpt.solver_parametrs import SolverParameters
8+
import hashlib
9+
import os
10+
from pathlib import Path
11+
12+
import requests
13+
from tqdm import tqdm
14+
15+
16+
def _get_hash(path: Path) -> str:
17+
file_hash = hashlib.sha256()
18+
with open(path, "rb") as f:
19+
while chunk := f.read(8192):
20+
file_hash.update(chunk)
21+
return file_hash.hexdigest()
22+
23+
24+
def download(path: Path, public_key: str) -> None:
25+
url = "https://cloud-api.yandex.net/v1/disk/public/resources"
26+
params = {"public_key": f"https://disk.yandex.ru/d/{public_key}"}
27+
28+
response = requests.get(url, params=params).json()
29+
download_url = response["file"]
30+
file_size = response["size"]
31+
sha256 = response["sha256"]
32+
33+
response = requests.get(download_url, stream=True)
34+
35+
if path.is_file() and os.path.getsize(path) == file_size:
36+
print(f"File already downloaded: {path}")
37+
if _get_hash(path) == sha256:
38+
return
39+
40+
with tqdm(total=file_size, unit="B", unit_scale=True) as progress_bar:
41+
with open(path, "wb") as f:
42+
for data in response.iter_content(1024):
43+
progress_bar.update(len(data))
44+
f.write(data)
45+
46+
47+
if __name__ == "__main__":
48+
if not os.path.exists('data'):
49+
path = Path('data.zip')
50+
download(path, 'Oqxcid6uX58kYQ')
51+
shutil.unpack_archive('data.zip', 'data', format="zip")
52+
os.remove('data.zip')
53+
54+
p_value_bound = {'low': 0.0, 'up': 1.0}
55+
q_value_bound = {'low': 1.0, 'up': 1.6}
56+
problem = Cardio2D(p_value_bound, q_value_bound)
57+
method_params = SolverParameters(r=np.double(3.0), iters_limit=10)
58+
solver = Solver(problem, parameters=method_params)
59+
cfol = ConsoleOutputListener(mode='full')
60+
solver.add_listener(cfol)
61+
solver_info = solver.solve()

examples/Machine_learning/NeuralNetwork/Segmentation/__init__.py

Whitespace-only changes.

examples/Machine_learning/NeuralNetwork/Segmentation/scripts/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
import numpy as np
3+
4+
class SegmentationDataset(torch.utils.data.Dataset):
5+
def __init__(self, device, paths, signals, masks=None, common_mask=False, for_train=False):
6+
7+
self._device = device
8+
self._paths = paths
9+
self._signals = [torch.Tensor(x).to(device) for x in signals]
10+
self._masks = [torch.LongTensor(x).to(device) for x in masks]
11+
12+
self.begin_noise, self.end_noise = 1e-3, 3e-3
13+
self.begin_ampl, self.end_ampl = 0, 0.3
14+
15+
self.begin_freq, self.end_freq = 0, 0.009
16+
17+
self.prob_isoline = 0.7
18+
self.prob_reverse = 0.5
19+
self.sub_len = 4000
20+
21+
self.common_mask = common_mask
22+
self.for_train = for_train
23+
24+
def reverse_ecg(self, signal):
25+
result = torch.zeros_like(signal, device=self._device)
26+
for i, x in enumerate(signal):
27+
sign = 2 * (np.random.rand() < self.prob_reverse) - 1
28+
result[i] = sign * x
29+
return result
30+
31+
def __len__(self):
32+
return len(self._signals)
33+
34+
def __getitem__(self, i):
35+
if not self.for_train:
36+
return self._paths[i], self._signals[i], self.skip_borders(self._masks[i][0])
37+
38+
shift = np.random.randint(0, 5000 - self.sub_len - 1)
39+
noise = self.begin_noise + (self.end_noise - self.begin_noise) * np.random.rand()
40+
signal = self._signals[i][:, shift:shift + self.sub_len] + torch.normal(0, noise,
41+
size=(self.sub_len,),
42+
device=self._device)
43+
44+
signal = self.reverse_ecg(signal)
45+
46+
if self._masks is None:
47+
return self._paths[i], signal
48+
49+
mask = self._masks[i][:, shift: shift + self.sub_len]
50+
indexes = torch.randperm(12, device=self._device)
51+
52+
if self.common_mask:
53+
mask = mask[0]
54+
else:
55+
mask = mask[indexes]
56+
57+
return self._paths[i], signal[indexes], self.skip_borders(mask)
58+
59+
def skip_borders(self, mask):
60+
wave_start = torch.logical_and(torch.roll(mask, 1) == 0, mask != 0).type(torch.uint8)
61+
wave_finish = torch.logical_and(torch.roll(mask, -1) == 0, mask != 0).type(torch.uint8)
62+
63+
indexes_starts, = torch.where(wave_start == 1)
64+
indexes_finish, = torch.where(wave_finish == 1)
65+
66+
left_skip = indexes_starts[indexes_starts > 500][0]
67+
right_skip = indexes_finish[indexes_finish < len(mask) - 500][-1]
68+
69+
mask_copy = torch.clone(mask)
70+
mask_copy[:left_skip] = 4
71+
mask_copy[right_skip:] = 4
72+
73+
return mask_copy

0 commit comments

Comments
 (0)