diff --git a/modules/ml/src/knearest.cpp b/modules/ml/src/knearest.cpp index d608012dfb..df48b00f24 100644 --- a/modules/ml/src/knearest.cpp +++ b/modules/ml/src/knearest.cpp @@ -140,13 +140,12 @@ public: String getModelName() const CV_OVERRIDE { return NAME_BRUTE_FORCE; } int getType() const CV_OVERRIDE { return ml::KNearest::BRUTE_FORCE; } - void findNearestCore( const Mat& _samples, int k0, const Range& range, + void findNearestCore( const Mat& _samples, int k, const Range& range, Mat* results, Mat* neighbor_responses, Mat* dists, float* presult ) const { int testidx, baseidx, i, j, d = samples.cols, nsamples = samples.rows; int testcount = range.end - range.start; - int k = std::min(k0, nsamples); AutoBuffer buf(testcount*k*2); float* dbuf = buf.data(); @@ -215,7 +214,7 @@ public: float* nr = neighbor_responses->ptr(testidx + range.start); for( j = 0; j < k; j++ ) nr[j] = rbuf[testidx*k + j]; - for( ; j < k0; j++ ) + for( ; j < k; j++ ) nr[j] = 0.f; } @@ -224,7 +223,7 @@ public: float* dptr = dists->ptr(testidx + range.start); for( j = 0; j < k; j++ ) dptr[j] = dbuf[testidx*k + j]; - for( ; j < k0; j++ ) + for( ; j < k; j++ ) dptr[j] = 0.f; } @@ -307,6 +306,7 @@ public: { float result = 0.f; CV_Assert( 0 < k ); + k = std::min(k, samples.rows); Mat test_samples = _samples.getMat(); CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols ); @@ -363,6 +363,7 @@ public: { float result = 0.f; CV_Assert( 0 < k ); + k = std::min(k, samples.rows); Mat test_samples = _samples.getMat(); CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols ); diff --git a/modules/ml/test/test_emknearestkmeans.cpp b/modules/ml/test/test_emknearestkmeans.cpp index 6755c2e9e4..691815c52a 100644 --- a/modules/ml/test/test_emknearestkmeans.cpp +++ b/modules/ml/test/test_emknearestkmeans.cpp @@ -702,4 +702,26 @@ TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); } TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); } TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); } +TEST(ML_KNearest, regression_12347) +{ + Mat xTrainData = (Mat_(5,2) << 1, 1.1, 1.1, 1, 2, 2, 2.1, 2, 2.1, 2.1); + Mat yTrainLabels = (Mat_(5,1) << 1, 1, 2, 2, 2); + Ptr knn = KNearest::create(); + knn->train(xTrainData, ml::ROW_SAMPLE, yTrainLabels); + + Mat xTestData = (Mat_(2,2) << 1.1, 1.1, 2, 2.2); + Mat zBestLabels, neighbours, dist; + // check output shapes: + int K = 16, Kexp = std::min(K, xTrainData.rows); + knn->findNearest(xTestData, K, zBestLabels, neighbours, dist); + EXPECT_EQ(xTestData.rows, zBestLabels.rows); + EXPECT_EQ(neighbours.cols, Kexp); + EXPECT_EQ(dist.cols, Kexp); + // see if the result is still correct: + K = 2; + knn->findNearest(xTestData, K, zBestLabels, neighbours, dist); + EXPECT_EQ(1, zBestLabels.at(0,0)); + EXPECT_EQ(2, zBestLabels.at(1,0)); +} + }} // namespace