Example: EM Algorithm in Gaussian Mixture Models
Clusters in data often comes from a mixture distribution, which may have k components. A datapoint is obtained by choosing a component and then apply that component to generate a sample. As for continuous data, a natural choice of probabilistic model is the multivariate Gaussian, which forms the mixture of Gaussian distributions. The parameters of a GMM are Wi = P(C = i), i.e. the weight of each component, Ui, i.e. the mean of each component, and i, the covariance of each component.
This basic idea of EM in this context is to pretend that we know the parameters of the model, and then to infer the probability that each datapoint belongs to each component. After that, we refit the components to the data, where each component is fitted to the entire data set with each point weighted by the probability that it belongs to that component. This process iterates until convergence. The hidden variables, in this case, would be which distribution component each datapoint belongs to. For the mixture of Gaussians, we randomly innitialize parameters for the mixture model, and then repeat following steps:
1. E - Step
Compute Pij = P(C = i Xj), which is the probability that datapoint Xj was generated by component Ci. By Bayes' rules, we have Pij = alpha*P(Xj C = i)*P(C = i), where alpha is a normalizer. P(Xj C = i) is just the probability density at Xj of the component Ci, and P(C = i) is the weight of Ci. Let Pi = j Pij.
2. M - Step
Update the mean, covariance, and component weights as follows.
i <- j Pij*(Xj - Ui)(Xj - Ui)'/PiWi <- Pi.
One thing to mention is, EM increases the log likelihood of the data at every iteration. In some situation, as discussed above, EM will reach a local maxima. A Matlab implementation of EM algorithm for Gaussian Mixture Models is attached here.
--------------------Matlab Codes for EM-GMM-------------------------
function [means, covs, weights, probDens, maxVals, maxIDs] = EMGaussianMixtureModel(dataset, k)% This is an EM algorithm for estimating parameters in Gaussian Mixture% Model% 'dataset' is the data ready for cluster% 'k' is the number of Gaussian components
% Get the dimension of each data pointdim = length(dataset(1, :));
% Get the number of data points in this datasetnum = length(dataset(:, 1));
% Initialize the weight uniformly for each Gaussian modelweights = zeros(k, 1);
for i = 1 : k weights(i) = 1/k;end
% Initialize means for each Gaussian modelmeans = zeros(k, dim);
% Initialize covariance matrix for each Gaussian modelcovs = zeros(dim, dim, k);
for i = 1 : k A = randn(dim, dim); covs(:, :, i) = A*A';end
% Initialize the probability table where probTable(X, i) stores the% probability that data point X was generated by component iprobTable = zeros(num, k);
% Begin EM AlgorithmnotConverge = 1;iterCount = 0;maxCount = 1000;
while(notConverge) % Expectation Step % Iterate the dataset to fill 'probTable' iterCount = iterCount + 1; for i = 1 : k for j = 1 : num p = mvnpdf(dataset(j, :), means(i, :), covs(:, :, i)); w = weights(i); % Save this the this probability to corresponding position in % the table % Note that the values computed here need to be normalized % later probTable(j, i) = p * w; end end % Normalize the probabilities we have just computed for i = 1 : length(probTable(:, 1)) total = sum(probTable(i, :)); for j = 1 : k probTable(i, j) = probTable(i, j)/total; end end % Maximization Step % Compute new mean, covariance matrix and weights for each Gaussian % model % We need to save current parameters for convergence checking tempWeights = weights; tempMeans = means; tempCovs = covs; for j = 1 : k pi = sum(probTable(:, j)); means(j, :) = zeros(1, dim); covs(:, :, j) = zeros(dim, dim); for i = 1 : num data = dataset(i, :); means(j, :) = means(j, :) + (probTable(i, j)/pi) * data; covs(:, :, j) = covs(:, :, j) + (probTable(i, j)/pi) * (data - means(j, :))' * (data - means(j, :)); end weights(j) = pi; end % Normalize the weights total = sum(weights); for i = 1 : k weights(i) = weights(i)/total; end % Check convergence threshold = 0.001; flag = 1; for j = 1 : k if (max(abs(tempWeights - weights)) > threshold) flag = 0; break; end if (max(abs(tempMeans - means)) > threshold) flag = 0; break; end if (max(abs(tempCovs - covs)) > threshold) flag = 0; break; end end if flag == 1 notConverge = 0; end % If iteration takes too long, we need to terminate it if iterCount > maxCount notConverge = 0; endend
% Now compute the probability density P(C = i Xj)probDens = zeros(num, k);
% This is P(C = i, Xj)for j = 1 : k for i = 1 : num probDens(i, j) = weights(j) * mvnpdf(dataset(i, :), means(j, :), covs(:, :, j)); endend
% This is P(C = i Xj)for i = 1 : num % This is P(Xj) total = sum(probDens(i, :)); for j = 1 : k probDens(i, j) = probDens(i, j)/total; endend
% Save the max value and their indices[maxVals, maxIDs] = max(probDens, [], 2);
% Begin plot the resultsfigure;
hold on;
for i = 1 : length(dataset) if maxIDs(i) == 1 plot(dataset(i, 1), dataset(i, 2), 'o', 'MarkerEdgeColor', 'r', 'MarkerFaceColor', 'r'); elseif maxIDs(i) == 2 plot(dataset(i, 1), dataset(i, 2), 'o', 'MarkerEdgeColor', 'g', 'MarkerFaceColor', 'g'); elseif maxIDs(i) == 3 plot(dataset(i, 1), dataset(i, 2), 'o', 'MarkerEdgeColor', 'b', 'MarkerFaceColor', 'b'); endend
hold off;
end
No comments:
Post a Comment