mirror of
https://github.com/opencv/opencv.git
synced 2025-06-06 00:43:52 +08:00
ml: fix adjusting K in KNearest (#12358)
This commit is contained in:
parent
4b03a4a841
commit
e13f6ded7f
@ -140,13 +140,12 @@ public:
|
||||
String getModelName() const CV_OVERRIDE { return NAME_BRUTE_FORCE; }
|
||||
int getType() const CV_OVERRIDE { return ml::KNearest::BRUTE_FORCE; }
|
||||
|
||||
void findNearestCore( const Mat& _samples, int k0, const Range& range,
|
||||
void findNearestCore( const Mat& _samples, int k, const Range& range,
|
||||
Mat* results, Mat* neighbor_responses,
|
||||
Mat* dists, float* presult ) const
|
||||
{
|
||||
int testidx, baseidx, i, j, d = samples.cols, nsamples = samples.rows;
|
||||
int testcount = range.end - range.start;
|
||||
int k = std::min(k0, nsamples);
|
||||
|
||||
AutoBuffer<float> buf(testcount*k*2);
|
||||
float* dbuf = buf.data();
|
||||
@ -215,7 +214,7 @@ public:
|
||||
float* nr = neighbor_responses->ptr<float>(testidx + range.start);
|
||||
for( j = 0; j < k; j++ )
|
||||
nr[j] = rbuf[testidx*k + j];
|
||||
for( ; j < k0; j++ )
|
||||
for( ; j < k; j++ )
|
||||
nr[j] = 0.f;
|
||||
}
|
||||
|
||||
@ -224,7 +223,7 @@ public:
|
||||
float* dptr = dists->ptr<float>(testidx + range.start);
|
||||
for( j = 0; j < k; j++ )
|
||||
dptr[j] = dbuf[testidx*k + j];
|
||||
for( ; j < k0; j++ )
|
||||
for( ; j < k; j++ )
|
||||
dptr[j] = 0.f;
|
||||
}
|
||||
|
||||
@ -307,6 +306,7 @@ public:
|
||||
{
|
||||
float result = 0.f;
|
||||
CV_Assert( 0 < k );
|
||||
k = std::min(k, samples.rows);
|
||||
|
||||
Mat test_samples = _samples.getMat();
|
||||
CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
|
||||
@ -363,6 +363,7 @@ public:
|
||||
{
|
||||
float result = 0.f;
|
||||
CV_Assert( 0 < k );
|
||||
k = std::min(k, samples.rows);
|
||||
|
||||
Mat test_samples = _samples.getMat();
|
||||
CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
|
||||
|
@ -702,4 +702,26 @@ TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
||||
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
|
||||
TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }
|
||||
|
||||
TEST(ML_KNearest, regression_12347)
|
||||
{
|
||||
Mat xTrainData = (Mat_<float>(5,2) << 1, 1.1, 1.1, 1, 2, 2, 2.1, 2, 2.1, 2.1);
|
||||
Mat yTrainLabels = (Mat_<float>(5,1) << 1, 1, 2, 2, 2);
|
||||
Ptr<KNearest> knn = KNearest::create();
|
||||
knn->train(xTrainData, ml::ROW_SAMPLE, yTrainLabels);
|
||||
|
||||
Mat xTestData = (Mat_<float>(2,2) << 1.1, 1.1, 2, 2.2);
|
||||
Mat zBestLabels, neighbours, dist;
|
||||
// check output shapes:
|
||||
int K = 16, Kexp = std::min(K, xTrainData.rows);
|
||||
knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
|
||||
EXPECT_EQ(xTestData.rows, zBestLabels.rows);
|
||||
EXPECT_EQ(neighbours.cols, Kexp);
|
||||
EXPECT_EQ(dist.cols, Kexp);
|
||||
// see if the result is still correct:
|
||||
K = 2;
|
||||
knn->findNearest(xTestData, K, zBestLabels, neighbours, dist);
|
||||
EXPECT_EQ(1, zBestLabels.at<float>(0,0));
|
||||
EXPECT_EQ(2, zBestLabels.at<float>(1,0));
|
||||
}
|
||||
|
||||
}} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user