Wednesday, March 3, 2010

Statistical Learning Algorithm - Expectation Maximization

In machine learning, usually we are given a set of observed data, based on which we can do classification, clustering, etc. Classical techniques involved in these scenarios include Naive Bayes Classifier, Decision Tree Learning, K-means, KNN (K-nearest neighbors), etc. However, there may be situations when un-observed data or hidden variables are present. Expectation Maximization (EM) method is such a important statistical algorithm used to estimate the parameters in probabilistic models, where the models may depend on unobserved latent variables. In this sense, EM is similar with MLE (Maximum Likelihood Estimation) method. However, EM is basically a iterative method, which means that it needs to modify its estimates in each iteration, which makes it also similar with K-means. They both hold the "Gradient Descent" or "Hill Climbing" property; the drawback of these two approaches is then obvious: if the algorithm is not initialized well, it may encounter "local-maxima" problem!
We will start with a description of general EM algorithm; after that we focus on its application in Gaussian Mixture Model (GMM), which is important and widely used.

The EM algorithm seeks to find the MLE by iteratively applying the following two major steps:

Expectation Step:

Calculate the expected value of the likelihood function as we discussed in MLE method. This expected value should be based on the conditional distribution of the hidden variables given the partially observed dataset under the current estimates of parameters. We call this expected value G(arg).
Maximization Step:

Update the paramters we are currently working on, choosing new parameters that will maximize G(arg) through some special schemes. Sometimes this scheme may be taking the weighted average of current parameters.

Of course, EM will need to initialize the parameters that would be estimated, just like what we did in a K-means algorithm: generally, any hill-climbing techniques will have similar problem. How to select the initial values? Randomly (The default method)? Or use some special techniques? The result of EM algorithm, as will be discussed later, is heavily dependent on "where we started".

Example: EM Algorithm in Gaussian Mixture Models

Clusters in data often comes from a mixture distribution, which may have k components. A datapoint is obtained by choosing a component and then apply that component to generate a sample. As for continuous data, a natural choice of probabilistic model is the multivariate Gaussian, which forms the mixture of Gaussian distributions. The parameters of a GMM are Wi = P(C = i), i.e. the weight of each component, Ui, i.e. the mean of each component, and i, the covariance of each component.

This basic idea of EM in this context is to pretend that we know the parameters of the model, and then to infer the probability that each datapoint belongs to each component. After that, we refit the components to the data, where each component is fitted to the entire data set with each point weighted by the probability that it belongs to that component. This process iterates until convergence. The hidden variables, in this case, would be which distribution component each datapoint belongs to. For the mixture of Gaussians, we randomly innitialize parameters for the mixture model, and then repeat following steps:

1. E - Step

Compute Pij = P(C = i Xj), which is the probability that datapoint Xj was generated by component Ci. By Bayes' rules, we have Pij = alpha*P(Xj C = i)*P(C = i), where alpha is a normalizer. P(Xj C = i) is just the probability density at Xj of the component Ci, and P(C = i) is the weight of Ci. Let Pi = j Pij.

2. M - Step

Update the mean, covariance, and component weights as follows.

Ui <- j Pij*Xj/Pi

i <- j Pij*(Xj - Ui)(Xj - Ui)'/Pi

Wi <- Pi.

One thing to mention is, EM increases the log likelihood of the data at every iteration. In some situation, as discussed above, EM will reach a local maxima. A Matlab implementation of EM algorithm for Gaussian Mixture Models is attached here.

--------------------Matlab Codes for EM-GMM-------------------------

function [means, covs, weights, probDens, maxVals, maxIDs] = EMGaussianMixtureModel(dataset, k)% This is an EM algorithm for estimating parameters in Gaussian Mixture% Model% 'dataset' is the data ready for cluster% 'k' is the number of Gaussian components
% Get the dimension of each data pointdim = length(dataset(1, :));
% Get the number of data points in this datasetnum = length(dataset(:, 1));
% Initialize the weight uniformly for each Gaussian modelweights = zeros(k, 1);
for i = 1 : k weights(i) = 1/k;end
% Initialize means for each Gaussian modelmeans = zeros(k, dim);
% Initialize covariance matrix for each Gaussian modelcovs = zeros(dim, dim, k);
for i = 1 : k A = randn(dim, dim); covs(:, :, i) = A*A';end
% Initialize the probability table where probTable(X, i) stores the% probability that data point X was generated by component iprobTable = zeros(num, k);
% Begin EM AlgorithmnotConverge = 1;iterCount = 0;maxCount = 1000;
while(notConverge) % Expectation Step % Iterate the dataset to fill 'probTable' iterCount = iterCount + 1; for i = 1 : k for j = 1 : num p = mvnpdf(dataset(j, :), means(i, :), covs(:, :, i)); w = weights(i); % Save this the this probability to corresponding position in % the table % Note that the values computed here need to be normalized % later probTable(j, i) = p * w; end end % Normalize the probabilities we have just computed for i = 1 : length(probTable(:, 1)) total = sum(probTable(i, :)); for j = 1 : k probTable(i, j) = probTable(i, j)/total; end end % Maximization Step % Compute new mean, covariance matrix and weights for each Gaussian % model % We need to save current parameters for convergence checking tempWeights = weights; tempMeans = means; tempCovs = covs; for j = 1 : k pi = sum(probTable(:, j)); means(j, :) = zeros(1, dim); covs(:, :, j) = zeros(dim, dim); for i = 1 : num data = dataset(i, :); means(j, :) = means(j, :) + (probTable(i, j)/pi) * data; covs(:, :, j) = covs(:, :, j) + (probTable(i, j)/pi) * (data - means(j, :))' * (data - means(j, :)); end weights(j) = pi; end % Normalize the weights total = sum(weights); for i = 1 : k weights(i) = weights(i)/total; end % Check convergence threshold = 0.001; flag = 1; for j = 1 : k if (max(abs(tempWeights - weights)) > threshold) flag = 0; break; end if (max(abs(tempMeans - means)) > threshold) flag = 0; break; end if (max(abs(tempCovs - covs)) > threshold) flag = 0; break; end end if flag == 1 notConverge = 0; end % If iteration takes too long, we need to terminate it if iterCount > maxCount notConverge = 0; endend
% Now compute the probability density P(C = i Xj)probDens = zeros(num, k);
% This is P(C = i, Xj)for j = 1 : k for i = 1 : num probDens(i, j) = weights(j) * mvnpdf(dataset(i, :), means(j, :), covs(:, :, j)); endend
% This is P(C = i Xj)for i = 1 : num % This is P(Xj) total = sum(probDens(i, :)); for j = 1 : k probDens(i, j) = probDens(i, j)/total; endend
% Save the max value and their indices[maxVals, maxIDs] = max(probDens, [], 2);
% Begin plot the resultsfigure;
hold on;
for i = 1 : length(dataset) if maxIDs(i) == 1 plot(dataset(i, 1), dataset(i, 2), 'o', 'MarkerEdgeColor', 'r', 'MarkerFaceColor', 'r'); elseif maxIDs(i) == 2 plot(dataset(i, 1), dataset(i, 2), 'o', 'MarkerEdgeColor', 'g', 'MarkerFaceColor', 'g'); elseif maxIDs(i) == 3 plot(dataset(i, 1), dataset(i, 2), 'o', 'MarkerEdgeColor', 'b', 'MarkerFaceColor', 'b'); endend
hold off;
end

No comments: