2010-05-12 01:44:00 +08:00
|
|
|
/*M///////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
|
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
|
|
|
//
|
|
|
|
// By downloading, copying, installing or using the software you agree to this license.
|
|
|
|
// If you do not agree to this license, do not download, install,
|
|
|
|
// copy or use the software.
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// Intel License Agreement
|
|
|
|
// For Open Source Computer Vision Library
|
|
|
|
//
|
|
|
|
// Copyright( C) 2000, Intel Corporation, all rights reserved.
|
|
|
|
// Third party copyrights are property of their respective owners.
|
|
|
|
//
|
|
|
|
// Redistribution and use in source and binary forms, with or without modification,
|
|
|
|
// are permitted provided that the following conditions are met:
|
|
|
|
//
|
|
|
|
// * Redistribution's of source code must retain the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer.
|
|
|
|
//
|
|
|
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer in the documentation
|
|
|
|
// and/or other materials provided with the distribution.
|
|
|
|
//
|
|
|
|
// * The name of Intel Corporation may not be used to endorse or promote products
|
|
|
|
// derived from this software without specific prior written permission.
|
|
|
|
//
|
|
|
|
// This software is provided by the copyright holders and contributors "as is" and
|
|
|
|
// any express or implied warranties, including, but not limited to, the implied
|
|
|
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
|
|
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
|
|
|
// indirect, incidental, special, exemplary, or consequential damages
|
|
|
|
//(including, but not limited to, procurement of substitute goods or services;
|
|
|
|
// loss of use, data, or profits; or business interruption) however caused
|
|
|
|
// and on any theory of liability, whether in contract, strict liability,
|
|
|
|
// or tort(including negligence or otherwise) arising in any way out of
|
|
|
|
// the use of this software, even ifadvised of the possibility of such damage.
|
|
|
|
//
|
|
|
|
//M*/
|
|
|
|
|
|
|
|
#include "precomp.hpp"
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
namespace cv
|
2011-06-24 20:48:00 +08:00
|
|
|
{
|
|
|
|
|
2012-04-06 18:04:22 +08:00
|
|
|
const float minEigenValue = 1.e-3f;
|
2012-04-06 17:26:11 +08:00
|
|
|
|
|
|
|
EM::Params::Params( int nclusters, int covMatType, int startStep, const cv::TermCriteria& termCrit,
|
|
|
|
const cv::Mat* probs, const cv::Mat* weights,
|
|
|
|
const cv::Mat* means, const std::vector<cv::Mat>* covs )
|
|
|
|
: nclusters(nclusters), covMatType(covMatType), startStep(startStep),
|
|
|
|
probs(probs), weights(weights), means(means), covs(covs), termCrit(termCrit)
|
2011-06-24 20:48:00 +08:00
|
|
|
{}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
EM::EM()
|
|
|
|
{}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
EM::EM(const cv::Mat& samples, const cv::Mat samplesMask,
|
|
|
|
const EM::Params& params, cv::Mat* labels, cv::Mat* probs, cv::Mat* likelihoods)
|
|
|
|
{
|
|
|
|
train(samples, samplesMask, params, labels, probs, likelihoods);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
EM::~EM()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
|
|
|
clear();
|
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::clear()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
trainSamples.release();
|
|
|
|
trainProbs.release();
|
|
|
|
trainLikelihoods.release();
|
|
|
|
trainLabels.release();
|
|
|
|
trainCounts.release();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
weights.release();
|
|
|
|
means.release();
|
|
|
|
covs.clear();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
covsEigenValues.clear();
|
|
|
|
invCovsEigenValues.clear();
|
|
|
|
covsRotateMats.clear();
|
|
|
|
|
|
|
|
logWeightDivDet.release();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
bool EM::train(const cv::Mat& samples, const cv::Mat& samplesMask,
|
|
|
|
const EM::Params& params, cv::Mat* labels, cv::Mat* probs, cv::Mat* likelihoods)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
setTrainData(samples, samplesMask, params);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
bool isOk = doTrain(params.termCrit);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(isOk)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
if(labels)
|
|
|
|
cv::swap(*labels, trainLabels);
|
|
|
|
if(probs)
|
|
|
|
cv::swap(*probs, trainProbs);
|
|
|
|
if(likelihoods)
|
|
|
|
cv::swap(*likelihoods, trainLikelihoods);
|
|
|
|
|
|
|
|
trainSamples.release();
|
|
|
|
trainProbs.release();
|
|
|
|
trainLabels.release();
|
|
|
|
trainLikelihoods.release();
|
|
|
|
trainCounts.release();
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
else
|
2011-06-07 18:05:23 +08:00
|
|
|
clear();
|
2012-04-06 17:26:11 +08:00
|
|
|
|
|
|
|
return isOk;
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
int EM::predict(const cv::Mat& sample, cv::Mat* _probs, double* _likelihood) const
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_Assert(isTrained());
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_Assert(!sample.empty());
|
|
|
|
CV_Assert(sample.type() == CV_32FC1);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
int label;
|
2012-04-06 18:04:22 +08:00
|
|
|
float likelihood = 0.f;
|
2012-04-06 17:26:11 +08:00
|
|
|
computeProbabilities(sample, label, _probs, _likelihood ? &likelihood : 0);
|
|
|
|
if(_likelihood)
|
|
|
|
*_likelihood = static_cast<double>(likelihood);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
return label;
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
bool EM::isTrained() const
|
|
|
|
{
|
|
|
|
return !means.empty();
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
int EM::getNClusters() const
|
|
|
|
{
|
|
|
|
return isTrained() ? nclusters : -1;
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
int EM::getCovMatType() const
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
return isTrained() ? covMatType : -1;
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
const cv::Mat& EM::getWeights() const
|
|
|
|
{
|
|
|
|
CV_Assert((isTrained() && !weights.empty()) || (!isTrained() && weights.empty()));
|
|
|
|
return weights;
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
const cv::Mat& EM::getMeans() const
|
|
|
|
{
|
|
|
|
CV_Assert((isTrained() && !means.empty()) || (!isTrained() && means.empty()));
|
|
|
|
return means;
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
const std::vector<cv::Mat>& EM::getCovs() const
|
|
|
|
{
|
|
|
|
CV_Assert((isTrained() && !covs.empty()) || (!isTrained() && covs.empty()));
|
|
|
|
return covs;
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
static
|
|
|
|
void checkTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM::Params& params)
|
|
|
|
{
|
|
|
|
// Check samples.
|
|
|
|
CV_Assert(!samples.empty());
|
|
|
|
CV_Assert(samples.type() == CV_32FC1);
|
|
|
|
|
|
|
|
int nsamples = samples.rows;
|
|
|
|
int dim = samples.cols;
|
|
|
|
|
|
|
|
// Check samples indices.
|
|
|
|
CV_Assert(samplesMask.empty() ||
|
|
|
|
((samplesMask.rows == 1 || samplesMask.cols == 1) &&
|
|
|
|
static_cast<int>(samplesMask.total()) == nsamples && samplesMask.type() == CV_8UC1));
|
|
|
|
|
|
|
|
// Check training params.
|
|
|
|
CV_Assert(params.nclusters > 0);
|
|
|
|
CV_Assert(params.nclusters <= nsamples);
|
|
|
|
CV_Assert(params.startStep == EM::START_AUTO_STEP || params.startStep == EM::START_E_STEP || params.startStep == EM::START_M_STEP);
|
|
|
|
|
|
|
|
CV_Assert(!params.probs ||
|
|
|
|
(!params.probs->empty() &&
|
|
|
|
params.probs->rows == nsamples && params.probs->cols == params.nclusters &&
|
|
|
|
params.probs->type() == CV_32FC1));
|
|
|
|
|
|
|
|
CV_Assert(!params.weights ||
|
|
|
|
(!params.weights->empty() &&
|
|
|
|
(params.weights->cols == 1 || params.weights->rows == 1) && static_cast<int>(params.weights->total()) == params.nclusters &&
|
|
|
|
params.weights->type() == CV_32FC1));
|
|
|
|
|
|
|
|
CV_Assert(!params.means ||
|
|
|
|
(!params.means->empty() &&
|
|
|
|
params.means->rows == params.nclusters && params.means->cols == dim &&
|
|
|
|
params.means->type() == CV_32FC1));
|
|
|
|
|
|
|
|
CV_Assert(!params.covs ||
|
|
|
|
(!params.covs->empty() &&
|
|
|
|
static_cast<int>(params.covs->size()) == params.nclusters));
|
|
|
|
if(params.covs)
|
|
|
|
{
|
|
|
|
const cv::Size covSize(dim, dim);
|
|
|
|
for(size_t i = 0; i < params.covs->size(); i++)
|
|
|
|
{
|
|
|
|
const cv::Mat& m = (*params.covs)[i];
|
|
|
|
CV_Assert(!m.empty() && m.size() == covSize && (m.type() == CV_32FC1));
|
|
|
|
}
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(params.startStep == EM::START_E_STEP)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_Assert(params.means);
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
else if(params.startStep == EM::START_M_STEP)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_Assert(params.probs);
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
static
|
|
|
|
void preprocessSampleData(const cv::Mat& src, cv::Mat& dst, int dstType, const cv::Mat& samplesMask, bool isAlwaysClone)
|
|
|
|
{
|
|
|
|
if(samplesMask.empty() || cv::countNonZero(samplesMask) == src.rows)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
if(src.type() == dstType && !isAlwaysClone)
|
|
|
|
dst = src;
|
|
|
|
else
|
|
|
|
src.convertTo(dst, dstType);
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
dst.release();
|
|
|
|
for(int sampleIndex = 0; sampleIndex < src.rows; sampleIndex++)
|
|
|
|
{
|
|
|
|
if(samplesMask.at<uchar>(sampleIndex))
|
|
|
|
{
|
|
|
|
cv::Mat sample = src.row(sampleIndex);
|
|
|
|
cv::Mat sample_dbl;
|
|
|
|
sample.convertTo(sample_dbl, dstType);
|
|
|
|
dst.push_back(sample_dbl);
|
|
|
|
}
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
static
|
|
|
|
void preprocessProbability(cv::Mat& probs)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::max(probs, 0., probs);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 18:04:22 +08:00
|
|
|
const float uniformProbability = (float)(1./probs.cols);
|
2012-04-06 17:26:11 +08:00
|
|
|
for(int y = 0; y < probs.rows; y++)
|
|
|
|
{
|
|
|
|
cv::Mat sampleProbs = probs.row(y);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
double maxVal = 0;
|
|
|
|
cv::minMaxLoc(sampleProbs, 0, &maxVal);
|
|
|
|
if(maxVal < FLT_EPSILON)
|
|
|
|
sampleProbs.setTo(uniformProbability);
|
|
|
|
else
|
|
|
|
cv::normalize(sampleProbs, sampleProbs, 1, 0, cv::NORM_L1);
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::setTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM::Params& params)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
clear();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
checkTrainData(samples, samplesMask, params);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// Set checked data
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
nclusters = params.nclusters;
|
|
|
|
covMatType = params.covMatType;
|
|
|
|
startStep = params.startStep;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
preprocessSampleData(samples, trainSamples, CV_32FC1, samplesMask, false);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// set probs
|
|
|
|
if(params.probs && startStep == EM::START_M_STEP)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
preprocessSampleData(*params.probs, trainProbs, CV_32FC1, samplesMask, true);
|
|
|
|
preprocessProbability(trainProbs);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// set weights
|
|
|
|
if(params.weights && (startStep == EM::START_E_STEP && params.covs))
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
params.weights->convertTo(weights, CV_32FC1);
|
|
|
|
weights.reshape(1,1);
|
|
|
|
preprocessProbability(weights);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// set means
|
|
|
|
if(params.means && (startStep == EM::START_E_STEP || startStep == EM::START_AUTO_STEP))
|
|
|
|
params.means->convertTo(means, CV_32FC1);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// set covs
|
|
|
|
if(params.covs && (startStep == EM::START_E_STEP && params.weights))
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
covs.resize(nclusters);
|
|
|
|
for(size_t i = 0; i < params.covs->size(); i++)
|
|
|
|
(*params.covs)[i].convertTo(covs[i], CV_32FC1);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::decomposeCovs()
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_Assert(!covs.empty());
|
|
|
|
covsEigenValues.resize(nclusters);
|
|
|
|
if(covMatType == EM::COV_MAT_GENERIC)
|
|
|
|
covsRotateMats.resize(nclusters);
|
|
|
|
invCovsEigenValues.resize(nclusters);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
|
|
{
|
|
|
|
CV_Assert(!covs[clusterIndex].empty());
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::SVD svd(covs[clusterIndex], cv::SVD::MODIFY_A + cv::SVD::FULL_UV);
|
|
|
|
CV_DbgAssert(svd.w.rows == 1 || svd.w.cols == 1);
|
|
|
|
CV_DbgAssert(svd.w.type() == CV_32FC1 && svd.u.type() == CV_32FC1);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(covMatType == EM::COV_MAT_SPHERICAL)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
float maxSingularVal = svd.w.at<float>(0);
|
|
|
|
covsEigenValues[clusterIndex] = cv::Mat(1, 1, CV_32FC1, cv::Scalar(maxSingularVal));
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
else if(covMatType == EM::COV_MAT_DIAGONAL)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
covsEigenValues[clusterIndex] = svd.w;
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
else //EM::COV_MAT_GENERIC
|
|
|
|
{
|
|
|
|
covsEigenValues[clusterIndex] = svd.w;
|
|
|
|
covsRotateMats[clusterIndex] = svd.u;
|
|
|
|
}
|
|
|
|
cv::max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
|
|
|
|
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::clusterTrainSamples()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
int nsamples = trainSamples.rows;
|
|
|
|
|
|
|
|
// Cluster samples, compute/update means
|
|
|
|
cv::Mat labels;
|
|
|
|
cv::kmeans(trainSamples, nclusters, labels,
|
|
|
|
cv::TermCriteria(cv::TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
|
|
|
|
10, cv::KMEANS_PP_CENTERS, means);
|
|
|
|
CV_Assert(means.type() == CV_32FC1);
|
|
|
|
|
|
|
|
// Compute weights and covs
|
|
|
|
weights = cv::Mat(1, nclusters, CV_32FC1, cv::Scalar(0));
|
|
|
|
covs.resize(nclusters);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::Mat clusterSamples;
|
|
|
|
for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
if(labels.at<int>(sampleIndex) == clusterIndex)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
const cv::Mat sample = trainSamples.row(sampleIndex);
|
|
|
|
clusterSamples.push_back(sample);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_Assert(!clusterSamples.empty());
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
|
|
|
|
CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_32FC1);
|
|
|
|
weights.at<float>(clusterIndex) = static_cast<float>(clusterSamples.rows)/static_cast<float>(nsamples);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
decomposeCovs();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::computeLogWeightDivDet()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_Assert(!covsEigenValues.empty());
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::Mat logWeights;
|
|
|
|
cv::log(weights, logWeights);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
logWeightDivDet.create(1, nclusters, CV_32FC1);
|
|
|
|
// note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
float logDetCov = 0.;
|
|
|
|
for(int di = 0; di < covsEigenValues[clusterIndex].cols; di++)
|
|
|
|
logDetCov += std::log(covsEigenValues[clusterIndex].at<float>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0));
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 18:04:22 +08:00
|
|
|
logWeightDivDet.at<float>(clusterIndex) = logWeights.at<float>(clusterIndex) - 0.5f * logDetCov;
|
2012-04-06 17:26:11 +08:00
|
|
|
}
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
bool EM::doTrain(const cv::TermCriteria& termCrit)
|
|
|
|
{
|
|
|
|
int dim = trainSamples.cols;
|
|
|
|
// Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP
|
|
|
|
if(startStep != EM::START_M_STEP)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
if(weights.empty())
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_Assert(covs.empty());
|
|
|
|
clusterTrainSamples();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(!covs.empty() && covsEigenValues.empty() )
|
|
|
|
{
|
|
|
|
CV_Assert(invCovsEigenValues.empty());
|
|
|
|
decomposeCovs();
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(startStep == EM::START_M_STEP)
|
|
|
|
mStep();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 18:04:22 +08:00
|
|
|
double trainLikelihood, prevTrainLikelihood = 0.;
|
2012-04-06 17:26:11 +08:00
|
|
|
for(int iter = 0; ; iter++)
|
|
|
|
{
|
|
|
|
eStep();
|
|
|
|
trainLikelihood = cv::sum(trainLikelihoods)[0];
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(iter >= termCrit.maxCount - 1)
|
|
|
|
break;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
double trainLikelihoodDelta = trainLikelihood - (iter > 0 ? prevTrainLikelihood : 0);
|
|
|
|
if( iter != 0 &&
|
|
|
|
(trainLikelihoodDelta < -DBL_EPSILON ||
|
|
|
|
trainLikelihoodDelta < termCrit.epsilon * std::fabs(trainLikelihood)))
|
|
|
|
break;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
mStep();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
prevTrainLikelihood = trainLikelihood;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if( trainLikelihood <= -DBL_MAX/10000. )
|
|
|
|
return false;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// postprocess covs
|
|
|
|
covs.resize(nclusters);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
if(covMatType == EM::COV_MAT_SPHERICAL)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
covs[clusterIndex].create(dim, dim, CV_32FC1);
|
|
|
|
cv::setIdentity(covs[clusterIndex], cv::Scalar(covsEigenValues[clusterIndex].at<float>(0)));
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
else if(covMatType == EM::COV_MAT_DIAGONAL)
|
|
|
|
covs[clusterIndex] = cv::Mat::diag(covsEigenValues[clusterIndex].t());
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
return true;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs, float* likelihood) const
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
// L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
|
|
|
|
// q = arg(max_k(L_ik))
|
|
|
|
// probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_jk))
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_DbgAssert(sample.rows == 1);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
int dim = sample.cols;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::Mat L(1, nclusters, CV_32FC1);
|
|
|
|
cv::Mat expL(1, nclusters, CV_32FC1);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
label = 0;
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
const cv::Mat centeredSample = sample - means.row(clusterIndex);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::Mat rotatedCenteredSample = covMatType != EM::COV_MAT_GENERIC ?
|
|
|
|
centeredSample : centeredSample * covsRotateMats[clusterIndex];
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
float Lval = 0;
|
|
|
|
for(int di = 0; di < dim; di++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
float w = invCovsEigenValues[clusterIndex].at<float>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0);
|
|
|
|
float val = rotatedCenteredSample.at<float>(di);
|
|
|
|
Lval += w * val * val;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
CV_DbgAssert(!logWeightDivDet.empty());
|
2012-04-06 18:04:22 +08:00
|
|
|
Lval = logWeightDivDet.at<float>(clusterIndex) - 0.5f * Lval;
|
2012-04-06 17:26:11 +08:00
|
|
|
L.at<float>(clusterIndex) = Lval;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(Lval > L.at<float>(label))
|
|
|
|
label = clusterIndex;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(!probs && !likelihood)
|
|
|
|
return;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// TODO maybe without finding max L value
|
|
|
|
cv::exp(L, expL);
|
|
|
|
float partExpSum = 0, // sum_j!=q (exp(L_jk)
|
|
|
|
factor; // 1/(1 + sum_j!=q (exp(L_jk))
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
|
|
{
|
|
|
|
if(clusterIndex != label)
|
|
|
|
partExpSum += expL.at<float>(clusterIndex);
|
|
|
|
}
|
2012-04-06 18:04:22 +08:00
|
|
|
factor = 1.f/(1 + partExpSum);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::exp(L - L.at<float>(label), expL);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(probs)
|
|
|
|
{
|
|
|
|
probs->create(1, nclusters, CV_32FC1);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
|
|
probs->at<float>(clusterIndex) = expL.at<float>(clusterIndex) * factor;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(likelihood)
|
|
|
|
{
|
|
|
|
// note likelihood = log (sum_j exp(L_ij)) - 0.5 * dims * ln2Pi
|
2012-04-06 18:04:22 +08:00
|
|
|
*likelihood = std::log(partExpSum + expL.at<float>(label)) - (float)(0.5 * dim * CV_LOG2PI);
|
2012-04-06 17:26:11 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::eStep()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
// Compute probs_ik from means_k, covs_k and weights_k.
|
|
|
|
trainProbs.create(trainSamples.rows, nclusters, CV_32FC1);
|
|
|
|
trainLabels.create(trainSamples.rows, 1, CV_32SC1);
|
|
|
|
trainLikelihoods.create(trainSamples.rows, 1, CV_32FC1);
|
2012-03-29 16:55:43 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
computeLogWeightDivDet();
|
2012-03-29 16:55:43 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
|
|
|
{
|
|
|
|
cv::Mat sampleProbs = trainProbs.row(sampleIndex);
|
|
|
|
computeProbabilities(trainSamples.row(sampleIndex), trainLabels.at<int>(sampleIndex),
|
|
|
|
&sampleProbs, &trainLikelihoods.at<float>(sampleIndex));
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::mStep()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
trainCounts.create(1, nclusters, CV_32SC1);
|
|
|
|
trainCounts = cv::Scalar(0);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
for(int sampleIndex = 0; sampleIndex < trainLabels.rows; sampleIndex++)
|
|
|
|
trainCounts.at<int>(trainLabels.at<int>(sampleIndex))++;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(cv::countNonZero(trainCounts) != (int)trainCounts.total())
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
clusterTrainSamples();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
else
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
// Update means_k, covs_k and weights_k from probs_ik
|
|
|
|
int dim = trainSamples.cols;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// Update weights
|
|
|
|
// not normalized first
|
|
|
|
cv::reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// Update means
|
|
|
|
means.create(nclusters, dim, CV_32FC1);
|
|
|
|
means = cv::Scalar(0);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::Mat clusterMean = means.row(clusterIndex);
|
|
|
|
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
|
|
|
clusterMean += trainProbs.at<float>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
|
|
|
|
clusterMean /= weights.at<float>(clusterIndex);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// Update covsEigenValues and invCovsEigenValues
|
|
|
|
covs.resize(nclusters);
|
|
|
|
covsEigenValues.resize(nclusters);
|
|
|
|
if(covMatType == EM::COV_MAT_GENERIC)
|
|
|
|
covsRotateMats.resize(nclusters);
|
|
|
|
invCovsEigenValues.resize(nclusters);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
if(covMatType != EM::COV_MAT_SPHERICAL)
|
|
|
|
covsEigenValues[clusterIndex].create(1, dim, CV_32FC1);
|
|
|
|
else
|
|
|
|
covsEigenValues[clusterIndex].create(1, 1, CV_32FC1);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(covMatType == EM::COV_MAT_GENERIC)
|
|
|
|
covs[clusterIndex].create(dim, dim, CV_32FC1);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ?
|
|
|
|
covsEigenValues[clusterIndex] : covs[clusterIndex];
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
clusterCov = cv::Scalar(0);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::Mat centeredSample;
|
|
|
|
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(covMatType == EM::COV_MAT_GENERIC)
|
|
|
|
clusterCov += trainProbs.at<float>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
|
2010-05-12 01:44:00 +08:00
|
|
|
else
|
2012-04-06 17:26:11 +08:00
|
|
|
{
|
|
|
|
float p = trainProbs.at<float>(sampleIndex, clusterIndex);
|
|
|
|
for(int di = 0; di < dim; di++ )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
float val = centeredSample.at<float>(di);
|
|
|
|
clusterCov.at<float>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
if(covMatType == EM::COV_MAT_SPHERICAL)
|
|
|
|
clusterCov /= dim;
|
|
|
|
|
|
|
|
clusterCov /= weights.at<float>(clusterIndex);
|
|
|
|
|
|
|
|
// Update covsRotateMats for EM::COV_MAT_GENERIC only
|
|
|
|
if(covMatType == EM::COV_MAT_GENERIC)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::SVD svd(covs[clusterIndex], cv::SVD::MODIFY_A + cv::SVD::FULL_UV);
|
|
|
|
covsEigenValues[clusterIndex] = svd.w;
|
|
|
|
covsRotateMats[clusterIndex] = svd.u;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
cv::max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// update invCovsEigenValues
|
|
|
|
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
// Normalize weights
|
|
|
|
weights /= trainSamples.rows;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
void EM::read(const FileNode& fn)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
Algorithm::read(fn);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
decomposeCovs();
|
|
|
|
computeLogWeightDivDet();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
static Algorithm* createEM()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
return new EM;
|
2012-03-29 16:55:43 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
static AlgorithmInfo em_info("StatModel.EM", createEM);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
AlgorithmInfo* EM::info() const
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
static volatile bool initialized = false;
|
|
|
|
if( !initialized )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2012-04-06 17:26:11 +08:00
|
|
|
EM obj;
|
|
|
|
em_info.addParam(obj, "nclusters", obj.nclusters);
|
|
|
|
em_info.addParam(obj, "covMatType", obj.covMatType);
|
2012-03-29 16:55:43 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
em_info.addParam(obj, "weights", obj.weights);
|
|
|
|
em_info.addParam(obj, "means", obj.means);
|
|
|
|
em_info.addParam(obj, "covs", obj.covs);
|
2012-03-29 16:55:43 +08:00
|
|
|
|
2012-04-06 17:26:11 +08:00
|
|
|
initialized = true;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
return &em_info;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-04-06 17:26:11 +08:00
|
|
|
} // namespace cv
|
2010-05-12 01:44:00 +08:00
|
|
|
|
|
|
|
/* End of file. */
|