refactored train and predict methods of em

This commit is contained in:
Maria Dimashova 2012-04-17 06:29:40 +00:00
parent 8f7e5811b6
commit 3dfa917879
7 changed files with 56 additions and 65 deletions

View File

@ -213,7 +213,7 @@ void CvHybridTracker::updateTrackerWithEM(Mat image) {
cv::Mat lbls;
EM em_model(1, EM::COV_MAT_SPHERICAL, TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.001));
em_model.train(cvarrToMat(samples), lbls);
em_model.train(cvarrToMat(samples), noArray(), lbls);
if(labels)
lbls.copyTo(cvarrToMat(labels));

View File

@ -1826,7 +1826,7 @@ public:
CV_WRAP cv::Mat getWeights() const;
CV_WRAP cv::Mat getProbs() const;
CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? likelihood : DBL_MAX; }
CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? logLikelihood : DBL_MAX; }
#endif
CV_WRAP virtual void clear();
@ -1847,7 +1847,7 @@ protected:
cv::EM emObj;
cv::Mat probs;
double likelihood;
double logLikelihood;
CvMat meansHdr;
std::vector<CvMat> covsHdrs;

View File

@ -56,12 +56,12 @@ CvEMParams::CvEMParams( int _nclusters, int _cov_mat_type, int _start_step,
probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
{}
CvEM::CvEM() : likelihood(DBL_MAX)
CvEM::CvEM() : logLikelihood(DBL_MAX)
{
}
CvEM::CvEM( const CvMat* samples, const CvMat* sample_idx,
CvEMParams params, CvMat* labels ) : likelihood(DBL_MAX)
CvEMParams params, CvMat* labels ) : logLikelihood(DBL_MAX)
{
train(samples, sample_idx, params, labels);
}
@ -96,16 +96,14 @@ void CvEM::write( CvFileStorage* _fs, const char* name ) const
double CvEM::calcLikelihood( const Mat &input_sample ) const
{
double likelihood;
emObj.predict(input_sample, noArray(), &likelihood);
return likelihood;
return emObj.predict(input_sample)[0];
}
float
CvEM::predict( const CvMat* _sample, CvMat* _probs ) const
{
Mat prbs0 = cvarrToMat(_probs), prbs = prbs0, sample = cvarrToMat(_sample);
int cls = emObj.predict(sample, _probs ? _OutputArray(prbs) : cv::noArray());
int cls = static_cast<int>(emObj.predict(sample, _probs ? _OutputArray(prbs) : cv::noArray())[1]);
if(_probs)
{
if( prbs.data != prbs0.data )
@ -203,29 +201,27 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
CvEMParams _params, Mat* _labels )
{
CV_Assert(_sample_idx.empty());
Mat prbs, weights, means, likelihoods;
Mat prbs, weights, means, logLikelihoods;
std::vector<Mat> covsHdrs;
init_params(_params, prbs, weights, means, covsHdrs);
emObj = EM(_params.nclusters, _params.cov_mat_type, _params.term_crit);
bool isOk = false;
if( _params.start_step == EM::START_AUTO_STEP )
isOk = emObj.train(_samples, _labels ? _OutputArray(*_labels) : cv::noArray(),
probs, likelihoods);
isOk = emObj.train(_samples,
logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
else if( _params.start_step == EM::START_E_STEP )
isOk = emObj.trainE(_samples, means, covsHdrs, weights,
_labels ? _OutputArray(*_labels) : cv::noArray(),
probs, likelihoods);
logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
else if( _params.start_step == EM::START_M_STEP )
isOk = emObj.trainM(_samples, prbs,
_labels ? _OutputArray(*_labels) : cv::noArray(),
probs, likelihoods);
logLikelihoods, _labels ? _OutputArray(*_labels) : cv::noArray(), probs);
else
CV_Error(CV_StsBadArg, "Bad start type of EM algorithm");
if(isOk)
{
likelihoods = sum(likelihoods).val[0];
logLikelihood = sum(logLikelihoods).val[0];
set_mat_hdrs();
}
@ -235,8 +231,7 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
float
CvEM::predict( const Mat& _sample, Mat* _probs ) const
{
int cls = emObj.predict(_sample, _probs ? _OutputArray(*_probs) : cv::noArray());
return (float)cls;
return static_cast<float>(emObj.predict(_sample, _probs ? _OutputArray(*_probs) : cv::noArray())[1]);
}
int CvEM::getNClusters() const

View File

@ -577,27 +577,26 @@ public:
CV_WRAP virtual void clear();
CV_WRAP virtual bool train(InputArray samples,
OutputArray logLikelihoods=noArray(),
OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray logLikelihoods=noArray());
OutputArray probs=noArray());
CV_WRAP virtual bool trainE(InputArray samples,
InputArray means0,
InputArray covs0=noArray(),
InputArray weights0=noArray(),
OutputArray logLikelihoods=noArray(),
OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray logLikelihoods=noArray());
OutputArray probs=noArray());
CV_WRAP virtual bool trainM(InputArray samples,
InputArray probs0,
OutputArray logLikelihoods=noArray(),
OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray logLikelihoods=noArray());
OutputArray probs=noArray());
CV_WRAP int predict(InputArray sample,
OutputArray probs=noArray(),
CV_OUT double* logLikelihood=0) const;
CV_WRAP Vec2d predict(InputArray sample,
OutputArray probs=noArray()) const;
CV_WRAP bool isTrained() const;
@ -613,9 +612,9 @@ protected:
const Mat* weights0);
bool doTrain(int startStep,
OutputArray logLikelihoods,
OutputArray labels,
OutputArray probs,
OutputArray logLikelihoods);
OutputArray probs);
virtual void eStep();
virtual void mStep();
@ -623,7 +622,7 @@ protected:
void decomposeCovs();
void computeLogWeightDivDet();
void computeProbabilities(const Mat& sample, int& label, Mat* probs, double* logLikelihood) const;
Vec2d computeProbabilities(const Mat& sample, Mat* probs) const;
// all inner matrices have type CV_64FC1
CV_PROP_RW int nclusters;

View File

@ -81,22 +81,22 @@ void EM::clear()
bool EM::train(InputArray samples,
OutputArray logLikelihoods,
OutputArray labels,
OutputArray probs,
OutputArray logLikelihoods)
OutputArray probs)
{
Mat samplesMat = samples.getMat();
setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
return doTrain(START_AUTO_STEP, labels, probs, logLikelihoods);
return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
}
bool EM::trainE(InputArray samples,
InputArray _means0,
InputArray _covs0,
InputArray _weights0,
OutputArray logLikelihoods,
OutputArray labels,
OutputArray probs,
OutputArray logLikelihoods)
OutputArray probs)
{
Mat samplesMat = samples.getMat();
vector<Mat> covs0;
@ -106,24 +106,24 @@ bool EM::trainE(InputArray samples,
setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
!_covs0.empty() ? &covs0 : 0, _weights0.empty() ? &weights0 : 0);
return doTrain(START_E_STEP, labels, probs, logLikelihoods);
return doTrain(START_E_STEP, logLikelihoods, labels, probs);
}
bool EM::trainM(InputArray samples,
InputArray _probs0,
OutputArray logLikelihoods,
OutputArray labels,
OutputArray probs,
OutputArray logLikelihoods)
OutputArray probs)
{
Mat samplesMat = samples.getMat();
Mat probs0 = _probs0.getMat();
setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
return doTrain(START_M_STEP, labels, probs, logLikelihoods);
return doTrain(START_M_STEP, logLikelihoods, labels, probs);
}
int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) const
Vec2d EM::predict(InputArray _sample, OutputArray _probs) const
{
Mat sample = _sample.getMat();
CV_Assert(isTrained());
@ -136,16 +136,14 @@ int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) c
sample = tmp;
}
int label;
Mat probs;
if( _probs.needed() )
{
_probs.create(1, nclusters, CV_64FC1);
probs = _probs.getMat();
}
computeProbabilities(sample, label, !probs.empty() ? &probs : 0, logLikelihood);
return label;
return computeProbabilities(sample, !probs.empty() ? &probs : 0);
}
bool EM::isTrained() const
@ -394,7 +392,7 @@ void EM::computeLogWeightDivDet()
}
}
bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArray logLikelihoods)
bool EM::doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
{
int dim = trainSamples.cols;
// Precompute the empty initial train data in the cases of EM::START_E_STEP and START_AUTO_STEP
@ -472,7 +470,7 @@ bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArr
return true;
}
void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double* logLikelihood) const
Vec2d EM::computeProbabilities(const Mat& sample, Mat* probs) const
{
// L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
// q = arg(max_k(L_ik))
@ -488,7 +486,7 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
int dim = sample.cols;
Mat L(1, nclusters, CV_64FC1);
label = 0;
int label = 0;
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
{
const Mat centeredSample = sample - means.row(clusterIndex);
@ -511,9 +509,6 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
label = clusterIndex;
}
if(!probs && !logLikelihood)
return;
double maxLVal = L.at<double>(label);
Mat expL_Lmax = L; // exp(L_ij - L_iq)
for(int i = 0; i < L.cols; i++)
@ -528,8 +523,11 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
expL_Lmax.copyTo(*probs);
}
if(logLikelihood)
*logLikelihood = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI;
Vec2d res;
res[0] = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI;
res[1] = label;
return res;
}
void EM::eStep()
@ -547,8 +545,9 @@ void EM::eStep()
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
{
Mat sampleProbs = trainProbs.row(sampleIndex);
computeProbabilities(trainSamples.row(sampleIndex), trainLabels.at<int>(sampleIndex),
&sampleProbs, &trainLogLikelihoods.at<double>(sampleIndex));
Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs);
trainLogLikelihoods.at<double>(sampleIndex) = res[0];
trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
}
}

View File

@ -373,11 +373,11 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
cv::EM em(params.nclusters, params.covMatType, params.termCrit);
if( params.startStep == EM::START_AUTO_STEP )
em.train( trainData, labels );
em.train( trainData, noArray(), labels );
else if( params.startStep == EM::START_E_STEP )
em.trainE( trainData, *params.means, *params.covs, *params.weights, labels );
em.trainE( trainData, *params.means, *params.covs, *params.weights, noArray(), labels );
else if( params.startStep == EM::START_M_STEP )
em.trainM( trainData, *params.probs, labels );
em.trainM( trainData, *params.probs, noArray(), labels );
// check train error
if( !calcErr( labels, trainLabels, sizes, err , false, false ) )
@ -396,9 +396,8 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
for( int i = 0; i < testData.rows; i++ )
{
Mat sample = testData.row(i);
double likelihood = 0;
Mat probs;
labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood );
labels.at<int>(i) = static_cast<int>(em.predict( sample, probs )[1]);
}
if( !calcErr( labels, testLabels, sizes, err, false, false ) )
{
@ -523,7 +522,7 @@ protected:
Mat firstResult(samples.rows, 1, CV_32SC1);
for( int i = 0; i < samples.rows; i++)
firstResult.at<int>(i) = em.predict(samples.row(i));
firstResult.at<int>(i) = static_cast<int>(em.predict(samples.row(i))[1]);
// Write out
string filename = tempfile() + ".xml";
@ -564,7 +563,7 @@ protected:
int errCaseCount = 0;
for( int i = 0; i < samples.rows; i++)
errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
errCaseCount = std::abs(em.predict(samples.row(i))[1] - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
if( errCaseCount > 0 )
{
@ -637,10 +636,9 @@ protected:
const double lambda = 1.;
for(int i = 0; i < samples.rows; i++)
{
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
Mat sample = samples.row(i);
model0.predict(sample, noArray(), &sampleLogLikelihoods0);
model1.predict(sample, noArray(), &sampleLogLikelihoods1);
double sampleLogLikelihoods0 = model0.predict(sample)[0];
double sampleLogLikelihoods1 = model1.predict(sample)[0];
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;

View File

@ -478,7 +478,7 @@ void find_decision_boundary_EM()
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
{
if(em_models[modelIndex].isTrained())
em_models[modelIndex].predict( testSample, noArray(), &logLikelihoods.at<double>(modelIndex) );
logLikelihoods.at<double>(modelIndex) = em_models[modelIndex].predict(testSample)[0];
}
Point maxLoc;
minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);