Skip to content

Commit eeed096

Browse files
authored
Merge pull request #237 from iomega/update_ms2deepscore_version
Update ms2deepscore to 0.5.0
2 parents a3ed83d + 0644750 commit eeed096

5 files changed

Lines changed: 25 additions & 100 deletions

File tree

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7+
## 1.4.0
8+
### Changed
9+
- Made compatible with MS2Deepscore 0.5.0
710

811
## 1.3.0
912
### Changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ dependencies:
1212
- pyarrow=12.0.1
1313
- tensorflow=2.12.1
1414
- scikit-learn=1.3.2
15-
- ms2deepscore=0.4.0
15+
- ms2deepscore=0.5.0
1616
- pandas=2.0.3
1717
- matplotlib=3.7.3
1818
- skl2onnx=1.16.0

ms2query/create_new_library/train_ms2deepscore.py

Lines changed: 20 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,102 +4,43 @@
44
"""
55

66
import os
7-
from typing import Dict, List, Optional
8-
import numpy as np
9-
import tensorflow as tf
7+
from typing import List
108
from matchms import Spectrum
11-
from matplotlib import pyplot as plt
129
from ms2deepscore import SpectrumBinner
13-
from ms2deepscore.data_generators import DataGeneratorAllInchikeys
14-
from ms2deepscore.models import SiameseModel
15-
from tensorflow.keras.callbacks import ( # pylint: disable=import-error
16-
EarlyStopping, ModelCheckpoint)
17-
from tensorflow.keras.optimizers import Adam # pylint: disable=import-error
10+
from ms2deepscore.train_new_model.train_ms2deepscore import (plot_history,
11+
train_ms2ds_model)
1812
from ms2query.create_new_library.calculate_tanimoto_scores import \
1913
calculate_tanimoto_scores_unique_inchikey
2014
from ms2query.create_new_library.split_data_for_training import \
2115
split_spectra_on_inchikeys
2216

2317

24-
def train_ms2ds_model(training_spectra,
25-
validation_spectra,
26-
tanimoto_df,
27-
output_model_file_name,
28-
epochs=150):
18+
def train_ms2deepscore_wrapper(spectra: List[Spectrum],
19+
output_model_file_name,
20+
fraction_validation_spectra,
21+
epochs,
22+
ms2ds_history_file_name=None):
2923
assert not os.path.isfile(output_model_file_name), "The MS2Deepscore output model file name already exists"
30-
# assert len(validation_spectra) >= 100, \
31-
# "Expected more validation spectra, too little validation spectra causes keras to crash"
32-
# Bin training spectra
24+
training_spectra, validation_spectra = split_spectra_on_inchikeys(spectra,
25+
fraction_validation_spectra)
26+
tanimoto_score_df = calculate_tanimoto_scores_unique_inchikey(spectra, spectra)
3327
spectrum_binner = SpectrumBinner(10000, mz_min=10.0, mz_max=1000.0, peak_scaling=0.5,
3428
allowed_missing_percentage=100.0)
3529
binned_spectrums_training = spectrum_binner.fit_transform(training_spectra)
3630
# Bin validation spectra using the binner based on the training spectra.
3731
# Peaks that do not occur in the training spectra will not be binned in the validaiton spectra.
3832
binned_spectrums_val = spectrum_binner.transform(validation_spectra)
3933

40-
same_prob_bins = list(zip(np.linspace(0, 0.9, 10), np.linspace(0.1, 1, 10)))
41-
42-
training_generator = DataGeneratorAllInchikeys(
43-
binned_spectrums_training,
44-
selected_inchikeys=list({s.get("inchikey")[:14] for s in training_spectra}),
45-
reference_scores_df=tanimoto_df,
46-
dim=len(spectrum_binner.known_bins), # The number of bins created
47-
same_prob_bins=same_prob_bins,
48-
num_turns=2,
49-
augment_noise_max=10,
50-
augment_noise_intensity=0.01)
51-
52-
validation_generator = DataGeneratorAllInchikeys(
53-
binned_spectrums_val,
54-
selected_inchikeys=list({s.get("inchikey")[:14] for s in binned_spectrums_val}),
55-
reference_scores_df=tanimoto_df,
56-
dim=len(spectrum_binner.known_bins), # The number of bins created
57-
same_prob_bins=same_prob_bins,
58-
num_turns=10, # Number of pairs for each InChiKey14 during each epoch.
59-
# To prevent data augmentation
60-
augment_removal_max=0, augment_removal_intensity=0, augment_intensity=0, augment_noise_max=0, use_fixed_set=True
34+
history = train_ms2ds_model(
35+
binned_spectrums_training,
36+
binned_spectrums_val,
37+
spectrum_binner,
38+
tanimoto_score_df,
39+
output_model_file_name,
40+
epochs=epochs,
41+
base_dims=(500, 500),
42+
embedding_dim=200,
6143
)
6244

63-
model = SiameseModel(spectrum_binner, base_dims=(500, 500), embedding_dim=200, dropout_rate=0.2)
64-
65-
model.compile(loss='mse', optimizer=Adam(lr=0.001), metrics=["mae", tf.keras.metrics.RootMeanSquaredError()])
66-
67-
# Save best model and include early stopping
68-
checkpointer = ModelCheckpoint(filepath=output_model_file_name, monitor='val_loss', mode="min", verbose=1, save_best_only=True)
69-
earlystopper_scoring_net = EarlyStopping(monitor='val_loss', mode="min", patience=10, verbose=1)
70-
# Fit model and save history
71-
history = model.model.fit(training_generator, validation_data=validation_generator, epochs=epochs, verbose=1,
72-
callbacks=[checkpointer, earlystopper_scoring_net])
73-
model.load_weights(output_model_file_name)
74-
model.save(output_model_file_name)
75-
return history.history
76-
77-
78-
def plot_history(history: Dict[str, List[float]],
79-
file_name: Optional[str] = None):
80-
plt.plot(history['loss'])
81-
plt.plot(history['val_loss'])
82-
plt.title('model loss')
83-
plt.ylabel('loss')
84-
plt.xlabel('epoch')
85-
plt.legend(['train', 'val'], loc='upper left')
86-
if file_name:
87-
plt.savefig(file_name)
88-
else:
89-
plt.show()
90-
91-
92-
def train_ms2deepscore_wrapper(spectra: List[Spectrum],
93-
output_model_file_name,
94-
fraction_validation_spectra,
95-
epochs,
96-
ms2ds_history_file_name=None):
97-
assert not os.path.isfile(output_model_file_name), "The MS2Deepscore output model file name already exists"
98-
training_spectra, validation_spectra = split_spectra_on_inchikeys(spectra,
99-
fraction_validation_spectra)
100-
tanimoto_score_df = calculate_tanimoto_scores_unique_inchikey(spectra, spectra)
101-
history = train_ms2ds_model(training_spectra, validation_spectra,
102-
tanimoto_score_df, output_model_file_name,
103-
epochs)
10445
print(f"The training history is: {history}")
10546
plot_history(history, ms2ds_history_file_name)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"h5py",
3636
"pyarrow",
3737
"scikit-learn",
38-
"ms2deepscore<=0.4.0",
38+
"ms2deepscore==0.5.0",
3939
"gensim>=4.0.0",
4040
"pandas",
4141
"tqdm",

tests/test_train_ms2deepscore.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)