-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtrain_gmm.m
101 lines (83 loc) · 2.23 KB
/
train_gmm.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
function gmm = train_gmm(trdata, nmix, niter, verb, CVPRIOR, mu0);
% gmmparams = train_gmm(trdata, nmix, niter, verb, cvprior, mu0);
%
% Train a GMM with diagonal covariance.
%
% Inputs:
% trdata - training data (cell array of training sequences, each
% column of the sequences arrays contains an
% observation)
% nmix - number of mixture components. Defaults to 3.
% niter - number of EM iterations to perform. Defaults to 10.
% verb - set to 1 to output loglik at each iteration
% cvprior -
%
% Outputs:
% gmmparams - structure containing hmm parameters learned from training
% data (gmm.priors, gmm.means(:,1:nmix), gmm.covars(:,1:nmix))
%
% 2007-11-06 [email protected]
if nargin < 2
nmix = 3;
end
if nargin < 3
niter = 10;
end
if nargin < 4
verb = 0;
end
% prior on observation covariances to avoid overfitting:
if nargin < 5
CVPRIOR = 1;
end
if ~iscell(trdata)
trdata = {trdata};
end
ndata = length(trdata);
% Initialization
gmm.priors = log(ones(1, nmix)/nmix);
gmm.nmix = nmix;
if nargin < 6 | numel(mu0) == 1 & mu0 == 1
gmm.means = kmeans(cat(2, trdata{:}), nmix, niter/2);
else
if size(mu0, 2) == nmix
gmm.means = mu0;
end
end
ndim = size(trdata{1}, 1);
%gmm.covars = ones(ndim, nmix);
gmm.covars(:,1:nmix) = repmat(var(trdata{1}')', [1 nmix]);
% sufficient statistics
norm = zeros(size(gmm.priors));
means = zeros(size(gmm.means));
covars = zeros(size(gmm.covars));
last_loglik = 0;
for iter = 1:niter
% E-step
loglik = 0;
norm(:) = 0;
means(:) = 0;
covars(:) = 0;
for n = 1:ndata
curr_data = trdata{n};
[ll, posteriors] = eval_gmm(gmm, curr_data);
loglik = loglik + sum(ll);
norm = norm + sum(posteriors, 2)';
means = means + curr_data * posteriors';
covars = covars + curr_data.^2 * posteriors';
end
if verb,
fprintf('Iteration %d: log likelihood = %f\n', iter, loglik);
end
% Check for convergence
if abs(loglik - last_loglik) < 1e-5
break
end
last_loglik = loglik;
% M-step
gmm.priors = log(norm/sum(norm));
nrm = repmat(1./norm, [ndim 1]);
gmm.means = means .* nrm;
gmm.covars = (covars - 2*gmm.means.*means) .* nrm + gmm.means.^2;
gmm.covars(gmm.covars < CVPRIOR) = CVPRIOR;
end