Skip to content

Commit 1689114

Browse files
author
Richard Michael
committed
added HK testing and wHK testing and implementation
1 parent 18388a1 commit 1689114

File tree

7 files changed

+235
-38
lines changed

7 files changed

+235
-38
lines changed

mlruns/0/meta.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
artifact_location: file:///Users/rcml/corel/mlruns/0
2+
creation_time: 1696424246894
3+
experiment_id: '0'
4+
last_update_time: 1696424246894
5+
lifecycle_stage: active
6+
name: Default

src/corel/kernel/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .hellinger import _k
44
from .hellinger import _hellinger_distance
55
from .hellinger_reference import HellingerReference
6+
from .hellinger import Hellinger

src/corel/kernel/hellinger_reference.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Optional
22
import tensorflow as tf
3-
3+
import numpy as np
44
import gpflow
55
from gpflow.kernels import Kernel
66
from gpflow.utilities import positive
@@ -18,7 +18,12 @@ def __init__(self, L:int, AA:int, lengthscale: float=1.0, noise: float=0.1, acti
1818
def restore(self, ps: tf.Tensor) -> tf.Tensor:
1919
ps = tf.squeeze(ps)
2020
N = 1 if len(ps.shape) == 1 else ps.shape[0]
21-
return tf.reshape(ps, shape=(N, ps.shape[-1] // self.AA, self.AA))
21+
if ps.shape[-1] != self.AA:
22+
return tf.reshape(ps, shape=(N, ps.shape[-1] // self.AA, self.AA))
23+
elif ps.shape[0] == N and ps.shape[1] == self.L and ps.shape[-1] == self.AA:
24+
return ps
25+
else:
26+
raise ValueError(f"Vector p shape incorrect! {ps.shape}")
2227

2328
def K(self, X, X2=None) -> tf.Tensor:
2429
if X2 is None:
@@ -32,18 +37,18 @@ def K_diag(self, X) -> tf.Tensor:
3237
return tf.ones(X.shape[0])
3338

3439
def _assert_X_values(self, X: tf.Tensor, tol:float) -> bool:
35-
return tf.all(tf.abs(tf.math.reduce_sum(X, axis=-1) - 1.) < tol)
40+
return (tf.abs(tf.math.reduce_sum(X, axis=-1) - 1.) < tol).numpy().all()
3641

3742
def _hellinger2(self, X: tf.Tensor, X2: tf.Tensor, tol: float=1e-5):
38-
M = tf.zeros([X.shape[0], X2.shape[0]])
43+
M = np.zeros([X.shape[0], X2.shape[0]], dtype=np.float64)
3944
X = self.restore(X)
4045
X2 = self.restore(X2)
4146
assert self._assert_X_values(X, tol)
4247
assert self._assert_X_values(X2, tol)
43-
_tmp = tf.zeros(self.L)
48+
_tmp = np.zeros(self.L, dtype=np.float64)
4449
for x_idx in range(X.shape[0]):
4550
for y_idx in range(X2.shape[0]):
4651
for l_idx in range(self.L):
47-
_tmp[l_idx] = tf.math.reduce_sum(tf.math.sqrt(X[x_idx, l_idx, :] * X2[y_idx, l_idx, :]))
48-
M[x_idx, y_idx] = 1 - tf.reduce_prod(_tmp)
49-
return M
52+
_tmp[l_idx] = tf.math.reduce_sum(tf.math.sqrt(X[x_idx, l_idx, :] * X2[y_idx, l_idx, :])).numpy() # NOTE: this type of assignment not supported by TF only numpy
53+
M[x_idx, y_idx] = (1 - tf.reduce_prod(_tmp)).numpy() # NOTE: build np matrix by assignment
54+
return tf.convert_to_tensor(M)

src/corel/kernel/weighted_hellinger.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99

1010

1111
class WeightedHellinger(Hellinger):
12-
def __init__(self, z: tf.Tensor, L: int, AA: int, lengthscale: float=1.0, noise: float=0.1, active_dims: Optional[int] = None, name: Optional[str] = None) -> None:
12+
def __init__(self, w: tf.Tensor, L: int, AA: int, lengthscale: float=1.0, noise: float=0.1, active_dims: Optional[int] = None, name: Optional[str] = None) -> None:
1313
super().__init__(L=L, AA=AA, active_dims=active_dims, name=name)
14-
self.z = z
15-
# TODO assert p in [0,1]
14+
self.w = w # weighting density vector
1615
self.lengthscale = gpflow.Parameter(lengthscale, transform=positive()) # TODO: log transform here?
1716
self.noise = gpflow.Parameter(noise, transform=positive()) # TODO: check against Kernel Interface
1817

19-
def K(self, X, X2=None) -> tf.Tensor:
18+
def K(self, X: tf.Tensor, X2: Optional[tf.Tensor]=None) -> tf.Tensor:
2019
"""
2120
X input is P(X)
2221
"""
@@ -40,15 +39,26 @@ def K(self, X, X2=None) -> tf.Tensor:
4039
M = tf.reshape(M, shape=(1, M.shape[0], 1, M.shape[1])) # adhere to [batch..., N1, batch..., N2]
4140
return M
4241

43-
def _H(self, X: tf.Tensor, X2: tf.Tensor):
44-
raise NotImplementedError("TODO: implement weighting by expected value")
42+
def _get_inner_product(self, X: tf.Tensor, X2: tf.Tensor) -> tf.Tensor:
43+
"""
44+
Compute RHS of weighted HK equation, as weighting times sqrt(p[a_l,l] x q[a_l,l])
45+
"""
46+
# M = tf.math.reduce_sum(self.w * tf.sqrt(X[None, ...] * X2[:, None, ...]), axis=-1)
47+
# NOTE: the einsum and reduce_sum product should be equivalent
48+
M = tf.einsum('ali,bli->abl', tf.sqrt(tf.pow(self.w,2)*X), tf.sqrt(X2))
49+
return tf.math.reduce_prod(M, axis=-1) # product over L, positions factorize
50+
51+
def _compute_lhs(self, X: tf.Tensor, X2: tf.Tensor) -> tf.Tensor:
52+
w_p = tf.math.reduce_sum(self.w*X[None, ...], axis=-1) / 2
53+
w_q = tf.math.reduce_sum(self.w*X2[:, None, ...], axis=-1) / 2
54+
return tf.math.reduce_prod(w_p+w_q, axis=-1)
55+
56+
def _H(self, X: tf.Tensor, X2: tf.Tensor) -> tf.Tensor:
4557
M = self._get_inner_product(X, X2)
46-
# TODO: correctly compute the z vector!
47-
# z = tf.reduce_sum(tf.squeeze(self.z), -1)[None:]
48-
M = z@tf.transpose(z) - M
49-
#M[M < 0.] = 0.
58+
# NOTE: LHS is expectation with equal weight, could have weighting
59+
weighted_E = self._compute_lhs(X, X2)
60+
M = weighted_E - M
5061
M = tf.where(M < 0., tf.zeros_like(M), M)
51-
5262
M = tf.where(M == 0., tf.zeros_like(M), M) # fix gradients
5363
M = tf.exp(-tf.sqrt(M) / tf.square(self.lengthscale))
5464
return M
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from typing import Callable
2+
import pytest
3+
import inspect
4+
import numpy as np
5+
from corel.kernel.hellinger import get_mean_and_amplitude
6+
from corel.kernel.hellinger import _hellinger_distance
7+
from corel.kernel.hellinger import _k
8+
from corel.kernel import HellingerReference
9+
from corel.kernel import Hellinger
10+
from corel.kernel import WeightedHellinger
11+
import tensorflow as tf
12+
import matplotlib.pyplot as plt
13+
14+
# define test sequences and test alphabet and test weighting distributions
15+
SEED=12
16+
N = 20
17+
L = 15
18+
AA = 3
19+
np.random.seed(SEED)
20+
21+
simulated_decoding_distributions = np.stack([
22+
np.random.dirichlet(np.ones(AA), L) for _ in range(N)
23+
])
24+
25+
simulated_weighting_vec = np.random.dirichlet(np.ones(AA), L)
26+
27+
28+
@pytest.mark.parametrize("dist", [simulated_decoding_distributions])#, simulated_weighting_vec])
29+
def test_simulated_dist_is_probabilities(dist):
30+
summed_dist = np.sum(dist, axis=-1)
31+
np.testing.assert_almost_equal(summed_dist, np.ones((N, L)))
32+
np.testing.assert_array_less(dist, np.ones_like(dist))
33+
np.testing.assert_array_less(np.zeros_like(dist), dist)
34+
35+
36+
def test_simulated_w_vec_is_probabilities():
37+
summed_dist = np.sum(simulated_weighting_vec, axis=-1)
38+
np.testing.assert_almost_equal(summed_dist, np.ones_like(summed_dist))
39+
40+
41+
def really_naive_r(p_x: np.ndarray, q_y: np.ndarray):
42+
assert p_x.shape[0] == q_y.shape[0] and p_x.shape[1] == q_y.shape[1], "Input distributions inconsistent"
43+
L = p_x.shape[0]
44+
AA = p_x.shape[1]
45+
dist_prod_sum_across_sequence = 0.
46+
for l in range(L):
47+
for a in range(AA):
48+
dist_prod_sum_across_sequence += np.sqrt(p_x[l, a] * q_y[l, a])
49+
return np.sqrt(1 - dist_prod_sum_across_sequence)
50+
51+
52+
def naive_r(p_x: np.ndarray, q_y: np.ndarray):
53+
"""
54+
p, q are probability distributions (ie. decoder distributions)
55+
"""
56+
# assumption sequences x , y are of shape (L, |AA|) with L seq-length and |AA| size of alphabet
57+
assert p_x.shape[0] == q_y.shape[0] and p_x.shape[1] == q_y.shape[1], "Input distributions inconsistent"
58+
if np.all(p_x == q_y):
59+
# the Hellinger distance between equal distributions is 0, but numerically this could fail
60+
# In that case summed_pq_vals_across_sequence can become slightly larger than 1 resulting in NaNs when taking the square root
61+
return 0.
62+
L = p_x.shape[0]
63+
AA = p_x.shape[1]
64+
summed_pq_vals_across_sequence = []
65+
for l in range(L):
66+
alphabet_prod_vals = []
67+
for a in range(AA):
68+
pq_sqrt_prod = np.sqrt(p_x[l, a]*q_y[l,a]) # TODO: for weighting: add weighting dist here
69+
alphabet_prod_vals.append(pq_sqrt_prod)
70+
summed_alphabet_vals = np.sum(alphabet_prod_vals)
71+
summed_pq_vals_across_sequence.append(summed_alphabet_vals)
72+
dist_prod_sum_across_sequence = np.prod(summed_pq_vals_across_sequence)
73+
assert dist_prod_sum_across_sequence <= 1
74+
return np.sqrt(1 - dist_prod_sum_across_sequence)
75+
76+
77+
def naive_r_w(p_x: np.ndarray, q_y: np.ndarray, w: np.ndarray):
78+
"""
79+
p, q are probability distributions (ie. decoder distributions),
80+
w is weighting distribution (ie. decoder out)
81+
"""
82+
# assumption sequences x , y are of shape (L, |AA|) with L seq-length and |AA| size of alphabet
83+
assert p_x.shape[0] == q_y.shape[0] and p_x.shape[1] == q_y.shape[1] and p_x.shape[0] == w.shape[0], "Input distributions inconsistent"
84+
if np.all(p_x == q_y):
85+
# the Hellinger distance between equal distributions is 0, but numerically this could fail
86+
# In that case summed_pq_vals_across_sequence can become slightly larger than 1 resulting in NaNs when taking the square root
87+
return 0.
88+
L = p_x.shape[0]
89+
AA = p_x.shape[1]
90+
summed_pq_vals_across_sequence = []
91+
for l in range(L):
92+
alphabet_prod_vals = []
93+
for a in range(AA):
94+
pq_sqrt_prod = w[l,a] * np.sqrt(p_x[l,a]*q_y[l,a])
95+
alphabet_prod_vals.append(pq_sqrt_prod)
96+
summed_alphabet_vals = np.sum(alphabet_prod_vals)
97+
summed_pq_vals_across_sequence.append(summed_alphabet_vals)
98+
dist_prod_sum_across_sequence = np.prod(summed_pq_vals_across_sequence)
99+
assert dist_prod_sum_across_sequence <= 1
100+
lhs_weighted_pq_values = []
101+
for l in range(L):
102+
alphabet_prod_vals = []
103+
for a in range(AA):
104+
weighted_pq_sum = 1/2 * w[l,a]*p_x[l,a] + 1/2 * w[l,a]*q_y[l,a]
105+
alphabet_prod_vals.append(weighted_pq_sum)
106+
summed_alphabet_vals = np.sum(alphabet_prod_vals)
107+
lhs_weighted_pq_values.append(summed_alphabet_vals)
108+
lhs_expectation = np.prod(lhs_weighted_pq_values)
109+
return np.sqrt(lhs_expectation - dist_prod_sum_across_sequence)
110+
111+
112+
# implement naive Hellinger function
113+
def naive_kernel(p: np.ndarray, q: np.ndarray, theta: float, lam: float) -> float:
114+
"""
115+
Naive hellinger distance computation and covariance computation of
116+
inputs: p , q distribution vectors
117+
theta, lam covariance function parameters
118+
returns:
119+
kernel value
120+
"""
121+
distance_mat = np.zeros((p.shape[0], q.shape[0]))
122+
for i in range(distance_mat.shape[0]):
123+
for j in range(distance_mat.shape[1]):
124+
distance_mat[i,j] = naive_r(p[i], q[j])
125+
if not np.isfinite(distance_mat[i,j]):
126+
print("Introduced NaN here!")
127+
return theta * np.exp( -lam * distance_mat)
128+
129+
130+
def naive_weighted_kernel(p: np.ndarray, q: np.ndarray, w: np.ndarray, theta: float, lam: float):
131+
distance_mat = np.zeros((p.shape[0], q.shape[0]))
132+
for i in range(distance_mat.shape[0]):
133+
for j in range(distance_mat.shape[1]):
134+
distance_mat[i,j] = naive_r_w(p[i], q[j], w)
135+
if not np.isfinite(distance_mat[i,j]):
136+
print("NaN here!")
137+
return theta * np.exp( -lam * distance_mat)
138+
139+
140+
# Simon implementation is p_0 is [N,1]
141+
142+
# def test_kernel_functions_distance_against_naive(): # TODO: this is not the same _hellinger_distance expects vector of 1d not number of elements in alphabet deep!
143+
# p_0 = simulated_decoding_distributions[:, :, 0] # not weighted p-vec, for comparison against [N, 1], make ones vector of that
144+
# dist_kernelmodule_function = _hellinger_distance(p_0)
145+
# dist_naive = naive_r(p_0, p_0)
146+
# np.testing.assert_almost_equal(dist_kernelmodule_function, dist_naive)
147+
148+
149+
# def test_kernel_functions_k_against_naive_k(): # TODO requires understanding of HD
150+
# lam = 0.5
151+
# noise = 0.01
152+
# module_dist = _hellinger_distance(simulated_decoding_distributions)
153+
## NOTE: Simon's _k is defined over atomic distributions for numerical efficiency, therefore not comparable here
154+
# Test works only on one-hot vectors
155+
# module_k = _k(module_dist, lengthscale=np.log(lam), log_noise=np.log(noise))
156+
# naive_k = naive_kernel(simulated_decoding_distributions, simulated_decoding_distributions, theta=1, lam=lam)
157+
# np.testing.assert_almost_equal(module_k, naive_k)
158+
159+
160+
def test_kernel_implementation_naive():
161+
"""
162+
test naive sum product reference implementation against GPFlow reference implementation
163+
"""
164+
theta = 1.
165+
lam = 1.
166+
naive_k_matrix = naive_kernel(simulated_decoding_distributions, simulated_decoding_distributions, theta=theta, lam=lam)
167+
hk = Hellinger(L=L, AA=AA, lengthscale=lam)
168+
hk_matrix = hk.K(simulated_decoding_distributions, simulated_decoding_distributions)[0].numpy()
169+
np.testing.assert_allclose(hk_matrix, naive_k_matrix, rtol=1e-6)
170+
171+
172+
def test_weighted_kernel_implementation_naive():
173+
theta = 1.
174+
lam = 1.
175+
whk = WeightedHellinger(w=tf.convert_to_tensor(simulated_weighting_vec), L=L, AA=AA, lengthscale=lam)
176+
whk_matrix = whk.K(simulated_decoding_distributions, simulated_decoding_distributions)
177+
naive_whk_matrix = naive_weighted_kernel(simulated_decoding_distributions, simulated_decoding_distributions, simulated_weighting_vec,
178+
lam=lam, theta=theta)
179+
np.testing.assert_allclose(naive_whk_matrix, whk_matrix[0], 5) # TODO: cov. matrix shape from WHK
180+
181+
# def test_kernel_functions_k():
182+
# # TODO: test k function in hellinger module
183+
# assert False
184+
185+
# TODO: test src/corel/kernel GPFlow weighted implementation
186+
187+
# def test_kernel_functions_hd():
188+
# # TODO: test distance function in hellinger module
189+
# assert False
190+
191+
192+
if __name__ == "__main__": # NOTE: added for the debugger to work!
193+
test_kernel_implementation_naive()
194+
test_weighted_kernel_implementation_naive()

src/corel/test/test_kernel/test_hellinger_reference.py

-19
This file was deleted.

0 commit comments

Comments
 (0)