资源简介
曾经为了研究em算法,在网上搜寻了一个月的资料,也没有找到em算法的原代码,后来终于在一个资深教授那里找到相关资料,特地传上来和大家共享
代码片段和文件信息
function [test_targets param_struct] = EM(train_patterns train_targets test_patterns Ngaussians)
% Classify using the expectation-maximization algorithm
% Inputs:
% train_patterns - Train patterns
% train_targets - Train targets
% test_patterns - Test patterns
% Ngaussians - Number for Gaussians for each class (vector)
%
% Outputs
% test_targets - Predicted targets
% param_struct - A parameter structure containing the parameters of the Gaussians found
classes = unique(train_targets); %Number of classes in targets
Nclasses = length(classes);
Nalpha = Ngaussians; %Number of Gaussians in each class
Dim = size(train_patterns1);
max_iter = 100;
max_try = 5;
Pw = zeros(Nclassesmax(Ngaussians));
sigma = zeros(Nclassesmax(Ngaussians)size(train_patterns1)size(train_patterns1));
mu = zeros(Nclassesmax(Ngaussians)size(train_patterns1));
%The initial guess is based on k-means preprocessing. If it does not converge after
%max_iter iterations a random guess is used.
disp(‘Using k-means for initial guess‘)
for i = 1:Nclasses
in = find(train_targets==classes(i));
[initial_mu targets labels] = k_means(train_patterns(:in)train_targets(:in)Ngaussians(i));
for j = 1:Ngaussians(i)
gauss_labels = find(labels==j);
Pw(ij) = length(gauss_labels) / length(labels);
sigma(ij::) = diag(std(train_patterns(:in(gauss_labels))‘));
end
mu(i1:Ngaussians(i):) = initial_mu‘;
end
%Do the EM: Estimate mean and covariance for each class
for c = 1:Nclasses
train = find(train_targets == classes(c));
if (Ngaussians(c) == 1)
%If there is only one Gaussian there is no need to do a whole EM procedure
sigma(c1::) = sqrtm(cov(train_patterns(:train)‘1));
mu(c1:) = mean(train_patterns(:train)‘);
else
sigma_i = squeeze(sigma(c:::));
old_sigma = zeros(size(sigma_i)); %Used for the stopping criterion
iter = 0; %Iteration counter
n = length(train); %Number of training points
qi = zeros(Nalpha(c)n); %This will hold qi‘s
P = zeros(1Nalpha(c));
Ntry = 0;
while ((sum(sum(sum(abs(sigma_i-old_sigma)))) > 1e-4) & (Ntry < max_try))
old_sigma = sigma_i;
%E step: Compute Q(theta; theta_i)
for t = 1:n
data = train_patterns(:train(t));
for k = 1:Nalpha(c)
P(k) = Pw(ck) * p_single(data squeeze(mu(ck:)) squeeze(sigma_i(k::)));
end
for i = 1:Nalpha(c)
qi(it) = P(i) / sum(P);
end
end
%M step: theta_i+1 <- argmax(Q(the
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 5336 2007-12-18 20:47 EM.m
文件 1834 2003-06-26 21:14 k_means.m
----------- --------- ---------- ----- ----
7170 2
评论
共有 条评论