Skip to content

Commit de26319

Browse files
Merge pull request #290 from matchms/minor_improvements
Minor improvements
2 parents 30787fa + 3f053d0 commit de26319

File tree

6 files changed

+143
-23
lines changed

6 files changed

+143
-23
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## 2.8.0
9+
10+
### Fixed
11+
- progress bars can be switched off
12+
- adjust random generator for test split
13+
14+
### Changed
15+
- expand tests
16+
817
## 2.7.2
918

1019
### Changed

ms2deepscore/MS2DeepScore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self, model: SiameseSpectralModel, progress_bar: bool = True):
5656
self.progress_bar = progress_bar
5757

5858
def get_embedding_array(self, spectrums):
59-
return compute_embedding_array(self.model, spectrums)
59+
return compute_embedding_array(self.model, spectrums, progress_bar=self.progress_bar)
6060

6161
def pair(self, reference: Spectrum, query: Spectrum) -> float:
6262
"""Calculate the MS2DeepScore similaritiy between a reference and a query spectrum.

ms2deepscore/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.7.2'
1+
__version__ = '2.8.0'

ms2deepscore/models/SiameseSpectralModel.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,13 @@ def dense_layer(input_size, output_size, activation="lrelu"):
344344
return nn.Sequential(nn.Linear(input_size, output_size), activations[activation])
345345

346346

347-
def compute_embedding_array(model: SiameseSpectralModel, spectra, datatype="numpy", device=None):
347+
def compute_embedding_array(
348+
model: SiameseSpectralModel,
349+
spectra,
350+
datatype="numpy",
351+
device=None,
352+
progress_bar: bool = True,
353+
):
348354
"""
349355
Compute the embeddings of all given spectra (list of matchms Spectrum objects).
350356
@@ -361,6 +367,8 @@ def compute_embedding_array(model: SiameseSpectralModel, spectra, datatype="nump
361367
device:
362368
The device on which to perform the computation.
363369
If None, it automatically uses CUDA if available, otherwise CPU.
370+
progress_bar:
371+
Whether to display a progress bar during embedding computation.
364372
"""
365373
if datatype.lower() not in ["numpy", "pytorch"]:
366374
raise ValueError("datatype can only be 'numpy' or 'pytorch'.")
@@ -372,7 +380,11 @@ def compute_embedding_array(model: SiameseSpectralModel, spectra, datatype="nump
372380
if device is None:
373381
device = torch_device("cuda" if cuda.is_available() else "cpu")
374382
model.to(device)
375-
for i, spec in tqdm(enumerate(spectra), total=len(spectra), desc="Computing spectral embeddings ..."):
383+
for i, spec in tqdm(
384+
enumerate(spectra),
385+
total=len(spectra),
386+
desc="Computing spectral embeddings ...",
387+
disable=not progress_bar):
376388
X = tensorize_spectra([spec], model.model_settings)
377389
with no_grad():
378390
if datatype.lower() == "numpy":

ms2deepscore/train_new_model/validation_and_test_split.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import random
1+
import numpy as np
22
from typing import List, Tuple
33
from matchms import Spectrum
44
from tqdm import tqdm
@@ -32,8 +32,8 @@ def split_spectra_in_random_inchikey_sets(
3232
"""Splits a set of spectra into a val, test and train set. The size of the val and test set are n/k.
3333
"""
3434
unique_inchikeys = select_unique_inchikeys(spectra)
35-
random.seed(random_seed)
36-
random.shuffle(unique_inchikeys)
35+
rng = np.random.default_rng(random_seed)
36+
rng.shuffle(unique_inchikeys)
3737
fraction_size = len(unique_inchikeys) // k
3838

3939
validation_inchikeys = unique_inchikeys[-fraction_size:]
Lines changed: 115 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,137 @@
11
import numpy as np
22
import pytest
33
from matchms import Spectrum
4+
45
from ms2deepscore.train_new_model.validation_and_test_split import (
5-
select_spectra_belonging_to_inchikey, select_unique_inchikeys,
6-
split_spectra_in_random_inchikey_sets)
6+
select_spectra_belonging_to_inchikey,
7+
select_unique_inchikeys,
8+
split_spectra_in_random_inchikey_sets,
9+
)
10+
11+
12+
def _inchikey(letter: str) -> str:
13+
return 14 * letter
14+
15+
16+
def _make_spectrum(letter: str, idx: int = 0) -> Spectrum:
17+
return Spectrum(
18+
mz=np.array([100.1 + idx]),
19+
intensities=np.array([0.9]),
20+
metadata={"inchikey": _inchikey(letter)},
21+
)
722

823

924
@pytest.fixture
1025
def sample_spectra():
1126
return [
12-
Spectrum(mz=np.array([100.1]), intensities=np.array([0.9]),
13-
metadata={"inchikey": 14 * "A"}),
14-
Spectrum(mz=np.array([100.1]), intensities=np.array([0.9]),
15-
metadata={"inchikey": 14 * "B"}),
16-
Spectrum(mz=np.array([100.1]), intensities=np.array([0.9]),
17-
metadata={"inchikey": 14 * "B"}),
18-
Spectrum(mz=np.array([100.1]), intensities=np.array([0.9]),
19-
metadata={"inchikey": 14 * "C"}),
27+
_make_spectrum("A", 0),
28+
_make_spectrum("B", 1),
29+
_make_spectrum("B", 2),
30+
_make_spectrum("C", 3),
2031
]
2132

2233

34+
@pytest.fixture
35+
def larger_sample_spectra():
36+
spectra = []
37+
# 8 unique inchikeys, 2 spectra each
38+
for letter in "ABCDEFGH":
39+
spectra.append(_make_spectrum(letter, 0))
40+
spectra.append(_make_spectrum(letter, 1))
41+
return spectra
42+
43+
44+
def _unique_inchikeys_in_spectra(spectra):
45+
return sorted({s.get("inchikey")[:14] for s in spectra})
46+
47+
2348
def test_select_unique_inchikeys(sample_spectra):
2449
result = select_unique_inchikeys(sample_spectra)
25-
assert result == [14 * "A", 14 * "B", 14 * "C"]
50+
assert result == [_inchikey("A"), _inchikey("B"), _inchikey("C")]
2651

2752

2853
def test_select_spectra_belonging_to_inchikey(sample_spectra):
29-
inchikeys = [14 * "A", 14 * "B"]
54+
inchikeys = [_inchikey("A"), _inchikey("B")]
3055
result = select_spectra_belonging_to_inchikey(sample_spectra, inchikeys)
3156
assert len(result) == 3
32-
assert result[0].get("inchikey") == 14 * "A"
57+
assert result[0].get("inchikey") == _inchikey("A")
58+
assert all(s.get("inchikey")[:14] in inchikeys for s in result)
59+
60+
61+
def test_select_spectra_belonging_to_inchikey_empty_match(sample_spectra):
62+
result = select_spectra_belonging_to_inchikey(sample_spectra, [_inchikey("Z")])
63+
assert result == []
3364

3465

35-
def test_split_spectra_in_random_inchikey_sets(sample_spectra):
36-
# TODO: this is still a dummy test mostly
66+
def test_split_spectra_in_random_inchikey_sets_preserves_all_spectra(sample_spectra):
3767
val, test, train = split_spectra_in_random_inchikey_sets(sample_spectra, 2, 42)
38-
assert len(val) + len(test) + len(train) == 4
68+
assert len(val) + len(test) + len(train) == len(sample_spectra)
69+
70+
71+
def test_split_spectra_in_random_inchikey_sets_splits_by_inchikey_group(larger_sample_spectra):
72+
val, test, train = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 42)
73+
74+
val_keys = set(_unique_inchikeys_in_spectra(val))
75+
test_keys = set(_unique_inchikeys_in_spectra(test))
76+
train_keys = set(_unique_inchikeys_in_spectra(train))
77+
78+
assert val_keys.isdisjoint(test_keys)
79+
assert val_keys.isdisjoint(train_keys)
80+
assert test_keys.isdisjoint(train_keys)
81+
82+
all_keys = val_keys | test_keys | train_keys
83+
assert all_keys == set(_unique_inchikeys_in_spectra(larger_sample_spectra))
84+
85+
86+
def test_split_spectra_in_random_inchikey_sets_expected_unique_group_sizes(larger_sample_spectra):
87+
val, test, train = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 42)
88+
89+
# 8 unique inchikeys, k=4 -> fraction_size = 2
90+
assert len(_unique_inchikeys_in_spectra(val)) == 2
91+
assert len(_unique_inchikeys_in_spectra(test)) == 2
92+
assert len(_unique_inchikeys_in_spectra(train)) == 4
93+
94+
# two spectra per inchikey
95+
assert len(val) == 4
96+
assert len(test) == 4
97+
assert len(train) == 8
98+
99+
100+
def test_split_spectra_in_random_inchikey_sets_same_seed_is_stable(larger_sample_spectra):
101+
val1, test1, train1 = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 42)
102+
val2, test2, train2 = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 42)
103+
104+
assert _unique_inchikeys_in_spectra(val1) == _unique_inchikeys_in_spectra(val2)
105+
assert _unique_inchikeys_in_spectra(test1) == _unique_inchikeys_in_spectra(test2)
106+
assert _unique_inchikeys_in_spectra(train1) == _unique_inchikeys_in_spectra(train2)
107+
108+
109+
def test_split_spectra_in_random_inchikey_sets_different_seed_can_change_split(larger_sample_spectra):
110+
val1, test1, train1 = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 1)
111+
val2, test2, train2 = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 2)
112+
113+
split1 = (
114+
_unique_inchikeys_in_spectra(val1),
115+
_unique_inchikeys_in_spectra(test1),
116+
_unique_inchikeys_in_spectra(train1),
117+
)
118+
split2 = (
119+
_unique_inchikeys_in_spectra(val2),
120+
_unique_inchikeys_in_spectra(test2),
121+
_unique_inchikeys_in_spectra(train2),
122+
)
123+
124+
assert split1 != split2
125+
126+
127+
def test_split_spectra_in_random_inchikey_sets_none_seed_still_preserves_partition(larger_sample_spectra):
128+
val, test, train = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, None)
129+
130+
val_keys = set(_unique_inchikeys_in_spectra(val))
131+
test_keys = set(_unique_inchikeys_in_spectra(test))
132+
train_keys = set(_unique_inchikeys_in_spectra(train))
133+
134+
assert val_keys.isdisjoint(test_keys)
135+
assert val_keys.isdisjoint(train_keys)
136+
assert test_keys.isdisjoint(train_keys)
137+
assert len(val) + len(test) + len(train) == len(larger_sample_spectra)

0 commit comments

Comments
 (0)