1
+ from typing import Callable
2
+ import pytest
3
+ import inspect
4
+ import numpy as np
5
+ from corel .kernel .hellinger import get_mean_and_amplitude
6
+ from corel .kernel .hellinger import _hellinger_distance
7
+ from corel .kernel .hellinger import _k
8
+ from corel .kernel import HellingerReference
9
+ from corel .kernel import Hellinger
10
+ from corel .kernel import WeightedHellinger
11
+ import tensorflow as tf
12
+ import matplotlib .pyplot as plt
13
+
14
+ # define test sequences and test alphabet and test weighting distributions
15
+ SEED = 12
16
+ N = 20
17
+ L = 15
18
+ AA = 3
19
+ np .random .seed (SEED )
20
+
21
+ simulated_decoding_distributions = np .stack ([
22
+ np .random .dirichlet (np .ones (AA ), L ) for _ in range (N )
23
+ ])
24
+
25
+ simulated_weighting_vec = np .random .dirichlet (np .ones (AA ), L )
26
+
27
+
28
+ @pytest .mark .parametrize ("dist" , [simulated_decoding_distributions ])#, simulated_weighting_vec])
29
+ def test_simulated_dist_is_probabilities (dist ):
30
+ summed_dist = np .sum (dist , axis = - 1 )
31
+ np .testing .assert_almost_equal (summed_dist , np .ones ((N , L )))
32
+ np .testing .assert_array_less (dist , np .ones_like (dist ))
33
+ np .testing .assert_array_less (np .zeros_like (dist ), dist )
34
+
35
+
36
+ def test_simulated_w_vec_is_probabilities ():
37
+ summed_dist = np .sum (simulated_weighting_vec , axis = - 1 )
38
+ np .testing .assert_almost_equal (summed_dist , np .ones_like (summed_dist ))
39
+
40
+
41
+ def really_naive_r (p_x : np .ndarray , q_y : np .ndarray ):
42
+ assert p_x .shape [0 ] == q_y .shape [0 ] and p_x .shape [1 ] == q_y .shape [1 ], "Input distributions inconsistent"
43
+ L = p_x .shape [0 ]
44
+ AA = p_x .shape [1 ]
45
+ dist_prod_sum_across_sequence = 0.
46
+ for l in range (L ):
47
+ for a in range (AA ):
48
+ dist_prod_sum_across_sequence += np .sqrt (p_x [l , a ] * q_y [l , a ])
49
+ return np .sqrt (1 - dist_prod_sum_across_sequence )
50
+
51
+
52
+ def naive_r (p_x : np .ndarray , q_y : np .ndarray ):
53
+ """
54
+ p, q are probability distributions (ie. decoder distributions)
55
+ """
56
+ # assumption sequences x , y are of shape (L, |AA|) with L seq-length and |AA| size of alphabet
57
+ assert p_x .shape [0 ] == q_y .shape [0 ] and p_x .shape [1 ] == q_y .shape [1 ], "Input distributions inconsistent"
58
+ if np .all (p_x == q_y ):
59
+ # the Hellinger distance between equal distributions is 0, but numerically this could fail
60
+ # In that case summed_pq_vals_across_sequence can become slightly larger than 1 resulting in NaNs when taking the square root
61
+ return 0.
62
+ L = p_x .shape [0 ]
63
+ AA = p_x .shape [1 ]
64
+ summed_pq_vals_across_sequence = []
65
+ for l in range (L ):
66
+ alphabet_prod_vals = []
67
+ for a in range (AA ):
68
+ pq_sqrt_prod = np .sqrt (p_x [l , a ]* q_y [l ,a ]) # TODO: for weighting: add weighting dist here
69
+ alphabet_prod_vals .append (pq_sqrt_prod )
70
+ summed_alphabet_vals = np .sum (alphabet_prod_vals )
71
+ summed_pq_vals_across_sequence .append (summed_alphabet_vals )
72
+ dist_prod_sum_across_sequence = np .prod (summed_pq_vals_across_sequence )
73
+ assert dist_prod_sum_across_sequence <= 1
74
+ return np .sqrt (1 - dist_prod_sum_across_sequence )
75
+
76
+
77
+ def naive_r_w (p_x : np .ndarray , q_y : np .ndarray , w : np .ndarray ):
78
+ """
79
+ p, q are probability distributions (ie. decoder distributions),
80
+ w is weighting distribution (ie. decoder out)
81
+ """
82
+ # assumption sequences x , y are of shape (L, |AA|) with L seq-length and |AA| size of alphabet
83
+ assert p_x .shape [0 ] == q_y .shape [0 ] and p_x .shape [1 ] == q_y .shape [1 ] and p_x .shape [0 ] == w .shape [0 ], "Input distributions inconsistent"
84
+ if np .all (p_x == q_y ):
85
+ # the Hellinger distance between equal distributions is 0, but numerically this could fail
86
+ # In that case summed_pq_vals_across_sequence can become slightly larger than 1 resulting in NaNs when taking the square root
87
+ return 0.
88
+ L = p_x .shape [0 ]
89
+ AA = p_x .shape [1 ]
90
+ summed_pq_vals_across_sequence = []
91
+ for l in range (L ):
92
+ alphabet_prod_vals = []
93
+ for a in range (AA ):
94
+ pq_sqrt_prod = w [l ,a ] * np .sqrt (p_x [l ,a ]* q_y [l ,a ])
95
+ alphabet_prod_vals .append (pq_sqrt_prod )
96
+ summed_alphabet_vals = np .sum (alphabet_prod_vals )
97
+ summed_pq_vals_across_sequence .append (summed_alphabet_vals )
98
+ dist_prod_sum_across_sequence = np .prod (summed_pq_vals_across_sequence )
99
+ assert dist_prod_sum_across_sequence <= 1
100
+ lhs_weighted_pq_values = []
101
+ for l in range (L ):
102
+ alphabet_prod_vals = []
103
+ for a in range (AA ):
104
+ weighted_pq_sum = 1 / 2 * w [l ,a ]* p_x [l ,a ] + 1 / 2 * w [l ,a ]* q_y [l ,a ]
105
+ alphabet_prod_vals .append (weighted_pq_sum )
106
+ summed_alphabet_vals = np .sum (alphabet_prod_vals )
107
+ lhs_weighted_pq_values .append (summed_alphabet_vals )
108
+ lhs_expectation = np .prod (lhs_weighted_pq_values )
109
+ return np .sqrt (lhs_expectation - dist_prod_sum_across_sequence )
110
+
111
+
112
+ # implement naive Hellinger function
113
+ def naive_kernel (p : np .ndarray , q : np .ndarray , theta : float , lam : float ) -> float :
114
+ """
115
+ Naive hellinger distance computation and covariance computation of
116
+ inputs: p , q distribution vectors
117
+ theta, lam covariance function parameters
118
+ returns:
119
+ kernel value
120
+ """
121
+ distance_mat = np .zeros ((p .shape [0 ], q .shape [0 ]))
122
+ for i in range (distance_mat .shape [0 ]):
123
+ for j in range (distance_mat .shape [1 ]):
124
+ distance_mat [i ,j ] = naive_r (p [i ], q [j ])
125
+ if not np .isfinite (distance_mat [i ,j ]):
126
+ print ("Introduced NaN here!" )
127
+ return theta * np .exp ( - lam * distance_mat )
128
+
129
+
130
+ def naive_weighted_kernel (p : np .ndarray , q : np .ndarray , w : np .ndarray , theta : float , lam : float ):
131
+ distance_mat = np .zeros ((p .shape [0 ], q .shape [0 ]))
132
+ for i in range (distance_mat .shape [0 ]):
133
+ for j in range (distance_mat .shape [1 ]):
134
+ distance_mat [i ,j ] = naive_r_w (p [i ], q [j ], w )
135
+ if not np .isfinite (distance_mat [i ,j ]):
136
+ print ("NaN here!" )
137
+ return theta * np .exp ( - lam * distance_mat )
138
+
139
+
140
+ # Simon implementation is p_0 is [N,1]
141
+
142
+ # def test_kernel_functions_distance_against_naive(): # TODO: this is not the same _hellinger_distance expects vector of 1d not number of elements in alphabet deep!
143
+ # p_0 = simulated_decoding_distributions[:, :, 0] # not weighted p-vec, for comparison against [N, 1], make ones vector of that
144
+ # dist_kernelmodule_function = _hellinger_distance(p_0)
145
+ # dist_naive = naive_r(p_0, p_0)
146
+ # np.testing.assert_almost_equal(dist_kernelmodule_function, dist_naive)
147
+
148
+
149
+ # def test_kernel_functions_k_against_naive_k(): # TODO requires understanding of HD
150
+ # lam = 0.5
151
+ # noise = 0.01
152
+ # module_dist = _hellinger_distance(simulated_decoding_distributions)
153
+ ## NOTE: Simon's _k is defined over atomic distributions for numerical efficiency, therefore not comparable here
154
+ # Test works only on one-hot vectors
155
+ # module_k = _k(module_dist, lengthscale=np.log(lam), log_noise=np.log(noise))
156
+ # naive_k = naive_kernel(simulated_decoding_distributions, simulated_decoding_distributions, theta=1, lam=lam)
157
+ # np.testing.assert_almost_equal(module_k, naive_k)
158
+
159
+
160
+ def test_kernel_implementation_naive ():
161
+ """
162
+ test naive sum product reference implementation against GPFlow reference implementation
163
+ """
164
+ theta = 1.
165
+ lam = 1.
166
+ naive_k_matrix = naive_kernel (simulated_decoding_distributions , simulated_decoding_distributions , theta = theta , lam = lam )
167
+ hk = Hellinger (L = L , AA = AA , lengthscale = lam )
168
+ hk_matrix = hk .K (simulated_decoding_distributions , simulated_decoding_distributions )[0 ].numpy ()
169
+ np .testing .assert_allclose (hk_matrix , naive_k_matrix , rtol = 1e-6 )
170
+
171
+
172
+ def test_weighted_kernel_implementation_naive ():
173
+ theta = 1.
174
+ lam = 1.
175
+ whk = WeightedHellinger (w = tf .convert_to_tensor (simulated_weighting_vec ), L = L , AA = AA , lengthscale = lam )
176
+ whk_matrix = whk .K (simulated_decoding_distributions , simulated_decoding_distributions )
177
+ naive_whk_matrix = naive_weighted_kernel (simulated_decoding_distributions , simulated_decoding_distributions , simulated_weighting_vec ,
178
+ lam = lam , theta = theta )
179
+ np .testing .assert_allclose (naive_whk_matrix , whk_matrix [0 ], 5 ) # TODO: cov. matrix shape from WHK
180
+
181
+ # def test_kernel_functions_k():
182
+ # # TODO: test k function in hellinger module
183
+ # assert False
184
+
185
+ # TODO: test src/corel/kernel GPFlow weighted implementation
186
+
187
+ # def test_kernel_functions_hd():
188
+ # # TODO: test distance function in hellinger module
189
+ # assert False
190
+
191
+
192
+ if __name__ == "__main__" : # NOTE: added for the debugger to work!
193
+ test_kernel_implementation_naive ()
194
+ test_weighted_kernel_implementation_naive ()
0 commit comments