diff --git a/theta/mathtools.py b/theta/mathtools.py index a54cf53..09b8fb0 100644 --- a/theta/mathtools.py +++ b/theta/mathtools.py @@ -36,7 +36,6 @@ def rtbm_parts(v, bv, bh, t, w, q, mode=1): return ( np.sqrt(detT / (2.0 * np.pi) ** (v.shape[0])) * ExpF ), ( vR1 / vR2 * np.exp(uR1-uR2) ) - def rtbm_probability(v, bv, bh, t, w, q, mode=1): """Implements the RTBM probability""" detT = np.linalg.det(t) @@ -58,6 +57,37 @@ def rtbm_probability(v, bv, bh, t, w, q, mode=1): return np.sqrt(detT / (2.0 * np.pi) ** (v.shape[0])) * ExpF * vR1 / vR2 * np.exp(uR1-uR2) +def factorized_rtbm_probability(v, bv, bh, t, w, q, mode=1): + """Implements the RTBM probability""" + detT = np.linalg.det(t) + invT = np.linalg.inv(t) + vT = v.T + vTv = np.dot(np.dot(vT, t), v) + BvT = bv.T + BhT = bh.T + Bvv = np.dot(BvT, v) + BiTB = np.dot(np.dot(BvT, invT), bv) + BtiTW = np.dot(np.dot(BvT, invT), w) + WtiTW = np.dot(np.dot(w.T, invT), w) + + ExpF = np.exp(-0.5 * vTv.diagonal() - Bvv - BiTB * np.ones(v.shape[1])) + + # factorize Q + vWb = (vT.dot(w) + BhT) + uR1 = np.ones(vWb.shape[0], dtype=complex) + vR1 = np.ones(vWb.shape[0], dtype=complex) + + for i in range(vWb.shape[1]): + O = np.matrix([[q[i, i]]], dtype=complex) + tuR1, tvR1 = RiemannTheta.parts_eval(vWb[:, [i]] / (2.0j * np.pi), -O / (2.0j * np.pi), mode, epsilon=RTBM_precision) + uR1 *= tuR1 + vR1 *= tvR1 + + uR2, vR2 = RiemannTheta.parts_eval((BhT - BtiTW) / (2.0j * np.pi), (-q + WtiTW) / (2.0j * np.pi), mode, epsilon=RTBM_precision) + + return np.sqrt(detT / (2.0 * np.pi) ** (v.shape[0])) * ExpF * vR1 / vR2 * np.exp(uR1-uR2) + + def rtbm_log_probability(v, bv, bh, t, w, q, mode=1): """Implements the RTBM probability""" detT = np.linalg.det(t) diff --git a/theta/rtbm.py b/theta/rtbm.py index 9ba101d..e831523 100644 --- a/theta/rtbm.py +++ b/theta/rtbm.py @@ -2,10 +2,11 @@ from __future__ import absolute_import import numpy as np from theta.mathtools import rtbm_probability, hidden_expectations, rtbm_log_probability, \ - check_normalization_consistency, check_pos_def + check_normalization_consistency, check_pos_def, factorized_rtbm_probability from theta.riemann_theta.riemann_theta import RiemannTheta + class AssignError(Exception): pass @@ -18,7 +19,7 @@ class Mode: Expectation = 2 def __init__(self, visible_units, hidden_units, mode=Mode.Probability, - init_max_param_bound=2, random_bound=1, phase=1, diagonal_T=False): + init_max_param_bound=2, random_bound=1, phase=1, diagonal_T=False, diagonal_Q=False): """Setup operators for BM based on the number of visible and hidden units Args: @@ -40,14 +41,21 @@ def __init__(self, visible_units, hidden_units, mode=Mode.Probability, self._w = np.zeros([visible_units, hidden_units]) self._q = np.zeros([hidden_units, hidden_units]) self._diagonal_T = diagonal_T + self._diagonal_Q = diagonal_Q self._mode = None self._call = None self._parameters = None self._X = None self._phase = phase self._check_positivity = True - if diagonal_T: - self._size = 2 * self._Nv + self._Nh + (self._Nh**2+self._Nh)//2 + self._Nv*self._Nh + + self._size = self._Nv + self._Nh + self._Nv*self._Nh + if diagonal_T and diagonal_Q: + self._size += self._Nv + self._Nh + elif diagonal_T and not diagonal_Q: + self._size += self._Nv + (self._Nh**2+self._Nh)//2 + elif not diagonal_T and diagonal_Q: + self._size += self._Nh + (self._Nv**2+self._Nv)//2 else: self._size = self._Nv + self._Nh + (self._Nv**2+self._Nv+self._Nh**2+self._Nh)//2+self._Nv*self._Nh @@ -119,7 +127,7 @@ def random_init(self, bound): a_size = (self._Nv+self._Nh)**2 params = np.random.uniform(-bound, bound, a_size+self._Nv+self._Nh) - if self._diagonal_T: + if self._diagonal_T or self._diagonal_Q: x = np.eye(a_shape[0]) np.fill_diagonal(x, params[:self._Nv+self._Nh]) else: @@ -134,10 +142,18 @@ def random_init(self, bound): self._bh = self._phase * params[-self._Nh:].reshape(self._bh.shape) # store parameters having in mind that Q and T are symmetric. - if self._diagonal_T: + if self._diagonal_T and not self._diagonal_Q: self._parameters = np.concatenate([self._bv.flatten(), self._bh.flatten(), self._w.flatten(), self._t.diagonal(), self._q[np.triu_indices(self._Nh)]]) + elif not self._diagonal_T and self._diagonal_Q: + self._parameters = np.concatenate([self._bv.flatten(), self._bh.flatten(), + self._w.flatten(), self._t[np.triu_indices(self._Nv)], + self._q.diagonal()]) + elif self._diagonal_T and self._diagonal_Q: + self._parameters = np.concatenate([self._bv.flatten(), self._bh.flatten(), + self._w.flatten(), self._t.diagonal(), + self._q.diagonal()]) else: self._parameters = np.concatenate([self._bv.flatten(), self._bh.flatten(), self._w.flatten(), self._t[np.triu_indices(self._Nv)], @@ -176,9 +192,13 @@ def set_parameters(self, params): self._t[(inds[1], inds[0])] = params[index:index+(self._Nv**2+self._Nv)//2] index += (self._Nv**2+self._Nv)//2 - inds = np.triu_indices_from(self._q) - self._q[inds] = params[index:index+(self._Nh**2+self._Nh)//2] - self._q[(inds[1], inds[0])] = params[index:index+(self._Nh**2+self._Nh)//2] + if self._diagonal_Q: + np.fill_diagonal(self._q, params[index:index+self._Nh]) + index += self._Nh + else: + inds = np.triu_indices_from(self._q) + self._q[inds] = params[index:index+(self._Nh**2+self._Nh)//2] + self._q[(inds[1], inds[0])] = params[index:index+(self._Nh**2+self._Nh)//2] if self._check_positivity: if not check_normalization_consistency(self._t, self._q, self._w) or \ @@ -194,12 +214,14 @@ def get_gradients(self): """Return flat array with calculated gradients [Gbh,Gbv,Gw,Gt,Gq] """ - inds = np.triu_indices_from(self._gradQ) - if(self._diagonal_T): + if self._diagonal_T and not self._diagonal_Q: return np.real(np.concatenate((self._gradBv.flatten(),self._gradBh.flatten(),self._gradW.flatten(), self._gradT.diagonal(), self._gradQ[inds].flatten() ))) - + elif not self._diagonal_T and self._diagonal_Q: + return np.real(np.concatenate((self._gradBv.flatten(),self._gradBh.flatten(),self._gradW.flatten(), self._gradT.flatten(), self._gradQ.diagonal() ))) + elif self._diagonal_T and self._diagonal_Q: + return np.real(np.concatenate((self._gradBv.flatten(),self._gradBh.flatten(),self._gradW.flatten(), self._gradT.diagonal(), self._gradQ.diagonal() ))) else: return np.real(np.concatenate((self._gradBv.flatten(),self._gradBh.flatten(),self._gradW.flatten(), self._gradT.flatten(), self._gradQ[inds].flatten() ))) @@ -240,7 +262,10 @@ def mode(self, value): mode = 2 if value is self.Mode.Probability: - self._call = lambda data: np.real(rtbm_probability(data, self._bv, self._bh, self._t, self._w, self._q, mode)) + if self._diagonal_Q: + self._call = lambda data: np.real(factorized_rtbm_probability(data, self._bv, self._bh, self._t, self._w, self._q, mode)) + else: + self._call = lambda data: np.real(rtbm_probability(data, self._bv, self._bh, self._t, self._w, self._q, mode)) elif value is self.Mode.LogProbability: self._call = lambda data: np.real(rtbm_log_probability(data, self._bv, self._bh, self._t, self._w, self._q, mode)) elif value is self.Mode.Expectation: