handle single-point sets in kmeans properly

This commit is contained in:
Vadim Pisarevsky 2012-04-06 13:22:08 +00:00
parent 71d8769314
commit 61c7c441b9
2 changed files with 20 additions and 13 deletions

View File

@ -2459,8 +2459,9 @@ double cv::kmeans( InputArray _data, int K,
{ {
const int SPP_TRIALS = 3; const int SPP_TRIALS = 3;
Mat data = _data.getMat(); Mat data = _data.getMat();
int N = data.rows > 1 ? data.rows : data.cols; bool isrow = data.rows == 1 && data.channels() > 1;
int dims = (data.rows > 1 ? data.cols : 1)*data.channels(); int N = !isrow ? data.rows : data.cols;
int dims = (!isrow ? data.cols : 1)*data.channels();
int type = data.depth(); int type = data.depth();
attempts = std::max(attempts, 1); attempts = std::max(attempts, 1);

View File

@ -2437,42 +2437,48 @@ public:
protected: protected:
void run(int) void run(int)
{ {
int i, iter = 0, N = 0, N0 = 0, K = 0, dims = 0;
Mat labels;
try try
{ {
RNG& rng = theRNG(); RNG& rng = theRNG();
const int MAX_DIM=5; const int MAX_DIM=5;
int MAX_POINTS = 100; int MAX_POINTS = 100, maxIter = 100;
for( int iter = 0; iter < 100; iter++ ) for( iter = 0; iter < maxIter; iter++ )
{ {
ts->update_context(this, iter, true); ts->update_context(this, iter, true);
int dims = rng.uniform(1, MAX_DIM+1); dims = rng.uniform(1, MAX_DIM+1);
int N = rng.uniform(1, MAX_POINTS+1); N = rng.uniform(1, MAX_POINTS+1);
int N0 = rng.uniform(1, N/10+1); N0 = rng.uniform(1, MAX(N/10, 2));
int K = rng.uniform(1, N+1); K = rng.uniform(1, N+1);
Mat data0(N0, dims, CV_32F), labels; Mat data0(N0, dims, CV_32F);
rng.fill(data0, RNG::UNIFORM, -1, 1); rng.fill(data0, RNG::UNIFORM, -1, 1);
Mat data(N, dims, CV_32F); Mat data(N, dims, CV_32F);
for( int i = 0; i < N; i++ ) for( i = 0; i < N; i++ )
data0.row(rng.uniform(0, N0)).copyTo(data.row(i)); data0.row(rng.uniform(0, N0)).copyTo(data.row(i));
kmeans(data, K, labels, TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 30, 0), kmeans(data, K, labels, TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 30, 0),
5, KMEANS_PP_CENTERS); 5, KMEANS_PP_CENTERS);
Mat hist(K, 1, CV_32S, Scalar(0)); Mat hist(K, 1, CV_32S, Scalar(0));
for( int i = 0; i < N; i++ ) for( i = 0; i < N; i++ )
{ {
int l = labels.at<int>(i); int l = labels.at<int>(i);
CV_Assert( 0 <= l && l < K ); CV_Assert(0 <= l && l < K);
hist.at<int>(l)++; hist.at<int>(l)++;
} }
for( int i = 0; i < K; i++ ) for( i = 0; i < K; i++ )
CV_Assert( hist.at<int>(i) != 0 ); CV_Assert( hist.at<int>(i) != 0 );
} }
} }
catch(...) catch(...)
{ {
ts->printf(cvtest::TS::LOG,
"context: iteration=%d, N=%d, N0=%d, K=%d\n",
iter, N, N0, K);
std::cout << labels << std::endl;
ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH); ts->set_failed_test_info(cvtest::TS::FAIL_MISMATCH);
} }
} }