-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdbnclassprobs.py
47 lines (37 loc) · 1.32 KB
/
dbnclassprobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from rbmclassprobs import rbmclassprobs
import sys
sys.path.insert(0, './util/')
from sigm import sigm
# function class_probs = dbnclassprobs( dbn,x, batchsize )
# %DBNCLASSPROBS calculates p(y|x) in a classification DBN
# %
# % INPUTS
# % dbn : A dbn struct
# % x : matrix of samples (n_samlples-by-n_features)
# % batchsize : optionally takes a minibatch size in which case the result
# % is calculated in minibatches to save memory
# %
# % OUTPUT
# % class_probs : class probabilites for each class (n_samples-by-n_classes)
# %
# % EXAMPLE
# % class_probs = dbnclassprobs( dbn,x )
# % pred = predict(x)
# %
# % See also, DBNPREDICT
#
# % Copyright Sřren Sřnderby july 2014
def dbnclassprobs(dbn, x, batchsize=None):
n_rbm = len(dbn) - 1 # o: n_rbm = len(dbn.rbm) \ o: n_rbm = len(dbn)
if not dbn[n_rbm].classRBM:
raise ValueError("Class probabilities can only be calc. for classification DBN")
# pass data deterministicly from input to top RBM
for i in range(n_rbm-1):
x = dbn[i].rbmup(dbn[i], x, [], sigm)
batchsize = 1
# at top RBM calculate class probabilities
if batchsize or 'var':
class_probs = rbmclassprobs(dbn[n_rbm], x, batchsize)
else:
class_probs = rbmclassprobs(dbn[n_rbm], x, batchsize)
return class_probs