2011-02-23 04:43:26 +08:00
K Nearest Neighbors
===================
2011-05-16 03:15:36 +08:00
The algorithm caches all training samples and predicts the response for a new sample by analyzing a certain number (
2011-02-23 04:43:26 +08:00
**K**
2011-05-16 03:15:36 +08:00
) of the nearest neighbors of the sample (using voting, calculating weighted sum, and so on). The method is sometimes referred to as "learning by example" because for prediction it looks for the feature vector with a known response that is closest to the given vector.
2011-02-23 04:43:26 +08:00
.. index :: CvKNearest
.. _CvKNearest:
CvKNearest
----------
2011-03-01 05:26:43 +08:00
.. c:type :: CvKNearest
2011-02-23 04:43:26 +08:00
2011-05-16 03:15:36 +08:00
K-Nearest Neighbors model ::
2011-02-23 04:43:26 +08:00
class CvKNearest : public CvStatModel
{
public:
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
CvKNearest();
virtual ~CvKNearest();
2011-02-26 19:05:10 +08:00
2011-06-09 09:16:45 +08:00
CvKNearest( const Mat& _train_data, const Mat& _responses,
const Mat& _sample_idx=Mat(), bool _is_regression=false, int max_k=32 );
2011-02-26 19:05:10 +08:00
2011-06-09 09:16:45 +08:00
virtual bool train( const Mat& _train_data, const Mat& _responses,
const Mat& _sample_idx=Mat(), bool is_regression=false,
2011-02-23 04:43:26 +08:00
int _max_k=32, bool _update_base=false );
2011-02-26 19:05:10 +08:00
2011-06-09 09:16:45 +08:00
virtual float find_nearest( const Mat& _samples, int k, Mat* results=0,
const float** neighbors=0, Mat* neighbor_responses=0, Mat* dist=0 ) const;
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
virtual void clear();
int get_max_k() const;
int get_var_count() const;
int get_sample_count() const;
bool is_regression() const;
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
protected:
...
};
2011-03-03 15:29:55 +08:00
2011-02-23 04:43:26 +08:00
.. index :: CvKNearest::train
.. _CvKNearest :: train:
CvKNearest::train
-----------------
2011-06-09 09:16:45 +08:00
.. cpp:function :: bool CvKNearest::train( const Mat& _train_data, const Mat& _responses, const Mat& _sample_idx=Mat(), bool is_regression=false, int _max_k=32, bool _update_base=false )
2011-02-23 04:43:26 +08:00
Trains the model.
2011-05-16 03:15:36 +08:00
The method trains the K-Nearest model. It follows the conventions of the generic `` train `` "method" with the following limitations:
* Only `` CV_ROW_SAMPLE `` data layout is supported.
* Input variables are all ordered.
* Output variables can be either categorical ( `` is_regression=false `` ) or ordered ( `` is_regression=true `` ).
* Variable subsets ( `` var_idx `` ) and missing measurements are not supported.
2011-02-23 04:43:26 +08:00
2011-02-26 19:05:10 +08:00
The parameter `` _max_k `` specifies the number of maximum neighbors that may be passed to the method `` find_nearest `` .
2011-02-23 04:43:26 +08:00
2011-02-26 19:05:10 +08:00
The parameter `` _update_base `` specifies whether the model is trained from scratch
2011-05-16 03:15:36 +08:00
( `` _update_base=false `` ), or it is updated using the new training data ( `` _update_base=true `` ). In the latter case, the parameter `` _max_k `` must not be larger than the original value.
2011-02-23 04:43:26 +08:00
.. index :: CvKNearest::find_nearest
.. _CvKNearest :: find_nearest:
CvKNearest::find_nearest
------------------------
2011-06-09 09:16:45 +08:00
.. cpp:function :: float CvKNearest::find_nearest( const Mat& _samples, int k, Mat* results=0, const float* * neighbors=0, Mat* neighbor_responses=0, Mat* dist=0 ) const
2011-02-23 04:43:26 +08:00
2011-05-16 03:15:36 +08:00
Finds the neighbors for input vectors.
2011-02-23 04:43:26 +08:00
2011-05-16 03:15:36 +08:00
For each input vector (a row of the matrix `` _samples `` ), the method finds the
2011-02-23 04:43:26 +08:00
:math:`\texttt{k} \le
2011-05-16 03:15:36 +08:00
\texttt{get\_max\_k()}` nearest neighbor. In case of regression,
the predicted result is a mean value of the particular vector's
neighbor responses. In case of classification, the class is determined
2011-02-23 04:43:26 +08:00
by voting.
2011-05-16 03:15:36 +08:00
For a custom classification/regression prediction, the method can optionally return pointers to the neighbor vectors themselves ( `` neighbors `` , an array of `` k*_samples->rows `` pointers), their corresponding output values ( `` neighbor_responses `` , a vector of `` k*_samples->rows `` elements), and the distances from the input vectors to the neighbors ( `` dist `` , also a vector of `` k*_samples->rows `` elements).
2011-02-23 04:43:26 +08:00
2011-05-16 03:15:36 +08:00
For each input vector, the neighbors are sorted by their distances to the vector.
2011-02-23 04:43:26 +08:00
2011-06-09 09:16:45 +08:00
If only a single input vector is passed, all output matrices are optional and the predicted value is returned by the method.
The sample below (currently using the obsolete `` CvMat `` structures) demonstrates the use of the k-nearest classifier for 2D point classification ::
2011-02-23 04:43:26 +08:00
#include "ml.h"
#include "highgui.h"
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
int main( int argc, char** argv )
{
const int K = 10;
int i, j, k, accuracy;
float response;
int train_sample_count = 100;
CvRNG rng_state = cvRNG(-1);
CvMat* trainData = cvCreateMat( train_sample_count, 2, CV_32FC1 );
CvMat* trainClasses = cvCreateMat( train_sample_count, 1, CV_32FC1 );
IplImage* img = cvCreateImage( cvSize( 500, 500 ), 8, 3 );
float _sample[2];
CvMat sample = cvMat( 1, 2, CV_32FC1, _sample );
cvZero( img );
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
CvMat trainData1, trainData2, trainClasses1, trainClasses2;
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
// form the training samples
cvGetRows( trainData, &trainData1, 0, train_sample_count/2 );
cvRandArr( &rng_state, &trainData1, CV_RAND_NORMAL, cvScalar(200,200), cvScalar(50,50) );
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
cvGetRows( trainData, &trainData2, train_sample_count/2, train_sample_count );
cvRandArr( &rng_state, &trainData2, CV_RAND_NORMAL, cvScalar(300,300), cvScalar(50,50) );
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
cvGetRows( trainClasses, &trainClasses1, 0, train_sample_count/2 );
cvSet( &trainClasses1, cvScalar(1) );
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
cvGetRows( trainClasses, &trainClasses2, train_sample_count/2, train_sample_count );
cvSet( &trainClasses2, cvScalar(2) );
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
// learn classifier
CvKNearest knn( trainData, trainClasses, 0, false, K );
CvMat* nearests = cvCreateMat( 1, K, CV_32FC1);
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
for( i = 0; i < img->height; i++ )
{
for( j = 0; j < img->width; j++ )
{
sample.data.fl[0] = (float)j;
sample.data.fl[1] = (float)i;
2011-02-26 19:05:10 +08:00
2011-05-16 03:15:36 +08:00
// estimate the response and get the neighbors' labels
2011-02-23 04:43:26 +08:00
response = knn.find_nearest(&sample,K,0,0,nearests,0);
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
// compute the number of neighbors representing the majority
for( k = 0, accuracy = 0; k < K; k++ )
{
if( nearests->data.fl[k] == response)
accuracy++;
}
// highlight the pixel depending on the accuracy (or confidence)
cvSet2D( img, i, j, response == 1 ?
(accuracy > 5 ? CV_RGB(180,0,0) : CV_RGB(180,120,0)) :
(accuracy > 5 ? CV_RGB(0,180,0) : CV_RGB(120,120,0)) );
}
}
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
// display the original training samples
for( i = 0; i < train_sample_count/2; i++ )
{
CvPoint pt;
pt.x = cvRound(trainData1.data.fl[i*2]);
pt.y = cvRound(trainData1.data.fl[i*2+1]);
cvCircle( img, pt, 2, CV_RGB(255,0,0), CV_FILLED );
pt.x = cvRound(trainData2.data.fl[i*2]);
pt.y = cvRound(trainData2.data.fl[i*2+1]);
cvCircle( img, pt, 2, CV_RGB(0,255,0), CV_FILLED );
}
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
cvNamedWindow( "classifier result", 1 );
cvShowImage( "classifier result", img );
cvWaitKey(0);
2011-02-26 19:05:10 +08:00
2011-02-23 04:43:26 +08:00
cvReleaseMat( &trainClasses );
cvReleaseMat( &trainData );
return 0;
}
2011-03-03 15:29:55 +08:00
2011-02-23 04:43:26 +08:00