diff --git a/modules/flann/include/opencv2/flann/kmeans_index.h b/modules/flann/include/opencv2/flann/kmeans_index.h index 3bf12047cd..460dc64be9 100644 --- a/modules/flann/include/opencv2/flann/kmeans_index.h +++ b/modules/flann/include/opencv2/flann/kmeans_index.h @@ -53,6 +53,62 @@ namespace cvflann { +template +struct squareDistance +{ + typedef typename Distance::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist*dist; } +}; + + +template +struct squareDistance, ElementType> +{ + typedef typename L2_Simple::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + +template +struct squareDistance, ElementType> +{ + typedef typename L2::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + + +template +struct squareDistance, ElementType> +{ + typedef typename MinkowskiDistance::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + +template +struct squareDistance, ElementType> +{ + typedef typename HellingerDistance::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + +template +struct squareDistance, ElementType> +{ + typedef typename ChiSquareDistance::ResultType ResultType; + ResultType operator()( ResultType dist ) { return dist; } +}; + + +template +typename Distance::ResultType ensureSquareDistance( typename Distance::ResultType dist ) +{ + typedef typename Distance::ElementType ElementType; + + squareDistance dummy; + return dummy( dist ); +} + + + struct KMeansIndexParams : public IndexParams { KMeansIndexParams(int branching = 32, int iterations = 11, @@ -211,7 +267,7 @@ public: for (int i = 0; i < n; i++) { closestDistSq[i] = distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols); - closestDistSq[i] *= closestDistSq[i]; + closestDistSq[i] = ensureSquareDistance( closestDistSq[i] ); currentPot += closestDistSq[i]; } @@ -239,7 +295,7 @@ public: double newPot = 0; for (int i = 0; i < n; i++) { DistanceType dist = distance_(dataset_[indices[i]], dataset_[indices[index]], dataset_.cols); - newPot += std::min( dist*dist, closestDistSq[i] ); + newPot += std::min( ensureSquareDistance(dist), closestDistSq[i] ); } // Store the best result @@ -254,7 +310,7 @@ public: currentPot = bestNewPot; for (int i = 0; i < n; i++) { DistanceType dist = distance_(dataset_[indices[i]], dataset_[indices[bestNewIndex]], dataset_.cols); - closestDistSq[i] = std::min( dist*dist, closestDistSq[i] ); + closestDistSq[i] = std::min( ensureSquareDistance(dist), closestDistSq[i] ); } }