Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion theta/mathtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def rtbm_parts(v, bv, bh, t, w, q, mode=1):
return ( np.sqrt(detT / (2.0 * np.pi) ** (v.shape[0])) * ExpF ), ( vR1 / vR2 * np.exp(uR1-uR2) )



def rtbm_probability(v, bv, bh, t, w, q, mode=1):
"""Implements the RTBM probability"""
detT = np.linalg.det(t)
Expand All @@ -58,6 +57,37 @@ def rtbm_probability(v, bv, bh, t, w, q, mode=1):
return np.sqrt(detT / (2.0 * np.pi) ** (v.shape[0])) * ExpF * vR1 / vR2 * np.exp(uR1-uR2)


def factorized_rtbm_probability(v, bv, bh, t, w, q, mode=1):
"""Implements the RTBM probability"""
detT = np.linalg.det(t)
invT = np.linalg.inv(t)
vT = v.T
vTv = np.dot(np.dot(vT, t), v)
BvT = bv.T
BhT = bh.T
Bvv = np.dot(BvT, v)
BiTB = np.dot(np.dot(BvT, invT), bv)
BtiTW = np.dot(np.dot(BvT, invT), w)
WtiTW = np.dot(np.dot(w.T, invT), w)

ExpF = np.exp(-0.5 * vTv.diagonal() - Bvv - BiTB * np.ones(v.shape[1]))

# factorize Q
vWb = (vT.dot(w) + BhT)
uR1 = np.ones(vWb.shape[0], dtype=complex)
vR1 = np.ones(vWb.shape[0], dtype=complex)

for i in range(vWb.shape[1]):
O = np.matrix([[q[i, i]]], dtype=complex)
tuR1, tvR1 = RiemannTheta.parts_eval(vWb[:, [i]] / (2.0j * np.pi), -O / (2.0j * np.pi), mode, epsilon=RTBM_precision)
uR1 *= tuR1
vR1 *= tvR1

uR2, vR2 = RiemannTheta.parts_eval((BhT - BtiTW) / (2.0j * np.pi), (-q + WtiTW) / (2.0j * np.pi), mode, epsilon=RTBM_precision)

return np.sqrt(detT / (2.0 * np.pi) ** (v.shape[0])) * ExpF * vR1 / vR2 * np.exp(uR1-uR2)


def rtbm_log_probability(v, bv, bh, t, w, q, mode=1):
"""Implements the RTBM probability"""
detT = np.linalg.det(t)
Expand Down
51 changes: 38 additions & 13 deletions theta/rtbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from __future__ import absolute_import
import numpy as np
from theta.mathtools import rtbm_probability, hidden_expectations, rtbm_log_probability, \
check_normalization_consistency, check_pos_def
check_normalization_consistency, check_pos_def, factorized_rtbm_probability

from theta.riemann_theta.riemann_theta import RiemannTheta


class AssignError(Exception):
pass

Expand All @@ -18,7 +19,7 @@ class Mode:
Expectation = 2

def __init__(self, visible_units, hidden_units, mode=Mode.Probability,
init_max_param_bound=2, random_bound=1, phase=1, diagonal_T=False):
init_max_param_bound=2, random_bound=1, phase=1, diagonal_T=False, diagonal_Q=False):
"""Setup operators for BM based on the number of visible and hidden units

Args:
Expand All @@ -40,14 +41,21 @@ def __init__(self, visible_units, hidden_units, mode=Mode.Probability,
self._w = np.zeros([visible_units, hidden_units])
self._q = np.zeros([hidden_units, hidden_units])
self._diagonal_T = diagonal_T
self._diagonal_Q = diagonal_Q
self._mode = None
self._call = None
self._parameters = None
self._X = None
self._phase = phase
self._check_positivity = True
if diagonal_T:
self._size = 2 * self._Nv + self._Nh + (self._Nh**2+self._Nh)//2 + self._Nv*self._Nh

self._size = self._Nv + self._Nh + self._Nv*self._Nh
if diagonal_T and diagonal_Q:
self._size += self._Nv + self._Nh
elif diagonal_T and not diagonal_Q:
self._size += self._Nv + (self._Nh**2+self._Nh)//2
elif not diagonal_T and diagonal_Q:
self._size += self._Nh + (self._Nv**2+self._Nv)//2
else:
self._size = self._Nv + self._Nh + (self._Nv**2+self._Nv+self._Nh**2+self._Nh)//2+self._Nv*self._Nh

Expand Down Expand Up @@ -119,7 +127,7 @@ def random_init(self, bound):
a_size = (self._Nv+self._Nh)**2

params = np.random.uniform(-bound, bound, a_size+self._Nv+self._Nh)
if self._diagonal_T:
if self._diagonal_T or self._diagonal_Q:
x = np.eye(a_shape[0])
np.fill_diagonal(x, params[:self._Nv+self._Nh])
else:
Expand All @@ -134,10 +142,18 @@ def random_init(self, bound):
self._bh = self._phase * params[-self._Nh:].reshape(self._bh.shape)

# store parameters having in mind that Q and T are symmetric.
if self._diagonal_T:
if self._diagonal_T and not self._diagonal_Q:
self._parameters = np.concatenate([self._bv.flatten(), self._bh.flatten(),
self._w.flatten(), self._t.diagonal(),
self._q[np.triu_indices(self._Nh)]])
elif not self._diagonal_T and self._diagonal_Q:
self._parameters = np.concatenate([self._bv.flatten(), self._bh.flatten(),
self._w.flatten(), self._t[np.triu_indices(self._Nv)],
self._q.diagonal()])
elif self._diagonal_T and self._diagonal_Q:
self._parameters = np.concatenate([self._bv.flatten(), self._bh.flatten(),
self._w.flatten(), self._t.diagonal(),
self._q.diagonal()])
else:
self._parameters = np.concatenate([self._bv.flatten(), self._bh.flatten(),
self._w.flatten(), self._t[np.triu_indices(self._Nv)],
Expand Down Expand Up @@ -176,9 +192,13 @@ def set_parameters(self, params):
self._t[(inds[1], inds[0])] = params[index:index+(self._Nv**2+self._Nv)//2]
index += (self._Nv**2+self._Nv)//2

inds = np.triu_indices_from(self._q)
self._q[inds] = params[index:index+(self._Nh**2+self._Nh)//2]
self._q[(inds[1], inds[0])] = params[index:index+(self._Nh**2+self._Nh)//2]
if self._diagonal_Q:
np.fill_diagonal(self._q, params[index:index+self._Nh])
index += self._Nh
else:
inds = np.triu_indices_from(self._q)
self._q[inds] = params[index:index+(self._Nh**2+self._Nh)//2]
self._q[(inds[1], inds[0])] = params[index:index+(self._Nh**2+self._Nh)//2]

if self._check_positivity:
if not check_normalization_consistency(self._t, self._q, self._w) or \
Expand All @@ -194,12 +214,14 @@ def get_gradients(self):
"""Return flat array with calculated gradients
[Gbh,Gbv,Gw,Gt,Gq]
"""

inds = np.triu_indices_from(self._gradQ)

if(self._diagonal_T):
if self._diagonal_T and not self._diagonal_Q:
return np.real(np.concatenate((self._gradBv.flatten(),self._gradBh.flatten(),self._gradW.flatten(), self._gradT.diagonal(), self._gradQ[inds].flatten() )))

elif not self._diagonal_T and self._diagonal_Q:
return np.real(np.concatenate((self._gradBv.flatten(),self._gradBh.flatten(),self._gradW.flatten(), self._gradT.flatten(), self._gradQ.diagonal() )))
elif self._diagonal_T and self._diagonal_Q:
return np.real(np.concatenate((self._gradBv.flatten(),self._gradBh.flatten(),self._gradW.flatten(), self._gradT.diagonal(), self._gradQ.diagonal() )))
else:
return np.real(np.concatenate((self._gradBv.flatten(),self._gradBh.flatten(),self._gradW.flatten(), self._gradT.flatten(), self._gradQ[inds].flatten() )))

Expand Down Expand Up @@ -240,7 +262,10 @@ def mode(self, value):
mode = 2

if value is self.Mode.Probability:
self._call = lambda data: np.real(rtbm_probability(data, self._bv, self._bh, self._t, self._w, self._q, mode))
if self._diagonal_Q:
self._call = lambda data: np.real(factorized_rtbm_probability(data, self._bv, self._bh, self._t, self._w, self._q, mode))
else:
self._call = lambda data: np.real(rtbm_probability(data, self._bv, self._bh, self._t, self._w, self._q, mode))
elif value is self.Mode.LogProbability:
self._call = lambda data: np.real(rtbm_log_probability(data, self._bv, self._bh, self._t, self._w, self._q, mode))
elif value is self.Mode.Expectation:
Expand Down