mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 19:20:28 +08:00
integrated parallel SVM prediction; fixed warnings after meanshift integration
This commit is contained in:
parent
537a36115f
commit
17a2480a21
@ -543,7 +543,8 @@ public:
|
|||||||
bool balanced=false );
|
bool balanced=false );
|
||||||
|
|
||||||
virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
|
virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
|
||||||
|
virtual float predict( const CvMat* samples, CvMat* results ) const;
|
||||||
|
|
||||||
#ifndef SWIG
|
#ifndef SWIG
|
||||||
CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
|
CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
|
||||||
const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
|
const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
|
||||||
@ -563,7 +564,7 @@ public:
|
|||||||
CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
|
CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
|
||||||
CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
|
CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
|
||||||
bool balanced=false);
|
bool balanced=false);
|
||||||
CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
|
CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CV_WRAP virtual int get_support_vector_count() const;
|
CV_WRAP virtual int get_support_vector_count() const;
|
||||||
|
@ -2081,7 +2081,7 @@ float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
|
|||||||
CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
|
CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
|
||||||
class_count, 0, &row_sample ));
|
class_count, 0, &row_sample ));
|
||||||
result = predict( row_sample, get_var_count(), returnDFVal );
|
result = predict( row_sample, get_var_count(), returnDFVal );
|
||||||
|
|
||||||
__END__;
|
__END__;
|
||||||
|
|
||||||
if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
|
if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
|
||||||
@ -2090,6 +2090,44 @@ float CvSVM::predict( const CvMat* sample, bool returnDFVal ) const
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct predict_body {
|
||||||
|
predict_body(const CvSVM* _pointer, float* _result, const CvMat* _samples, CvMat* _results)
|
||||||
|
{
|
||||||
|
pointer = _pointer;
|
||||||
|
result = _result;
|
||||||
|
samples = _samples;
|
||||||
|
results = _results;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CvSVM* pointer;
|
||||||
|
float* result;
|
||||||
|
const CvMat* samples;
|
||||||
|
CvMat* results;
|
||||||
|
|
||||||
|
void operator()( const cv::BlockedRange& range ) const
|
||||||
|
{
|
||||||
|
for(int i = range.begin(); i < range.end(); i++ )
|
||||||
|
{
|
||||||
|
CvMat sample;
|
||||||
|
cvGetRow( samples, &sample, i );
|
||||||
|
int r = (int)pointer->predict(&sample);
|
||||||
|
if (results)
|
||||||
|
results->data.fl[i] = r;
|
||||||
|
if (i == 0)
|
||||||
|
*result = r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
float CvSVM::predict(const CvMat* samples, CV_OUT CvMat* results) const
|
||||||
|
{
|
||||||
|
float result = 0;
|
||||||
|
cv::parallel_for(cv::BlockedRange(0, samples->rows),
|
||||||
|
predict_body(this, &result, samples, results)
|
||||||
|
);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
CvSVM::CvSVM( const Mat& _train_data, const Mat& _responses,
|
CvSVM::CvSVM( const Mat& _train_data, const Mat& _responses,
|
||||||
const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
|
const Mat& _var_idx, const Mat& _sample_idx, CvSVMParams _params )
|
||||||
|
@ -1396,7 +1396,7 @@ int createSchedule(const CvLSVMFeaturePyramid *H, const CvLSVMFilterObject **all
|
|||||||
const int threadsNum, int *kLevels, int **processingLevels)
|
const int threadsNum, int *kLevels, int **processingLevels)
|
||||||
{
|
{
|
||||||
int rootFilterDim, sumPartFiltersDim, i, numLevels, dbx, dby, numDotProducts;
|
int rootFilterDim, sumPartFiltersDim, i, numLevels, dbx, dby, numDotProducts;
|
||||||
int averNumDotProd, j, minValue, argMin, tmp, lambda, maxValue, k;
|
int averNumDotProd, j, minValue, argMin, lambda, maxValue, k;
|
||||||
int *dotProd, *weights, *disp;
|
int *dotProd, *weights, *disp;
|
||||||
if (H == NULL || all_F == NULL)
|
if (H == NULL || all_F == NULL)
|
||||||
{
|
{
|
||||||
|
@ -44,11 +44,17 @@
|
|||||||
using namespace cv;
|
using namespace cv;
|
||||||
|
|
||||||
MeanshiftGrouping::MeanshiftGrouping(const Point3d& densKer, const vector<Point3d>& posV,
|
MeanshiftGrouping::MeanshiftGrouping(const Point3d& densKer, const vector<Point3d>& posV,
|
||||||
const vector<double>& wV, double modeEps, int maxIter):
|
const vector<double>& wV, double modeEps, int maxIter)
|
||||||
densityKernel(densKer), weightsV(wV), positionsV(posV), positionsCount(posV.size()),
|
|
||||||
meanshiftV(positionsCount), distanceV(positionsCount), modeEps(modeEps),
|
|
||||||
iterMax (maxIter)
|
|
||||||
{
|
{
|
||||||
|
densityKernel = densKer;
|
||||||
|
weightsV = wV;
|
||||||
|
positionsV = posV;
|
||||||
|
positionsCount = posV.size();
|
||||||
|
meanshiftV.resize(positionsCount);
|
||||||
|
distanceV.resize(positionsCount);
|
||||||
|
modeEps = modeEps;
|
||||||
|
iterMax = maxIter;
|
||||||
|
|
||||||
for (unsigned i=0; i<positionsV.size(); i++)
|
for (unsigned i=0; i<positionsV.size(); i++)
|
||||||
{
|
{
|
||||||
meanshiftV[i] = getNewValue(positionsV[i]);
|
meanshiftV[i] = getNewValue(positionsV[i]);
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
void help()
|
void help()
|
||||||
{
|
{
|
||||||
printf("\nThe sample demonstrates how to train Random Trees classifier\n"
|
printf("\nThe sample demonstrates how to train Random Trees classifier\n"
|
||||||
"(or Boosting classifier, or MLP, or Knearest, or Nbayes - see main()) using the provided dataset.\n"
|
"(or Boosting classifier, or MLP, or Knearest, or Nbayes, or Support Vector Machines - see main()) using the provided dataset.\n"
|
||||||
"\n"
|
"\n"
|
||||||
"We use the sample database letter-recognition.data\n"
|
"We use the sample database letter-recognition.data\n"
|
||||||
"from UCI Repository, here is the link:\n"
|
"from UCI Repository, here is the link:\n"
|
||||||
@ -28,7 +28,7 @@ void help()
|
|||||||
"The usage: letter_recog [-data <path to letter-recognition.data>] \\\n"
|
"The usage: letter_recog [-data <path to letter-recognition.data>] \\\n"
|
||||||
" [-save <output XML file for the classifier>] \\\n"
|
" [-save <output XML file for the classifier>] \\\n"
|
||||||
" [-load <XML file with the pre-trained classifier>] \\\n"
|
" [-load <XML file with the pre-trained classifier>] \\\n"
|
||||||
" [-boost|-mlp|-knearest|-nbayes] # to use boost/mlp/knearest classifier instead of default Random Trees\n" );
|
" [-boost|-mlp|-knearest|-nbayes|-svm] # to use boost/mlp/knearest/SVM classifier instead of default Random Trees\n" );
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function reads data and responses from the file <filename>
|
// This function reads data and responses from the file <filename>
|
||||||
@ -630,6 +630,78 @@ int build_nbayes_classifier( char* data_filename )
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static
|
||||||
|
int build_svm_classifier( char* data_filename )
|
||||||
|
{
|
||||||
|
CvMat* data = 0;
|
||||||
|
CvMat* responses = 0;
|
||||||
|
CvMat train_data;
|
||||||
|
int nsamples_all = 0, ntrain_samples = 0;
|
||||||
|
int var_count;
|
||||||
|
CvSVM svm;
|
||||||
|
|
||||||
|
int ok = read_num_class_data( data_filename, 16, &data, &responses );
|
||||||
|
if( !ok )
|
||||||
|
{
|
||||||
|
printf( "Could not read the database %s\n", data_filename );
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
////////// SVM parameters ///////////////////////////////
|
||||||
|
CvSVMParams param;
|
||||||
|
param.kernel_type=CvSVM::LINEAR;
|
||||||
|
param.svm_type=CvSVM::C_SVC;
|
||||||
|
param.C=1;
|
||||||
|
///////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
printf( "The database %s is loaded.\n", data_filename );
|
||||||
|
nsamples_all = data->rows;
|
||||||
|
ntrain_samples = (int)(nsamples_all*0.1);
|
||||||
|
var_count = data->cols;
|
||||||
|
|
||||||
|
// train classifier
|
||||||
|
printf( "Training the classifier (may take a few minutes)...\n");
|
||||||
|
cvGetRows( data, &train_data, 0, ntrain_samples );
|
||||||
|
CvMat* train_resp = cvCreateMat( ntrain_samples, 1, CV_32FC1);
|
||||||
|
for (int i = 0; i < ntrain_samples; i++)
|
||||||
|
train_resp->data.fl[i] = responses->data.fl[i];
|
||||||
|
svm.train(&train_data, train_resp, 0, 0, param);
|
||||||
|
|
||||||
|
// classification
|
||||||
|
float _sample[var_count * (nsamples_all - ntrain_samples)];
|
||||||
|
CvMat sample = cvMat( nsamples_all - ntrain_samples, 16, CV_32FC1, _sample );
|
||||||
|
float true_results[nsamples_all - ntrain_samples];
|
||||||
|
for (int j = ntrain_samples; j < nsamples_all; j++)
|
||||||
|
{
|
||||||
|
float *s = data->data.fl + j * var_count;
|
||||||
|
|
||||||
|
for (int i = 0; i < var_count; i++)
|
||||||
|
{
|
||||||
|
sample.data.fl[(j - ntrain_samples) * var_count + i] = s[i];
|
||||||
|
}
|
||||||
|
true_results[j - ntrain_samples] = responses->data.fl[j];
|
||||||
|
}
|
||||||
|
CvMat *result = cvCreateMat(1, nsamples_all - ntrain_samples, CV_32FC1);
|
||||||
|
|
||||||
|
printf("Classification (may take a few minutes)...\n");
|
||||||
|
(int)svm.predict(&sample, result);
|
||||||
|
|
||||||
|
int true_resp = 0;
|
||||||
|
for (int i = 0; i < nsamples_all - ntrain_samples; i++)
|
||||||
|
{
|
||||||
|
if (result->data.fl[i] == true_results[i])
|
||||||
|
true_resp++;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("true_resp = %f%%\n", (float)true_resp / (nsamples_all - ntrain_samples) * 100);
|
||||||
|
|
||||||
|
cvReleaseMat( &train_resp );
|
||||||
|
cvReleaseMat( &result );
|
||||||
|
cvReleaseMat( &data );
|
||||||
|
cvReleaseMat( &responses );
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
int main( int argc, char *argv[] )
|
int main( int argc, char *argv[] )
|
||||||
{
|
{
|
||||||
char* filename_to_save = 0;
|
char* filename_to_save = 0;
|
||||||
@ -672,6 +744,10 @@ int main( int argc, char *argv[] )
|
|||||||
{
|
{
|
||||||
method = 4;
|
method = 4;
|
||||||
}
|
}
|
||||||
|
else if ( strcmp(argv[i], "-svm") == 0)
|
||||||
|
{
|
||||||
|
method = 5;
|
||||||
|
}
|
||||||
else
|
else
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -687,6 +763,8 @@ int main( int argc, char *argv[] )
|
|||||||
build_knearest_classifier( data_filename, 10 ) :
|
build_knearest_classifier( data_filename, 10 ) :
|
||||||
method == 4 ?
|
method == 4 ?
|
||||||
build_nbayes_classifier( data_filename) :
|
build_nbayes_classifier( data_filename) :
|
||||||
|
method == 5 ?
|
||||||
|
build_svm_classifier( data_filename ):
|
||||||
-1) < 0)
|
-1) < 0)
|
||||||
{
|
{
|
||||||
help();
|
help();
|
||||||
|
Loading…
Reference in New Issue
Block a user