2010-05-12 01:44:00 +08:00
|
|
|
/*M///////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
|
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
|
|
|
//
|
|
|
|
// By downloading, copying, installing or using the software you agree to this license.
|
|
|
|
// If you do not agree to this license, do not download, install,
|
|
|
|
// copy or use the software.
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// Intel License Agreement
|
|
|
|
// For Open Source Computer Vision Library
|
|
|
|
//
|
|
|
|
// Copyright( C) 2000, Intel Corporation, all rights reserved.
|
|
|
|
// Third party copyrights are property of their respective owners.
|
|
|
|
//
|
|
|
|
// Redistribution and use in source and binary forms, with or without modification,
|
|
|
|
// are permitted provided that the following conditions are met:
|
|
|
|
//
|
|
|
|
// * Redistribution's of source code must retain the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer.
|
|
|
|
//
|
|
|
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer in the documentation
|
|
|
|
// and/or other materials provided with the distribution.
|
|
|
|
//
|
|
|
|
// * The name of Intel Corporation may not be used to endorse or promote products
|
|
|
|
// derived from this software without specific prior written permission.
|
|
|
|
//
|
|
|
|
// This software is provided by the copyright holders and contributors "as is" and
|
|
|
|
// any express or implied warranties, including, but not limited to, the implied
|
|
|
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
|
|
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
|
|
|
// indirect, incidental, special, exemplary, or consequential damages
|
|
|
|
//(including, but not limited to, procurement of substitute goods or services;
|
|
|
|
// loss of use, data, or profits; or business interruption) however caused
|
|
|
|
// and on any theory of liability, whether in contract, strict liability,
|
|
|
|
// or tort(including negligence or otherwise) arising in any way out of
|
|
|
|
// the use of this software, even ifadvised of the possibility of such damage.
|
|
|
|
//
|
|
|
|
//M*/
|
|
|
|
|
|
|
|
#include "precomp.hpp"
|
|
|
|
|
2012-04-14 05:50:59 +08:00
|
|
|
namespace cv
|
2011-06-24 20:48:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
namespace ml
|
|
|
|
{
|
2011-06-24 20:48:00 +08:00
|
|
|
|
2012-04-30 22:33:52 +08:00
|
|
|
const double minEigenValue = DBL_EPSILON;
|
2012-04-14 05:50:59 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
class CV_EXPORTS EMImpl : public EM
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
public:
|
2015-02-11 18:24:14 +08:00
|
|
|
|
|
|
|
int nclusters;
|
|
|
|
int covMatType;
|
|
|
|
TermCriteria termCrit;
|
|
|
|
|
|
|
|
CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, termCrit)
|
|
|
|
|
|
|
|
void setClustersNumber(int val)
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
nclusters = val;
|
|
|
|
CV_Assert(nclusters > 1);
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
int getClustersNumber() const
|
|
|
|
{
|
|
|
|
return nclusters;
|
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
void setCovarianceMatrixType(int val)
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
covMatType = val;
|
|
|
|
CV_Assert(covMatType == COV_MAT_SPHERICAL ||
|
|
|
|
covMatType == COV_MAT_DIAGONAL ||
|
|
|
|
covMatType == COV_MAT_GENERIC);
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
int getCovarianceMatrixType() const
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
return covMatType;
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
EMImpl()
|
|
|
|
{
|
|
|
|
nclusters = DEFAULT_NCLUSTERS;
|
|
|
|
covMatType=EM::COV_MAT_DIAGONAL;
|
|
|
|
termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual ~EMImpl() {}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void clear()
|
|
|
|
{
|
|
|
|
trainSamples.release();
|
|
|
|
trainProbs.release();
|
|
|
|
trainLogLikelihoods.release();
|
|
|
|
trainLabels.release();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
weights.release();
|
|
|
|
means.release();
|
|
|
|
covs.clear();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
covsEigenValues.clear();
|
|
|
|
invCovsEigenValues.clear();
|
|
|
|
covsRotateMats.clear();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
logWeightDivDet.release();
|
|
|
|
}
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
bool train(const Ptr<TrainData>& data, int)
|
|
|
|
{
|
|
|
|
Mat samples = data->getTrainSamples(), labels;
|
2015-02-11 18:24:14 +08:00
|
|
|
return trainEM(samples, labels, noArray(), noArray());
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
bool trainEM(InputArray samples,
|
2012-04-30 22:33:52 +08:00
|
|
|
OutputArray logLikelihoods,
|
2012-04-14 05:50:59 +08:00
|
|
|
OutputArray labels,
|
2012-04-30 22:33:52 +08:00
|
|
|
OutputArray probs)
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
|
|
|
Mat samplesMat = samples.getMat();
|
|
|
|
setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
|
|
|
|
return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
bool trainE(InputArray samples,
|
2012-04-14 05:50:59 +08:00
|
|
|
InputArray _means0,
|
|
|
|
InputArray _covs0,
|
|
|
|
InputArray _weights0,
|
2012-04-30 22:33:52 +08:00
|
|
|
OutputArray logLikelihoods,
|
2012-04-14 05:50:59 +08:00
|
|
|
OutputArray labels,
|
2012-04-30 22:33:52 +08:00
|
|
|
OutputArray probs)
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
|
|
|
Mat samplesMat = samples.getMat();
|
|
|
|
std::vector<Mat> covs0;
|
|
|
|
_covs0.getMatVector(covs0);
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
|
2012-04-14 05:50:59 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
|
|
|
|
!_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
|
|
|
|
return doTrain(START_E_STEP, logLikelihoods, labels, probs);
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
bool trainM(InputArray samples,
|
2012-04-14 05:50:59 +08:00
|
|
|
InputArray _probs0,
|
2012-04-30 22:33:52 +08:00
|
|
|
OutputArray logLikelihoods,
|
2012-04-14 05:50:59 +08:00
|
|
|
OutputArray labels,
|
2012-04-30 22:33:52 +08:00
|
|
|
OutputArray probs)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat samplesMat = samples.getMat();
|
|
|
|
Mat probs0 = _probs0.getMat();
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
|
|
|
|
return doTrain(START_M_STEP, logLikelihoods, labels, probs);
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
float predict(InputArray _inputs, OutputArray _outputs, int) const
|
|
|
|
{
|
|
|
|
bool needprobs = _outputs.needed();
|
|
|
|
Mat samples = _inputs.getMat(), probs, probsrow;
|
|
|
|
int ptype = CV_32F;
|
|
|
|
float firstres = 0.f;
|
|
|
|
int i, nsamples = samples.rows;
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( needprobs )
|
|
|
|
{
|
|
|
|
if( _outputs.fixedType() )
|
|
|
|
ptype = _outputs.type();
|
2015-02-11 18:24:14 +08:00
|
|
|
_outputs.create(samples.rows, nclusters, ptype);
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
else
|
|
|
|
nsamples = std::min(nsamples, 1);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < nsamples; i++ )
|
2012-04-14 05:50:59 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if( needprobs )
|
|
|
|
probsrow = probs.row(i);
|
|
|
|
Vec2d res = computeProbabilities(samples.row(i), needprobs ? &probsrow : 0, ptype);
|
|
|
|
if( i == 0 )
|
|
|
|
firstres = (float)res[1];
|
2012-04-14 05:50:59 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
return firstres;
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Vec2d predict2(InputArray _sample, OutputArray _probs) const
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int ptype = CV_32F;
|
|
|
|
Mat sample = _sample.getMat();
|
|
|
|
CV_Assert(isTrained());
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert(!sample.empty());
|
|
|
|
if(sample.type() != CV_64FC1)
|
|
|
|
{
|
|
|
|
Mat tmp;
|
|
|
|
sample.convertTo(tmp, CV_64FC1);
|
|
|
|
sample = tmp;
|
|
|
|
}
|
|
|
|
sample.reshape(1, 1);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat probs;
|
|
|
|
if( _probs.needed() )
|
|
|
|
{
|
|
|
|
if( _probs.fixedType() )
|
|
|
|
ptype = _probs.type();
|
2015-02-11 18:24:14 +08:00
|
|
|
_probs.create(1, nclusters, ptype);
|
2014-07-30 03:54:23 +08:00
|
|
|
probs = _probs.getMat();
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
return computeProbabilities(sample, !probs.empty() ? &probs : 0, ptype);
|
2012-04-14 05:50:59 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
bool isTrained() const
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
return !means.empty();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
bool isClassifier() const
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
return true;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int getVarCount() const
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
return means.cols;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2015-04-07 21:44:26 +08:00
|
|
|
String getDefaultName() const
|
2012-04-14 05:50:59 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
return "opencv_ml_em";
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
static void checkTrainData(int startStep, const Mat& samples,
|
|
|
|
int nclusters, int covMatType, const Mat* probs, const Mat* means,
|
|
|
|
const std::vector<Mat>* covs, const Mat* weights)
|
|
|
|
{
|
|
|
|
// Check samples.
|
|
|
|
CV_Assert(!samples.empty());
|
|
|
|
CV_Assert(samples.channels() == 1);
|
|
|
|
|
|
|
|
int nsamples = samples.rows;
|
|
|
|
int dim = samples.cols;
|
|
|
|
|
|
|
|
// Check training params.
|
|
|
|
CV_Assert(nclusters > 0);
|
|
|
|
CV_Assert(nclusters <= nsamples);
|
|
|
|
CV_Assert(startStep == START_AUTO_STEP ||
|
|
|
|
startStep == START_E_STEP ||
|
|
|
|
startStep == START_M_STEP);
|
|
|
|
CV_Assert(covMatType == COV_MAT_GENERIC ||
|
|
|
|
covMatType == COV_MAT_DIAGONAL ||
|
|
|
|
covMatType == COV_MAT_SPHERICAL);
|
|
|
|
|
|
|
|
CV_Assert(!probs ||
|
|
|
|
(!probs->empty() &&
|
|
|
|
probs->rows == nsamples && probs->cols == nclusters &&
|
|
|
|
(probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));
|
|
|
|
|
|
|
|
CV_Assert(!weights ||
|
|
|
|
(!weights->empty() &&
|
|
|
|
(weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
|
|
|
|
(weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));
|
|
|
|
|
|
|
|
CV_Assert(!means ||
|
|
|
|
(!means->empty() &&
|
|
|
|
means->rows == nclusters && means->cols == dim &&
|
|
|
|
means->channels() == 1));
|
|
|
|
|
|
|
|
CV_Assert(!covs ||
|
|
|
|
(!covs->empty() &&
|
|
|
|
static_cast<int>(covs->size()) == nclusters));
|
|
|
|
if(covs)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
const Size covSize(dim, dim);
|
|
|
|
for(size_t i = 0; i < covs->size(); i++)
|
|
|
|
{
|
|
|
|
const Mat& m = (*covs)[i];
|
|
|
|
CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
|
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
|
|
|
|
if(startStep == START_E_STEP)
|
2011-06-07 18:05:23 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert(means);
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if(startStep == START_M_STEP)
|
2012-04-14 05:50:59 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert(probs);
|
2012-04-14 05:50:59 +08:00
|
|
|
}
|
2011-06-07 18:05:23 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
static void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
|
2012-04-14 05:50:59 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if(src.type() == dstType && !isAlwaysClone)
|
|
|
|
dst = src;
|
2012-04-14 05:50:59 +08:00
|
|
|
else
|
2014-07-30 03:54:23 +08:00
|
|
|
src.convertTo(dst, dstType);
|
2012-04-14 05:50:59 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
static void preprocessProbability(Mat& probs)
|
2012-04-14 05:50:59 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
max(probs, 0., probs);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
const double uniformProbability = (double)(1./probs.cols);
|
|
|
|
for(int y = 0; y < probs.rows; y++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat sampleProbs = probs.row(y);
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
double maxVal = 0;
|
|
|
|
minMaxLoc(sampleProbs, 0, &maxVal);
|
|
|
|
if(maxVal < FLT_EPSILON)
|
|
|
|
sampleProbs.setTo(uniformProbability);
|
|
|
|
else
|
|
|
|
normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void setTrainData(int startStep, const Mat& samples,
|
|
|
|
const Mat* probs0,
|
|
|
|
const Mat* means0,
|
|
|
|
const std::vector<Mat>* covs0,
|
|
|
|
const Mat* weights0)
|
|
|
|
{
|
|
|
|
clear();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
bool isKMeansInit = (startStep == START_AUTO_STEP) || (startStep == START_E_STEP && (covs0 == 0 || weights0 == 0));
|
|
|
|
// Set checked data
|
|
|
|
preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// set probs
|
|
|
|
if(probs0 && startStep == START_M_STEP)
|
|
|
|
{
|
|
|
|
preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
|
|
|
|
preprocessProbability(trainProbs);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// set weights
|
|
|
|
if(weights0 && (startStep == START_E_STEP && covs0))
|
|
|
|
{
|
|
|
|
weights0->convertTo(weights, CV_64FC1);
|
|
|
|
weights.reshape(1,1);
|
|
|
|
preprocessProbability(weights);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// set means
|
|
|
|
if(means0 && (startStep == START_E_STEP/* || startStep == START_AUTO_STEP*/))
|
|
|
|
means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
|
|
|
|
|
|
|
|
// set covs
|
|
|
|
if(covs0 && (startStep == START_E_STEP && weights0))
|
|
|
|
{
|
|
|
|
covs.resize(nclusters);
|
|
|
|
for(size_t i = 0; i < covs0->size(); i++)
|
|
|
|
(*covs0)[i].convertTo(covs[i], CV_64FC1);
|
|
|
|
}
|
2012-04-14 05:50:59 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void decomposeCovs()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert(!covs.empty());
|
|
|
|
covsEigenValues.resize(nclusters);
|
|
|
|
if(covMatType == COV_MAT_GENERIC)
|
|
|
|
covsRotateMats.resize(nclusters);
|
|
|
|
invCovsEigenValues.resize(nclusters);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert(!covs[clusterIndex].empty());
|
|
|
|
|
|
|
|
SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
|
|
|
|
|
|
|
|
if(covMatType == COV_MAT_SPHERICAL)
|
|
|
|
{
|
|
|
|
double maxSingularVal = svd.w.at<double>(0);
|
|
|
|
covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
|
|
|
|
}
|
|
|
|
else if(covMatType == COV_MAT_DIAGONAL)
|
|
|
|
{
|
2015-09-15 08:35:53 +08:00
|
|
|
covsEigenValues[clusterIndex] = covs[clusterIndex].diag().clone(); //Preserve the original order of eigen values.
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
else //COV_MAT_GENERIC
|
|
|
|
{
|
|
|
|
covsEigenValues[clusterIndex] = svd.w;
|
|
|
|
covsRotateMats[clusterIndex] = svd.u;
|
|
|
|
}
|
|
|
|
max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
|
|
|
|
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void clusterTrainSamples()
|
2012-04-14 05:50:59 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int nsamples = trainSamples.rows;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// Cluster samples, compute/update means
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// Convert samples and means to 32F, because kmeans requires this type.
|
|
|
|
Mat trainSamplesFlt, meansFlt;
|
|
|
|
if(trainSamples.type() != CV_32FC1)
|
|
|
|
trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
|
|
|
|
else
|
|
|
|
trainSamplesFlt = trainSamples;
|
|
|
|
if(!means.empty())
|
|
|
|
{
|
|
|
|
if(means.type() != CV_32FC1)
|
|
|
|
means.convertTo(meansFlt, CV_32FC1);
|
|
|
|
else
|
|
|
|
meansFlt = means;
|
|
|
|
}
|
|
|
|
|
|
|
|
Mat labels;
|
|
|
|
kmeans(trainSamplesFlt, nclusters, labels,
|
|
|
|
TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
|
|
|
|
10, KMEANS_PP_CENTERS, meansFlt);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// Convert samples and means back to 64F.
|
|
|
|
CV_Assert(meansFlt.type() == CV_32FC1);
|
|
|
|
if(trainSamples.type() != CV_64FC1)
|
|
|
|
{
|
|
|
|
Mat trainSamplesBuffer;
|
|
|
|
trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
|
|
|
|
trainSamples = trainSamplesBuffer;
|
|
|
|
}
|
|
|
|
meansFlt.convertTo(means, CV_64FC1);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// Compute weights and covs
|
|
|
|
weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
|
|
|
|
covs.resize(nclusters);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
|
|
{
|
|
|
|
Mat clusterSamples;
|
|
|
|
for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
|
|
|
|
{
|
|
|
|
if(labels.at<int>(sampleIndex) == clusterIndex)
|
|
|
|
{
|
|
|
|
const Mat sample = trainSamples.row(sampleIndex);
|
|
|
|
clusterSamples.push_back(sample);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
CV_Assert(!clusterSamples.empty());
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
|
|
|
|
CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
|
|
|
|
weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
decomposeCovs();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-04-14 05:50:59 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void computeLogWeightDivDet()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert(!covsEigenValues.empty());
|
|
|
|
|
|
|
|
Mat logWeights;
|
|
|
|
cv::max(weights, DBL_MIN, weights);
|
|
|
|
log(weights, logWeights);
|
|
|
|
|
|
|
|
logWeightDivDet.create(1, nclusters, CV_64FC1);
|
|
|
|
// note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
|
|
|
|
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
|
|
{
|
|
|
|
double logDetCov = 0.;
|
|
|
|
const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
|
|
|
|
for(int di = 0; di < evalCount; di++)
|
2015-02-11 18:24:14 +08:00
|
|
|
logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
|
2014-07-30 03:54:23 +08:00
|
|
|
|
|
|
|
logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int dim = trainSamples.cols;
|
|
|
|
// Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
|
|
|
|
if(startStep != START_M_STEP)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if(covs.empty())
|
|
|
|
{
|
|
|
|
CV_Assert(weights.empty());
|
|
|
|
clusterTrainSamples();
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
|
|
|
|
if(!covs.empty() && covsEigenValues.empty() )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert(invCovsEigenValues.empty());
|
|
|
|
decomposeCovs();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(startStep == START_M_STEP)
|
|
|
|
mStep();
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
double trainLogLikelihood, prevTrainLogLikelihood = 0.;
|
2015-02-11 18:24:14 +08:00
|
|
|
int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
|
|
|
|
termCrit.maxCount : DEFAULT_MAX_ITERS;
|
|
|
|
double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
|
2012-04-14 05:50:59 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for(int iter = 0; ; iter++)
|
|
|
|
{
|
|
|
|
eStep();
|
|
|
|
trainLogLikelihood = sum(trainLogLikelihoods)[0];
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(iter >= maxIters - 1)
|
|
|
|
break;
|
|
|
|
|
|
|
|
double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
|
|
|
|
if( iter != 0 &&
|
|
|
|
(trainLogLikelihoodDelta < -DBL_EPSILON ||
|
|
|
|
trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
|
|
|
|
break;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
mStep();
|
2011-06-07 18:05:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
prevTrainLogLikelihood = trainLogLikelihood;
|
|
|
|
}
|
|
|
|
|
|
|
|
if( trainLogLikelihood <= -DBL_MAX/10000. )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
clear();
|
|
|
|
return false;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// postprocess covs
|
|
|
|
covs.resize(nclusters);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
if(covMatType == COV_MAT_SPHERICAL)
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
|
|
|
covs[clusterIndex].create(dim, dim, CV_64FC1);
|
|
|
|
setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
|
|
|
|
}
|
2015-02-11 18:24:14 +08:00
|
|
|
else if(covMatType == COV_MAT_DIAGONAL)
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
|
|
|
covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
|
|
|
|
}
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(labels.needed())
|
|
|
|
trainLabels.copyTo(labels);
|
|
|
|
if(probs.needed())
|
|
|
|
trainProbs.copyTo(probs);
|
|
|
|
if(logLikelihoods.needed())
|
|
|
|
trainLogLikelihoods.copyTo(logLikelihoods);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
trainSamples.release();
|
|
|
|
trainProbs.release();
|
|
|
|
trainLabels.release();
|
|
|
|
trainLogLikelihoods.release();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
return true;
|
|
|
|
}
|
2012-04-30 22:33:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Vec2d computeProbabilities(const Mat& sample, Mat* probs, int ptype) const
|
|
|
|
{
|
|
|
|
// L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
|
|
|
|
// q = arg(max_k(L_ik))
|
|
|
|
// probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
|
|
|
|
// see Alex Smola's blog http://blog.smola.org/page/2 for
|
|
|
|
// details on the log-sum-exp trick
|
|
|
|
|
|
|
|
int stype = sample.type();
|
|
|
|
CV_Assert(!means.empty());
|
|
|
|
CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
|
|
|
|
CV_Assert(sample.size() == Size(means.cols, 1));
|
|
|
|
|
|
|
|
int dim = sample.cols;
|
|
|
|
|
|
|
|
Mat L(1, nclusters, CV_64FC1), centeredSample(1, dim, CV_64F);
|
|
|
|
int i, label = 0;
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
|
|
{
|
|
|
|
const double* mptr = means.ptr<double>(clusterIndex);
|
|
|
|
double* dptr = centeredSample.ptr<double>();
|
|
|
|
if( stype == CV_32F )
|
|
|
|
{
|
|
|
|
const float* sptr = sample.ptr<float>();
|
|
|
|
for( i = 0; i < dim; i++ )
|
|
|
|
dptr[i] = sptr[i] - mptr[i];
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
const double* sptr = sample.ptr<double>();
|
|
|
|
for( i = 0; i < dim; i++ )
|
|
|
|
dptr[i] = sptr[i] - mptr[i];
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat rotatedCenteredSample = covMatType != COV_MAT_GENERIC ?
|
|
|
|
centeredSample : centeredSample * covsRotateMats[clusterIndex];
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
double Lval = 0;
|
|
|
|
for(int di = 0; di < dim; di++)
|
|
|
|
{
|
|
|
|
double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0);
|
|
|
|
double val = rotatedCenteredSample.at<double>(di);
|
|
|
|
Lval += w * val * val;
|
|
|
|
}
|
|
|
|
CV_DbgAssert(!logWeightDivDet.empty());
|
|
|
|
L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(L.at<double>(clusterIndex) > L.at<double>(label))
|
|
|
|
label = clusterIndex;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
double maxLVal = L.at<double>(label);
|
|
|
|
double expDiffSum = 0;
|
|
|
|
for( i = 0; i < L.cols; i++ )
|
|
|
|
{
|
|
|
|
double v = std::exp(L.at<double>(i) - maxLVal);
|
|
|
|
L.at<double>(i) = v;
|
|
|
|
expDiffSum += v; // sum_j(exp(L_ij - L_iq))
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(probs)
|
|
|
|
L.convertTo(*probs, ptype, 1./expDiffSum);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Vec2d res;
|
|
|
|
res[0] = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI;
|
|
|
|
res[1] = label;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
return res;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void eStep()
|
2012-04-30 22:33:52 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
// Compute probs_ik from means_k, covs_k and weights_k.
|
2015-02-11 18:24:14 +08:00
|
|
|
trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
|
2014-07-30 03:54:23 +08:00
|
|
|
trainLabels.create(trainSamples.rows, 1, CV_32SC1);
|
|
|
|
trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
computeLogWeightDivDet();
|
|
|
|
|
|
|
|
CV_DbgAssert(trainSamples.type() == CV_64FC1);
|
|
|
|
CV_DbgAssert(means.type() == CV_64FC1);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2012-04-30 22:33:52 +08:00
|
|
|
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
|
|
|
Mat sampleProbs = trainProbs.row(sampleIndex);
|
|
|
|
Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs, CV_64F);
|
|
|
|
trainLogLikelihoods.at<double>(sampleIndex) = res[0];
|
|
|
|
trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
|
|
|
|
}
|
2012-04-30 22:33:52 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void mStep()
|
2012-04-30 22:33:52 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
// Update means_k, covs_k and weights_k from probs_ik
|
|
|
|
int dim = trainSamples.cols;
|
|
|
|
|
|
|
|
// Update weights
|
|
|
|
// not normalized first
|
|
|
|
reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
|
|
|
|
|
|
|
|
// Update means
|
|
|
|
means.create(nclusters, dim, CV_64FC1);
|
|
|
|
means = Scalar(0);
|
|
|
|
|
|
|
|
const double minPosWeight = trainSamples.rows * DBL_EPSILON;
|
|
|
|
double minWeight = DBL_MAX;
|
|
|
|
int minWeightClusterIndex = -1;
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
|
|
{
|
|
|
|
if(weights.at<double>(clusterIndex) <= minPosWeight)
|
|
|
|
continue;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(weights.at<double>(clusterIndex) < minWeight)
|
|
|
|
{
|
|
|
|
minWeight = weights.at<double>(clusterIndex);
|
|
|
|
minWeightClusterIndex = clusterIndex;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat clusterMean = means.row(clusterIndex);
|
|
|
|
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
|
|
|
clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
|
|
|
|
clusterMean /= weights.at<double>(clusterIndex);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// Update covsEigenValues and invCovsEigenValues
|
|
|
|
covs.resize(nclusters);
|
|
|
|
covsEigenValues.resize(nclusters);
|
|
|
|
if(covMatType == COV_MAT_GENERIC)
|
|
|
|
covsRotateMats.resize(nclusters);
|
|
|
|
invCovsEigenValues.resize(nclusters);
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2012-04-30 22:33:52 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if(weights.at<double>(clusterIndex) <= minPosWeight)
|
|
|
|
continue;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(covMatType != COV_MAT_SPHERICAL)
|
|
|
|
covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
|
2012-04-30 22:33:52 +08:00
|
|
|
else
|
2014-07-30 03:54:23 +08:00
|
|
|
covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
|
|
|
|
|
|
|
|
if(covMatType == COV_MAT_GENERIC)
|
|
|
|
covs[clusterIndex].create(dim, dim, CV_64FC1);
|
|
|
|
|
|
|
|
Mat clusterCov = covMatType != COV_MAT_GENERIC ?
|
|
|
|
covsEigenValues[clusterIndex] : covs[clusterIndex];
|
|
|
|
|
|
|
|
clusterCov = Scalar(0);
|
|
|
|
|
|
|
|
Mat centeredSample;
|
|
|
|
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
|
|
|
|
|
|
|
|
if(covMatType == COV_MAT_GENERIC)
|
|
|
|
clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
|
|
|
|
else
|
2012-04-14 05:50:59 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
double p = trainProbs.at<double>(sampleIndex, clusterIndex);
|
|
|
|
for(int di = 0; di < dim; di++ )
|
|
|
|
{
|
|
|
|
double val = centeredSample.at<double>(di);
|
|
|
|
clusterCov.at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0) += p*val*val;
|
|
|
|
}
|
2012-04-14 05:50:59 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(covMatType == COV_MAT_SPHERICAL)
|
|
|
|
clusterCov /= dim;
|
|
|
|
|
|
|
|
clusterCov /= weights.at<double>(clusterIndex);
|
|
|
|
|
|
|
|
// Update covsRotateMats for COV_MAT_GENERIC only
|
|
|
|
if(covMatType == COV_MAT_GENERIC)
|
|
|
|
{
|
|
|
|
SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
|
|
|
|
covsEigenValues[clusterIndex] = svd.w;
|
|
|
|
covsRotateMats[clusterIndex] = svd.u;
|
|
|
|
}
|
2012-04-14 05:50:59 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
|
2012-04-14 05:50:59 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// update invCovsEigenValues
|
|
|
|
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
|
|
|
|
}
|
|
|
|
|
|
|
|
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
2012-04-30 22:33:52 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if(weights.at<double>(clusterIndex) <= minPosWeight)
|
|
|
|
{
|
|
|
|
Mat clusterMean = means.row(clusterIndex);
|
|
|
|
means.row(minWeightClusterIndex).copyTo(clusterMean);
|
|
|
|
covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
|
|
|
|
covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
|
|
|
|
if(covMatType == COV_MAT_GENERIC)
|
|
|
|
covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
|
|
|
|
invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
|
|
|
|
}
|
2012-04-30 22:33:52 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// Normalize weights
|
|
|
|
weights /= trainSamples.rows;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void write_params(FileStorage& fs) const
|
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
fs << "nclusters" << nclusters;
|
|
|
|
fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
|
|
|
|
covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
|
|
|
|
covMatType == COV_MAT_GENERIC ? String("generic") :
|
|
|
|
format("unknown_%d", covMatType));
|
|
|
|
writeTermCrit(fs, termCrit);
|
2012-04-30 22:33:52 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void write(FileStorage& fs) const
|
2012-04-30 22:33:52 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
fs << "training_params" << "{";
|
|
|
|
write_params(fs);
|
|
|
|
fs << "}";
|
|
|
|
fs << "weights" << weights;
|
|
|
|
fs << "means" << means;
|
|
|
|
|
|
|
|
size_t i, n = covs.size();
|
|
|
|
|
|
|
|
fs << "covs" << "[";
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
fs << covs[i];
|
|
|
|
fs << "]";
|
|
|
|
}
|
|
|
|
|
|
|
|
void read_params(const FileNode& fn)
|
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
nclusters = (int)fn["nclusters"];
|
2014-07-30 03:54:23 +08:00
|
|
|
String s = (String)fn["cov_mat_type"];
|
2015-02-11 18:24:14 +08:00
|
|
|
covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
|
2014-07-30 03:54:23 +08:00
|
|
|
s == "diagonal" ? COV_MAT_DIAGONAL :
|
|
|
|
s == "generic" ? COV_MAT_GENERIC : -1;
|
2015-02-11 18:24:14 +08:00
|
|
|
CV_Assert(covMatType >= 0);
|
|
|
|
termCrit = readTermCrit(fn);
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void read(const FileNode& fn)
|
|
|
|
{
|
|
|
|
clear();
|
|
|
|
read_params(fn["training_params"]);
|
|
|
|
|
|
|
|
fn["weights"] >> weights;
|
|
|
|
fn["means"] >> means;
|
|
|
|
|
|
|
|
FileNode cfn = fn["covs"];
|
|
|
|
FileNodeIterator cfn_it = cfn.begin();
|
|
|
|
int i, n = (int)cfn.size();
|
|
|
|
covs.resize(n);
|
|
|
|
|
|
|
|
for( i = 0; i < n; i++, ++cfn_it )
|
|
|
|
(*cfn_it) >> covs[i];
|
|
|
|
|
|
|
|
decomposeCovs();
|
|
|
|
computeLogWeightDivDet();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-04-30 22:33:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat getWeights() const { return weights; }
|
|
|
|
Mat getMeans() const { return means; }
|
|
|
|
void getCovs(std::vector<Mat>& _covs) const
|
|
|
|
{
|
|
|
|
_covs.resize(covs.size());
|
|
|
|
std::copy(covs.begin(), covs.end(), _covs.begin());
|
|
|
|
}
|
|
|
|
|
|
|
|
// all inner matrices have type CV_64FC1
|
|
|
|
Mat trainSamples;
|
|
|
|
Mat trainProbs;
|
|
|
|
Mat trainLogLikelihoods;
|
|
|
|
Mat trainLabels;
|
|
|
|
|
|
|
|
Mat weights;
|
|
|
|
Mat means;
|
|
|
|
std::vector<Mat> covs;
|
|
|
|
|
|
|
|
std::vector<Mat> covsEigenValues;
|
|
|
|
std::vector<Mat> covsRotateMats;
|
|
|
|
std::vector<Mat> invCovsEigenValues;
|
|
|
|
Mat logWeightDivDet;
|
|
|
|
};
|
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
Ptr<EM> EM::create()
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
return makePtr<EMImpl>();
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
2012-04-14 05:50:59 +08:00
|
|
|
} // namespace cv
|
2010-05-12 01:44:00 +08:00
|
|
|
|
|
|
|
/* End of file. */
|