diff --git a/modules/ml/include/opencv2/ml/ml.hpp b/modules/ml/include/opencv2/ml/ml.hpp index 36f86c4749..09cdc36778 100644 --- a/modules/ml/include/opencv2/ml/ml.hpp +++ b/modules/ml/include/opencv2/ml/ml.hpp @@ -563,7 +563,7 @@ public: enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL}; // Default parameters - enum {DEFAULT_NCLUSTERS=10, DEFAULT_MAX_ITERS=100}; + enum {DEFAULT_NCLUSTERS=5, DEFAULT_MAX_ITERS=100}; // The initial step enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0}; @@ -635,7 +635,6 @@ protected: Mat trainProbs; Mat trainLogLikelihoods; Mat trainLabels; - Mat trainCounts; CV_PROP Mat weights; CV_PROP Mat means; @@ -2035,7 +2034,7 @@ public: // returns: // 0 - OK - // 1 - file can not be opened or is not correct + // -1 - file can not be opened or is not correct int read_csv( const char* filename ); const CvMat* get_values() const; diff --git a/modules/ml/src/em.cpp b/modules/ml/src/em.cpp index 95eda5b33e..545e107789 100644 --- a/modules/ml/src/em.cpp +++ b/modules/ml/src/em.cpp @@ -44,7 +44,7 @@ namespace cv { -const double minEigenValue = DBL_MIN; +const double minEigenValue = DBL_EPSILON; /////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -67,7 +67,6 @@ void EM::clear() trainProbs.release(); trainLogLikelihoods.release(); trainLabels.release(); - trainCounts.release(); weights.release(); means.release(); @@ -469,7 +468,6 @@ bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArr trainProbs.release(); trainLabels.release(); trainLogLikelihoods.release(); - trainCounts.release(); return true; } @@ -556,97 +554,114 @@ void EM::eStep() void EM::mStep() { - trainCounts.create(1, nclusters, CV_32SC1); - trainCounts = Scalar(0); + // Update means_k, covs_k and weights_k from probs_ik + int dim = trainSamples.cols; - for(int sampleIndex = 0; sampleIndex < trainLabels.rows; sampleIndex++) - trainCounts.at(trainLabels.at(sampleIndex))++; + // Update weights + // not normalized first + reduce(trainProbs, weights, 0, CV_REDUCE_SUM); - if(countNonZero(trainCounts) != (int)trainCounts.total()) + // Update means + means.create(nclusters, dim, CV_64FC1); + means = Scalar(0); + + const double minPosWeight = trainSamples.rows * DBL_EPSILON; + double minWeight = DBL_MAX; + int minWeightClusterIndex = -1; + for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) { - clusterTrainSamples(); - } - else - { - // Update means_k, covs_k and weights_k from probs_ik - int dim = trainSamples.cols; + if(weights.at(clusterIndex) <= minPosWeight) + continue; - // Update weights - // not normalized first - reduce(trainProbs, weights, 0, CV_REDUCE_SUM); - - // Update means - means.create(nclusters, dim, CV_64FC1); - means = Scalar(0); - for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) + if(weights.at(clusterIndex) < minWeight) { - Mat clusterMean = means.row(clusterIndex); - for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) - clusterMean += trainProbs.at(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex); - clusterMean /= weights.at(clusterIndex); + minWeight = weights.at(clusterIndex); + minWeightClusterIndex = clusterIndex; } - // Update covsEigenValues and invCovsEigenValues - covs.resize(nclusters); - covsEigenValues.resize(nclusters); + Mat clusterMean = means.row(clusterIndex); + for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) + clusterMean += trainProbs.at(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex); + clusterMean /= weights.at(clusterIndex); + } + + // Update covsEigenValues and invCovsEigenValues + covs.resize(nclusters); + covsEigenValues.resize(nclusters); + if(covMatType == EM::COV_MAT_GENERIC) + covsRotateMats.resize(nclusters); + invCovsEigenValues.resize(nclusters); + for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) + { + if(weights.at(clusterIndex) <= minPosWeight) + continue; + + if(covMatType != EM::COV_MAT_SPHERICAL) + covsEigenValues[clusterIndex].create(1, dim, CV_64FC1); + else + covsEigenValues[clusterIndex].create(1, 1, CV_64FC1); + if(covMatType == EM::COV_MAT_GENERIC) - covsRotateMats.resize(nclusters); - invCovsEigenValues.resize(nclusters); - for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) + covs[clusterIndex].create(dim, dim, CV_64FC1); + + Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ? + covsEigenValues[clusterIndex] : covs[clusterIndex]; + + clusterCov = Scalar(0); + + Mat centeredSample; + for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) { - if(covMatType != EM::COV_MAT_SPHERICAL) - covsEigenValues[clusterIndex].create(1, dim, CV_64FC1); - else - covsEigenValues[clusterIndex].create(1, 1, CV_64FC1); + centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex); if(covMatType == EM::COV_MAT_GENERIC) - covs[clusterIndex].create(dim, dim, CV_64FC1); - - Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ? - covsEigenValues[clusterIndex] : covs[clusterIndex]; - - clusterCov = Scalar(0); - - Mat centeredSample; - for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++) + clusterCov += trainProbs.at(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample; + else { - centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex); - - if(covMatType == EM::COV_MAT_GENERIC) - clusterCov += trainProbs.at(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample; - else + double p = trainProbs.at(sampleIndex, clusterIndex); + for(int di = 0; di < dim; di++ ) { - double p = trainProbs.at(sampleIndex, clusterIndex); - for(int di = 0; di < dim; di++ ) - { - double val = centeredSample.at(di); - clusterCov.at(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val; - } + double val = centeredSample.at(di); + clusterCov.at(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val; } } - - if(covMatType == EM::COV_MAT_SPHERICAL) - clusterCov /= dim; - - clusterCov /= weights.at(clusterIndex); - - // Update covsRotateMats for EM::COV_MAT_GENERIC only - if(covMatType == EM::COV_MAT_GENERIC) - { - SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV); - covsEigenValues[clusterIndex] = svd.w; - covsRotateMats[clusterIndex] = svd.u; - } - - max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]); - - // update invCovsEigenValues - invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex]; } - // Normalize weights - weights /= trainSamples.rows; + if(covMatType == EM::COV_MAT_SPHERICAL) + clusterCov /= dim; + + clusterCov /= weights.at(clusterIndex); + + // Update covsRotateMats for EM::COV_MAT_GENERIC only + if(covMatType == EM::COV_MAT_GENERIC) + { + SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV); + covsEigenValues[clusterIndex] = svd.w; + covsRotateMats[clusterIndex] = svd.u; + } + + max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]); + + // update invCovsEigenValues + invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex]; } + + for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++) + { + if(weights.at(clusterIndex) <= minPosWeight) + { + Mat clusterMean = means.row(clusterIndex); + means.row(minWeightClusterIndex).copyTo(clusterMean); + covs[minWeightClusterIndex].copyTo(covs[clusterIndex]); + covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]); + if(covMatType == EM::COV_MAT_GENERIC) + covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]); + invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]); + } + } + + // Normalize weights + weights /= trainSamples.rows; } void EM::read(const FileNode& fn) diff --git a/modules/ml/test/test_emknearestkmeans.cpp b/modules/ml/test/test_emknearestkmeans.cpp index ba0dec7472..c6f60235dd 100644 --- a/modules/ml/test/test_emknearestkmeans.cpp +++ b/modules/ml/test/test_emknearestkmeans.cpp @@ -572,7 +572,106 @@ protected: } }; +class CV_EMTest_Classification : public cvtest::BaseTest +{ +public: + CV_EMTest_Classification() {} +protected: + virtual void run(int) + { + // This test classifies spam by the following way: + // 1. estimates distributions of "spam" / "not spam" + // 2. predict classID using Bayes classifier for estimated distributions. + + CvMLData data; + string dataFilename = string(ts->get_data_path()) + "spambase.data"; + + if(data.read_csv(dataFilename.c_str()) != 0) + { + ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n"); + ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA); + } + + Mat values = data.get_values(); + CV_Assert(values.cols == 58); + int responseIndex = 57; + + Mat samples = values.colRange(0, responseIndex); + Mat responses = values.col(responseIndex); + + vector trainSamplesMask(samples.rows, 0); + int trainSamplesCount = (int)(0.5f * samples.rows); + for(int i = 0; i < trainSamplesCount; i++) + trainSamplesMask[i] = 1; + RNG rng(0); + for(size_t i = 0; i < trainSamplesMask.size(); i++) + { + int i1 = rng(trainSamplesMask.size()); + int i2 = rng(trainSamplesMask.size()); + std::swap(trainSamplesMask[i1], trainSamplesMask[i2]); + } + + EM model0(3), model1(3); + Mat samples0, samples1; + for(int i = 0; i < samples.rows; i++) + { + if(trainSamplesMask[i]) + { + Mat sample = samples.row(i); + int resp = (int)responses.at(i); + if(resp == 0) + samples0.push_back(sample); + else + samples1.push_back(sample); + } + } + model0.train(samples0); + model1.train(samples1); + + Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)), + testConfusionMat(2, 2, CV_32SC1, Scalar(0)); + const double lambda = 1.; + for(int i = 0; i < samples.rows; i++) + { + double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0; + Mat sample = samples.row(i); + model0.predict(sample, noArray(), &sampleLogLikelihoods0); + model1.predict(sample, noArray(), &sampleLogLikelihoods1); + + int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1; + + if(trainSamplesMask[i]) + trainConfusionMat.at((int)responses.at(i), classID)++; + else + testConfusionMat.at((int)responses.at(i), classID)++; + } +// std::cout << trainConfusionMat << std::endl; +// std::cout << testConfusionMat << std::endl; + + double trainError = (double)(trainConfusionMat.at(1,0) + trainConfusionMat.at(0,1)) / trainSamplesCount; + double testError = (double)(testConfusionMat.at(1,0) + testConfusionMat.at(0,1)) / (samples.rows - trainSamplesCount); + const double maxTrainError = 0.16; + const double maxTestError = 0.19; + + int code = cvtest::TS::OK; + if(trainError > maxTrainError) + { + ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError); + code = cvtest::TS::FAIL_INVALID_TEST_DATA; + } + if(testError > maxTestError) + { + ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", trainError, maxTrainError); + code = cvtest::TS::FAIL_INVALID_TEST_DATA; + } + + ts->set_failed_test_info(code); + } +}; + TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); } TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); } TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); } TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); } +TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); } +