|
4 | 4 | """ |
5 | 5 |
|
6 | 6 | import os |
7 | | -from typing import Dict, List, Optional |
8 | | -import numpy as np |
9 | | -import tensorflow as tf |
| 7 | +from typing import List |
10 | 8 | from matchms import Spectrum |
11 | | -from matplotlib import pyplot as plt |
12 | 9 | from ms2deepscore import SpectrumBinner |
13 | | -from ms2deepscore.data_generators import DataGeneratorAllInchikeys |
14 | | -from ms2deepscore.models import SiameseModel |
15 | | -from tensorflow.keras.callbacks import ( # pylint: disable=import-error |
16 | | - EarlyStopping, ModelCheckpoint) |
17 | | -from tensorflow.keras.optimizers import Adam # pylint: disable=import-error |
| 10 | +from ms2deepscore.train_new_model.train_ms2deepscore import (plot_history, |
| 11 | + train_ms2ds_model) |
18 | 12 | from ms2query.create_new_library.calculate_tanimoto_scores import \ |
19 | 13 | calculate_tanimoto_scores_unique_inchikey |
20 | 14 | from ms2query.create_new_library.split_data_for_training import \ |
21 | 15 | split_spectra_on_inchikeys |
22 | 16 |
|
23 | 17 |
|
24 | | -def train_ms2ds_model(training_spectra, |
25 | | - validation_spectra, |
26 | | - tanimoto_df, |
27 | | - output_model_file_name, |
28 | | - epochs=150): |
| 18 | +def train_ms2deepscore_wrapper(spectra: List[Spectrum], |
| 19 | + output_model_file_name, |
| 20 | + fraction_validation_spectra, |
| 21 | + epochs, |
| 22 | + ms2ds_history_file_name=None): |
29 | 23 | assert not os.path.isfile(output_model_file_name), "The MS2Deepscore output model file name already exists" |
30 | | - # assert len(validation_spectra) >= 100, \ |
31 | | - # "Expected more validation spectra, too little validation spectra causes keras to crash" |
32 | | - # Bin training spectra |
| 24 | + training_spectra, validation_spectra = split_spectra_on_inchikeys(spectra, |
| 25 | + fraction_validation_spectra) |
| 26 | + tanimoto_score_df = calculate_tanimoto_scores_unique_inchikey(spectra, spectra) |
33 | 27 | spectrum_binner = SpectrumBinner(10000, mz_min=10.0, mz_max=1000.0, peak_scaling=0.5, |
34 | 28 | allowed_missing_percentage=100.0) |
35 | 29 | binned_spectrums_training = spectrum_binner.fit_transform(training_spectra) |
36 | 30 | # Bin validation spectra using the binner based on the training spectra. |
37 | 31 | # Peaks that do not occur in the training spectra will not be binned in the validaiton spectra. |
38 | 32 | binned_spectrums_val = spectrum_binner.transform(validation_spectra) |
39 | 33 |
|
40 | | - same_prob_bins = list(zip(np.linspace(0, 0.9, 10), np.linspace(0.1, 1, 10))) |
41 | | - |
42 | | - training_generator = DataGeneratorAllInchikeys( |
43 | | - binned_spectrums_training, |
44 | | - selected_inchikeys=list({s.get("inchikey")[:14] for s in training_spectra}), |
45 | | - reference_scores_df=tanimoto_df, |
46 | | - dim=len(spectrum_binner.known_bins), # The number of bins created |
47 | | - same_prob_bins=same_prob_bins, |
48 | | - num_turns=2, |
49 | | - augment_noise_max=10, |
50 | | - augment_noise_intensity=0.01) |
51 | | - |
52 | | - validation_generator = DataGeneratorAllInchikeys( |
53 | | - binned_spectrums_val, |
54 | | - selected_inchikeys=list({s.get("inchikey")[:14] for s in binned_spectrums_val}), |
55 | | - reference_scores_df=tanimoto_df, |
56 | | - dim=len(spectrum_binner.known_bins), # The number of bins created |
57 | | - same_prob_bins=same_prob_bins, |
58 | | - num_turns=10, # Number of pairs for each InChiKey14 during each epoch. |
59 | | - # To prevent data augmentation |
60 | | - augment_removal_max=0, augment_removal_intensity=0, augment_intensity=0, augment_noise_max=0, use_fixed_set=True |
| 34 | + history = train_ms2ds_model( |
| 35 | + binned_spectrums_training, |
| 36 | + binned_spectrums_val, |
| 37 | + spectrum_binner, |
| 38 | + tanimoto_score_df, |
| 39 | + output_model_file_name, |
| 40 | + epochs=epochs, |
| 41 | + base_dims=(500, 500), |
| 42 | + embedding_dim=200, |
61 | 43 | ) |
62 | 44 |
|
63 | | - model = SiameseModel(spectrum_binner, base_dims=(500, 500), embedding_dim=200, dropout_rate=0.2) |
64 | | - |
65 | | - model.compile(loss='mse', optimizer=Adam(lr=0.001), metrics=["mae", tf.keras.metrics.RootMeanSquaredError()]) |
66 | | - |
67 | | - # Save best model and include early stopping |
68 | | - checkpointer = ModelCheckpoint(filepath=output_model_file_name, monitor='val_loss', mode="min", verbose=1, save_best_only=True) |
69 | | - earlystopper_scoring_net = EarlyStopping(monitor='val_loss', mode="min", patience=10, verbose=1) |
70 | | - # Fit model and save history |
71 | | - history = model.model.fit(training_generator, validation_data=validation_generator, epochs=epochs, verbose=1, |
72 | | - callbacks=[checkpointer, earlystopper_scoring_net]) |
73 | | - model.load_weights(output_model_file_name) |
74 | | - model.save(output_model_file_name) |
75 | | - return history.history |
76 | | - |
77 | | - |
78 | | -def plot_history(history: Dict[str, List[float]], |
79 | | - file_name: Optional[str] = None): |
80 | | - plt.plot(history['loss']) |
81 | | - plt.plot(history['val_loss']) |
82 | | - plt.title('model loss') |
83 | | - plt.ylabel('loss') |
84 | | - plt.xlabel('epoch') |
85 | | - plt.legend(['train', 'val'], loc='upper left') |
86 | | - if file_name: |
87 | | - plt.savefig(file_name) |
88 | | - else: |
89 | | - plt.show() |
90 | | - |
91 | | - |
92 | | -def train_ms2deepscore_wrapper(spectra: List[Spectrum], |
93 | | - output_model_file_name, |
94 | | - fraction_validation_spectra, |
95 | | - epochs, |
96 | | - ms2ds_history_file_name=None): |
97 | | - assert not os.path.isfile(output_model_file_name), "The MS2Deepscore output model file name already exists" |
98 | | - training_spectra, validation_spectra = split_spectra_on_inchikeys(spectra, |
99 | | - fraction_validation_spectra) |
100 | | - tanimoto_score_df = calculate_tanimoto_scores_unique_inchikey(spectra, spectra) |
101 | | - history = train_ms2ds_model(training_spectra, validation_spectra, |
102 | | - tanimoto_score_df, output_model_file_name, |
103 | | - epochs) |
104 | 45 | print(f"The training history is: {history}") |
105 | 46 | plot_history(history, ms2ds_history_file_name) |
0 commit comments