|
1 | 1 | import numpy as np |
2 | 2 | import pytest |
3 | 3 | from matchms import Spectrum |
| 4 | + |
4 | 5 | from ms2deepscore.train_new_model.validation_and_test_split import ( |
5 | | - select_spectra_belonging_to_inchikey, select_unique_inchikeys, |
6 | | - split_spectra_in_random_inchikey_sets) |
| 6 | + select_spectra_belonging_to_inchikey, |
| 7 | + select_unique_inchikeys, |
| 8 | + split_spectra_in_random_inchikey_sets, |
| 9 | +) |
| 10 | + |
| 11 | + |
| 12 | +def _inchikey(letter: str) -> str: |
| 13 | + return 14 * letter |
| 14 | + |
| 15 | + |
| 16 | +def _make_spectrum(letter: str, idx: int = 0) -> Spectrum: |
| 17 | + return Spectrum( |
| 18 | + mz=np.array([100.1 + idx]), |
| 19 | + intensities=np.array([0.9]), |
| 20 | + metadata={"inchikey": _inchikey(letter)}, |
| 21 | + ) |
7 | 22 |
|
8 | 23 |
|
9 | 24 | @pytest.fixture |
10 | 25 | def sample_spectra(): |
11 | 26 | return [ |
12 | | - Spectrum(mz=np.array([100.1]), intensities=np.array([0.9]), |
13 | | - metadata={"inchikey": 14 * "A"}), |
14 | | - Spectrum(mz=np.array([100.1]), intensities=np.array([0.9]), |
15 | | - metadata={"inchikey": 14 * "B"}), |
16 | | - Spectrum(mz=np.array([100.1]), intensities=np.array([0.9]), |
17 | | - metadata={"inchikey": 14 * "B"}), |
18 | | - Spectrum(mz=np.array([100.1]), intensities=np.array([0.9]), |
19 | | - metadata={"inchikey": 14 * "C"}), |
| 27 | + _make_spectrum("A", 0), |
| 28 | + _make_spectrum("B", 1), |
| 29 | + _make_spectrum("B", 2), |
| 30 | + _make_spectrum("C", 3), |
20 | 31 | ] |
21 | 32 |
|
22 | 33 |
|
| 34 | +@pytest.fixture |
| 35 | +def larger_sample_spectra(): |
| 36 | + spectra = [] |
| 37 | + # 8 unique inchikeys, 2 spectra each |
| 38 | + for letter in "ABCDEFGH": |
| 39 | + spectra.append(_make_spectrum(letter, 0)) |
| 40 | + spectra.append(_make_spectrum(letter, 1)) |
| 41 | + return spectra |
| 42 | + |
| 43 | + |
| 44 | +def _unique_inchikeys_in_spectra(spectra): |
| 45 | + return sorted({s.get("inchikey")[:14] for s in spectra}) |
| 46 | + |
| 47 | + |
23 | 48 | def test_select_unique_inchikeys(sample_spectra): |
24 | 49 | result = select_unique_inchikeys(sample_spectra) |
25 | | - assert result == [14 * "A", 14 * "B", 14 * "C"] |
| 50 | + assert result == [_inchikey("A"), _inchikey("B"), _inchikey("C")] |
26 | 51 |
|
27 | 52 |
|
28 | 53 | def test_select_spectra_belonging_to_inchikey(sample_spectra): |
29 | | - inchikeys = [14 * "A", 14 * "B"] |
| 54 | + inchikeys = [_inchikey("A"), _inchikey("B")] |
30 | 55 | result = select_spectra_belonging_to_inchikey(sample_spectra, inchikeys) |
31 | 56 | assert len(result) == 3 |
32 | | - assert result[0].get("inchikey") == 14 * "A" |
| 57 | + assert result[0].get("inchikey") == _inchikey("A") |
| 58 | + assert all(s.get("inchikey")[:14] in inchikeys for s in result) |
| 59 | + |
| 60 | + |
| 61 | +def test_select_spectra_belonging_to_inchikey_empty_match(sample_spectra): |
| 62 | + result = select_spectra_belonging_to_inchikey(sample_spectra, [_inchikey("Z")]) |
| 63 | + assert result == [] |
33 | 64 |
|
34 | 65 |
|
35 | | -def test_split_spectra_in_random_inchikey_sets(sample_spectra): |
36 | | - # TODO: this is still a dummy test mostly |
| 66 | +def test_split_spectra_in_random_inchikey_sets_preserves_all_spectra(sample_spectra): |
37 | 67 | val, test, train = split_spectra_in_random_inchikey_sets(sample_spectra, 2, 42) |
38 | | - assert len(val) + len(test) + len(train) == 4 |
| 68 | + assert len(val) + len(test) + len(train) == len(sample_spectra) |
| 69 | + |
| 70 | + |
| 71 | +def test_split_spectra_in_random_inchikey_sets_splits_by_inchikey_group(larger_sample_spectra): |
| 72 | + val, test, train = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 42) |
| 73 | + |
| 74 | + val_keys = set(_unique_inchikeys_in_spectra(val)) |
| 75 | + test_keys = set(_unique_inchikeys_in_spectra(test)) |
| 76 | + train_keys = set(_unique_inchikeys_in_spectra(train)) |
| 77 | + |
| 78 | + assert val_keys.isdisjoint(test_keys) |
| 79 | + assert val_keys.isdisjoint(train_keys) |
| 80 | + assert test_keys.isdisjoint(train_keys) |
| 81 | + |
| 82 | + all_keys = val_keys | test_keys | train_keys |
| 83 | + assert all_keys == set(_unique_inchikeys_in_spectra(larger_sample_spectra)) |
| 84 | + |
| 85 | + |
| 86 | +def test_split_spectra_in_random_inchikey_sets_expected_unique_group_sizes(larger_sample_spectra): |
| 87 | + val, test, train = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 42) |
| 88 | + |
| 89 | + # 8 unique inchikeys, k=4 -> fraction_size = 2 |
| 90 | + assert len(_unique_inchikeys_in_spectra(val)) == 2 |
| 91 | + assert len(_unique_inchikeys_in_spectra(test)) == 2 |
| 92 | + assert len(_unique_inchikeys_in_spectra(train)) == 4 |
| 93 | + |
| 94 | + # two spectra per inchikey |
| 95 | + assert len(val) == 4 |
| 96 | + assert len(test) == 4 |
| 97 | + assert len(train) == 8 |
| 98 | + |
| 99 | + |
| 100 | +def test_split_spectra_in_random_inchikey_sets_same_seed_is_stable(larger_sample_spectra): |
| 101 | + val1, test1, train1 = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 42) |
| 102 | + val2, test2, train2 = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 42) |
| 103 | + |
| 104 | + assert _unique_inchikeys_in_spectra(val1) == _unique_inchikeys_in_spectra(val2) |
| 105 | + assert _unique_inchikeys_in_spectra(test1) == _unique_inchikeys_in_spectra(test2) |
| 106 | + assert _unique_inchikeys_in_spectra(train1) == _unique_inchikeys_in_spectra(train2) |
| 107 | + |
| 108 | + |
| 109 | +def test_split_spectra_in_random_inchikey_sets_different_seed_can_change_split(larger_sample_spectra): |
| 110 | + val1, test1, train1 = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 1) |
| 111 | + val2, test2, train2 = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, 2) |
| 112 | + |
| 113 | + split1 = ( |
| 114 | + _unique_inchikeys_in_spectra(val1), |
| 115 | + _unique_inchikeys_in_spectra(test1), |
| 116 | + _unique_inchikeys_in_spectra(train1), |
| 117 | + ) |
| 118 | + split2 = ( |
| 119 | + _unique_inchikeys_in_spectra(val2), |
| 120 | + _unique_inchikeys_in_spectra(test2), |
| 121 | + _unique_inchikeys_in_spectra(train2), |
| 122 | + ) |
| 123 | + |
| 124 | + assert split1 != split2 |
| 125 | + |
| 126 | + |
| 127 | +def test_split_spectra_in_random_inchikey_sets_none_seed_still_preserves_partition(larger_sample_spectra): |
| 128 | + val, test, train = split_spectra_in_random_inchikey_sets(larger_sample_spectra, 4, None) |
| 129 | + |
| 130 | + val_keys = set(_unique_inchikeys_in_spectra(val)) |
| 131 | + test_keys = set(_unique_inchikeys_in_spectra(test)) |
| 132 | + train_keys = set(_unique_inchikeys_in_spectra(train)) |
| 133 | + |
| 134 | + assert val_keys.isdisjoint(test_keys) |
| 135 | + assert val_keys.isdisjoint(train_keys) |
| 136 | + assert test_keys.isdisjoint(train_keys) |
| 137 | + assert len(val) + len(test) + len(train) == len(larger_sample_spectra) |
0 commit comments