From 45e0e5f8e947a9c1b4a995477c43a006ec2df43f Mon Sep 17 00:00:00 2001 From: Pierre-Emmanuel Viel Date: Tue, 17 Dec 2013 12:51:58 +0100 Subject: [PATCH] Pick centers in KMeans++ with a probability proportional to their distance^2, instead of simple distance, to previous centers --- .../opencv2/flann/hierarchical_clustering_index.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/modules/flann/include/opencv2/flann/hierarchical_clustering_index.h b/modules/flann/include/opencv2/flann/hierarchical_clustering_index.h index ce2d622450..02fc278448 100644 --- a/modules/flann/include/opencv2/flann/hierarchical_clustering_index.h +++ b/modules/flann/include/opencv2/flann/hierarchical_clustering_index.h @@ -210,8 +210,11 @@ private: assert(index >=0 && index < n); centers[0] = dsindices[index]; + // Computing distance^2 will have the advantage of even higher probability further to pick new centers + // far from previous centers (and this complies to "k-means++: the advantages of careful seeding" article) for (int i = 0; i < n; i++) { closestDistSq[i] = distance(dataset[dsindices[i]], dataset[dsindices[index]], dataset.cols); + closestDistSq[i] *= closestDistSq[i]; currentPot += closestDistSq[i]; } @@ -237,7 +240,10 @@ private: // Compute the new potential double newPot = 0; - for (int i = 0; i < n; i++) newPot += std::min( distance(dataset[dsindices[i]], dataset[dsindices[index]], dataset.cols), closestDistSq[i] ); + for (int i = 0; i < n; i++) { + DistanceType dist = distance(dataset[dsindices[i]], dataset[dsindices[index]], dataset.cols); + newPot += std::min( dist*dist, closestDistSq[i] ); + } // Store the best result if ((bestNewPot < 0)||(newPot < bestNewPot)) { @@ -249,7 +255,10 @@ private: // Add the appropriate center centers[centerCount] = dsindices[bestNewIndex]; currentPot = bestNewPot; - for (int i = 0; i < n; i++) closestDistSq[i] = std::min( distance(dataset[dsindices[i]], dataset[dsindices[bestNewIndex]], dataset.cols), closestDistSq[i] ); + for (int i = 0; i < n; i++) { + DistanceType dist = distance(dataset[dsindices[i]], dataset[dsindices[bestNewIndex]], dataset.cols); + closestDistSq[i] = std::min( dist*dist, closestDistSq[i] ); + } } centers_length = centerCount;