2011-02-10 04:55:11 +08:00
|
|
|
/*M///////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
|
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
|
|
|
//
|
|
|
|
// By downloading, copying, installing or using the software you agree to this license.
|
|
|
|
// If you do not agree to this license, do not download, install,
|
|
|
|
// copy or use the software.
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// Intel License Agreement
|
|
|
|
// For Open Source Computer Vision Library
|
|
|
|
//
|
|
|
|
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
|
|
|
// Third party copyrights are property of their respective owners.
|
|
|
|
//
|
|
|
|
// Redistribution and use in source and binary forms, with or without modification,
|
|
|
|
// are permitted provided that the following conditions are met:
|
|
|
|
//
|
|
|
|
// * Redistribution's of source code must retain the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer.
|
|
|
|
//
|
|
|
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer in the documentation
|
|
|
|
// and/or other materials provided with the distribution.
|
|
|
|
//
|
|
|
|
// * The name of Intel Corporation may not be used to endorse or promote products
|
|
|
|
// derived from this software without specific prior written permission.
|
|
|
|
//
|
|
|
|
// This software is provided by the copyright holders and contributors "as is" and
|
|
|
|
// any express or implied warranties, including, but not limited to, the implied
|
|
|
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
|
|
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
|
|
|
// indirect, incidental, special, exemplary, or consequential damages
|
|
|
|
// (including, but not limited to, procurement of substitute goods or services;
|
|
|
|
// loss of use, data, or profits; or business interruption) however caused
|
|
|
|
// and on any theory of liability, whether in contract, strict liability,
|
|
|
|
// or tort (including negligence or otherwise) arising in any way out of
|
|
|
|
// the use of this software, even if advised of the possibility of such damage.
|
|
|
|
//
|
|
|
|
//M*/
|
|
|
|
|
|
|
|
#include "test_precomp.hpp"
|
|
|
|
|
2017-12-15 22:40:08 +08:00
|
|
|
//#define GENERATE_TESTDATA
|
|
|
|
|
2017-11-05 21:48:40 +08:00
|
|
|
namespace opencv_test { namespace {
|
2011-02-10 04:55:11 +08:00
|
|
|
|
2013-03-21 00:13:46 +08:00
|
|
|
int str_to_svm_type(String& str)
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
|
|
|
if( !str.compare("C_SVC") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::C_SVC;
|
2011-02-10 04:55:11 +08:00
|
|
|
if( !str.compare("NU_SVC") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::NU_SVC;
|
2011-02-10 04:55:11 +08:00
|
|
|
if( !str.compare("ONE_CLASS") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::ONE_CLASS;
|
2011-02-10 04:55:11 +08:00
|
|
|
if( !str.compare("EPS_SVR") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::EPS_SVR;
|
2011-02-10 04:55:11 +08:00
|
|
|
if( !str.compare("NU_SVR") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::NU_SVR;
|
2011-02-10 04:55:11 +08:00
|
|
|
CV_Error( CV_StsBadArg, "incorrect svm type string" );
|
|
|
|
}
|
2013-03-21 00:13:46 +08:00
|
|
|
int str_to_svm_kernel_type( String& str )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
|
|
|
if( !str.compare("LINEAR") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::LINEAR;
|
2011-02-10 04:55:11 +08:00
|
|
|
if( !str.compare("POLY") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::POLY;
|
2011-02-10 04:55:11 +08:00
|
|
|
if( !str.compare("RBF") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::RBF;
|
2011-02-10 04:55:11 +08:00
|
|
|
if( !str.compare("SIGMOID") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return SVM::SIGMOID;
|
2011-02-10 04:55:11 +08:00
|
|
|
CV_Error( CV_StsBadArg, "incorrect svm type string" );
|
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
|
2011-02-10 04:55:11 +08:00
|
|
|
// 4. em
|
|
|
|
// 5. ann
|
2013-03-21 00:13:46 +08:00
|
|
|
int str_to_ann_train_method( String& str )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
|
|
|
if( !str.compare("BACKPROP") )
|
2015-02-11 18:24:14 +08:00
|
|
|
return ANN_MLP::BACKPROP;
|
2017-12-15 18:57:39 +08:00
|
|
|
if (!str.compare("RPROP"))
|
2015-02-11 18:24:14 +08:00
|
|
|
return ANN_MLP::RPROP;
|
2017-12-15 18:57:39 +08:00
|
|
|
if (!str.compare("ANNEAL"))
|
|
|
|
return ANN_MLP::ANNEAL;
|
2011-02-10 04:55:11 +08:00
|
|
|
CV_Error( CV_StsBadArg, "incorrect ann train method string" );
|
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
|
2017-11-05 21:48:40 +08:00
|
|
|
#if 0
|
2017-11-23 05:07:23 +08:00
|
|
|
int str_to_ann_activation_function(String& str)
|
|
|
|
{
|
|
|
|
if (!str.compare("IDENTITY"))
|
|
|
|
return ANN_MLP::IDENTITY;
|
|
|
|
if (!str.compare("SIGMOID_SYM"))
|
|
|
|
return ANN_MLP::SIGMOID_SYM;
|
|
|
|
if (!str.compare("GAUSSIAN"))
|
|
|
|
return ANN_MLP::GAUSSIAN;
|
|
|
|
if (!str.compare("RELU"))
|
|
|
|
return ANN_MLP::RELU;
|
|
|
|
if (!str.compare("LEAKYRELU"))
|
|
|
|
return ANN_MLP::LEAKYRELU;
|
|
|
|
CV_Error(CV_StsBadArg, "incorrect ann activation function string");
|
|
|
|
}
|
2017-11-05 21:48:40 +08:00
|
|
|
#endif
|
2017-11-23 05:07:23 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void ann_check_data( Ptr<TrainData> _data )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat values = _data->getSamples();
|
|
|
|
Mat var_idx = _data->getVarIdx();
|
|
|
|
int nvars = (int)var_idx.total();
|
|
|
|
if( nvars != 0 && nvars != values.cols )
|
2011-02-10 04:55:11 +08:00
|
|
|
CV_Error( CV_StsBadArg, "var_idx is not supported" );
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !_data->getMissing().empty() )
|
2011-02-10 04:55:11 +08:00
|
|
|
CV_Error( CV_StsBadArg, "missing values are not supported" );
|
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
|
|
|
|
// unroll the categorical responses to binary vectors
|
|
|
|
Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat train_sidx = _data->getTrainSampleIdx();
|
|
|
|
int* train_sidx_ptr = train_sidx.ptr<int>();
|
|
|
|
Mat responses = _data->getResponses();
|
2011-02-10 04:55:11 +08:00
|
|
|
int cls_count = 0;
|
|
|
|
// construct cls_map
|
|
|
|
cls_map.clear();
|
2014-07-30 03:54:23 +08:00
|
|
|
int nresponses = (int)responses.total();
|
|
|
|
int si, n = !train_sidx.empty() ? (int)train_sidx.total() : nresponses;
|
|
|
|
|
|
|
|
for( si = 0; si < n; si++ )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
|
|
|
|
int r = cvRound(responses.at<float>(sidx));
|
|
|
|
CV_DbgAssert( fabs(responses.at<float>(sidx) - r) < FLT_EPSILON );
|
|
|
|
map<int,int>::iterator it = cls_map.find(r);
|
|
|
|
if( it == cls_map.end() )
|
2011-02-10 04:55:11 +08:00
|
|
|
cls_map[r] = cls_count++;
|
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat new_responses = Mat::zeros( nresponses, cls_count, CV_32F );
|
|
|
|
for( si = 0; si < n; si++ )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int sidx = train_sidx_ptr ? train_sidx_ptr[si] : si;
|
|
|
|
int r = cvRound(responses.at<float>(sidx));
|
2011-02-10 04:55:11 +08:00
|
|
|
int cidx = cls_map[r];
|
2014-07-30 03:54:23 +08:00
|
|
|
new_responses.at<float>(sidx, cidx) = 1.f;
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
return new_responses;
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
|
|
|
|
float ann_calc_error( Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2011-02-10 04:55:11 +08:00
|
|
|
float err = 0;
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat samples = _data->getSamples();
|
|
|
|
Mat responses = _data->getResponses();
|
|
|
|
Mat sample_idx = (type == CV_TEST_ERROR) ? _data->getTestSampleIdx() : _data->getTrainSampleIdx();
|
|
|
|
int* sidx = !sample_idx.empty() ? sample_idx.ptr<int>() : 0;
|
|
|
|
ann_check_data( _data );
|
|
|
|
int sample_count = (int)sample_idx.total();
|
|
|
|
sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? samples.rows : sample_count;
|
2011-02-10 04:55:11 +08:00
|
|
|
float* pred_resp = 0;
|
|
|
|
vector<float> innresp;
|
|
|
|
if( sample_count > 0 )
|
|
|
|
{
|
|
|
|
if( resp_labels )
|
|
|
|
{
|
|
|
|
resp_labels->resize( sample_count );
|
|
|
|
pred_resp = &((*resp_labels)[0]);
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
innresp.resize( sample_count );
|
|
|
|
pred_resp = &(innresp[0]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
int cls_count = (int)cls_map.size();
|
|
|
|
Mat output( 1, cls_count, CV_32FC1 );
|
2014-07-30 03:54:23 +08:00
|
|
|
|
2011-02-10 04:55:11 +08:00
|
|
|
for( int i = 0; i < sample_count; i++ )
|
|
|
|
{
|
|
|
|
int si = sidx ? sidx[i] : i;
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat sample = samples.row(si);
|
|
|
|
ann->predict( sample, output );
|
|
|
|
Point best_cls;
|
|
|
|
minMaxLoc(output, 0, 0, 0, &best_cls, 0);
|
|
|
|
int r = cvRound(responses.at<float>(si));
|
|
|
|
CV_DbgAssert( fabs(responses.at<float>(si) - r) < FLT_EPSILON );
|
2011-02-10 04:55:11 +08:00
|
|
|
r = cls_map[r];
|
|
|
|
int d = best_cls.x == r ? 0 : 1;
|
|
|
|
err += d;
|
|
|
|
pred_resp[i] = (float)best_cls.x;
|
|
|
|
}
|
|
|
|
err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
|
|
|
|
return err;
|
|
|
|
}
|
|
|
|
|
2017-11-23 05:07:23 +08:00
|
|
|
TEST(ML_ANN, ActivationFunction)
|
|
|
|
{
|
|
|
|
String folder = string(cvtest::TS::ptr()->get_data_path());
|
|
|
|
String original_path = folder + "waveform.data";
|
|
|
|
String dataname = folder + "waveform";
|
|
|
|
|
|
|
|
Ptr<TrainData> tdata = TrainData::loadFromCSV(original_path, 0);
|
|
|
|
|
|
|
|
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
|
|
|
|
RNG& rng = theRNG();
|
|
|
|
rng.state = 1027401484159173092;
|
|
|
|
tdata->setTrainTestSplit(500);
|
|
|
|
|
|
|
|
vector<int> activationType;
|
|
|
|
activationType.push_back(ml::ANN_MLP::IDENTITY);
|
|
|
|
activationType.push_back(ml::ANN_MLP::SIGMOID_SYM);
|
|
|
|
activationType.push_back(ml::ANN_MLP::GAUSSIAN);
|
|
|
|
activationType.push_back(ml::ANN_MLP::RELU);
|
|
|
|
activationType.push_back(ml::ANN_MLP::LEAKYRELU);
|
|
|
|
vector<String> activationName;
|
|
|
|
activationName.push_back("_identity");
|
|
|
|
activationName.push_back("_sigmoid_sym");
|
|
|
|
activationName.push_back("_gaussian");
|
|
|
|
activationName.push_back("_relu");
|
|
|
|
activationName.push_back("_leakyrelu");
|
|
|
|
for (size_t i = 0; i < activationType.size(); i++)
|
|
|
|
{
|
|
|
|
Ptr<ml::ANN_MLP> x = ml::ANN_MLP::create();
|
|
|
|
Mat_<int> layerSizes(1, 4);
|
|
|
|
layerSizes(0, 0) = tdata->getNVars();
|
|
|
|
layerSizes(0, 1) = 100;
|
|
|
|
layerSizes(0, 2) = 100;
|
|
|
|
layerSizes(0, 3) = tdata->getResponses().cols;
|
|
|
|
x->setLayerSizes(layerSizes);
|
|
|
|
x->setActivationFunction(activationType[i]);
|
|
|
|
x->setTrainMethod(ml::ANN_MLP::RPROP, 0.01, 0.1);
|
|
|
|
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 300, 0.01));
|
|
|
|
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE);
|
|
|
|
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << activationName[i];
|
|
|
|
#ifdef GENERATE_TESTDATA
|
|
|
|
x->save(dataname + activationName[i] + ".yml");
|
|
|
|
#else
|
|
|
|
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + activationName[i] + ".yml");
|
|
|
|
ASSERT_TRUE(y != NULL) << "Could not load " << dataname + activationName[i] + ".yml";
|
|
|
|
Mat testSamples = tdata->getTestSamples();
|
|
|
|
Mat rx, ry, dst;
|
|
|
|
x->predict(testSamples, rx);
|
|
|
|
y->predict(testSamples, ry);
|
2017-12-15 18:57:39 +08:00
|
|
|
double n = cvtest::norm(rx, ry, NORM_INF);
|
|
|
|
EXPECT_LT(n,FLT_EPSILON) << "Predict are not equal for " << dataname + activationName[i] + ".yml and " << activationName[i];
|
2017-11-23 05:07:23 +08:00
|
|
|
#endif
|
|
|
|
}
|
|
|
|
}
|
2017-12-15 22:40:08 +08:00
|
|
|
|
2018-02-20 00:45:04 +08:00
|
|
|
CV_ENUM(ANN_MLP_METHOD, ANN_MLP::RPROP, ANN_MLP::ANNEAL)
|
|
|
|
|
|
|
|
typedef tuple<ANN_MLP_METHOD, string, int> ML_ANN_METHOD_Params;
|
|
|
|
typedef TestWithParam<ML_ANN_METHOD_Params> ML_ANN_METHOD;
|
|
|
|
|
|
|
|
TEST_P(ML_ANN_METHOD, Test)
|
2017-12-15 18:57:39 +08:00
|
|
|
{
|
2018-02-20 00:45:04 +08:00
|
|
|
int methodType = get<0>(GetParam());
|
|
|
|
string methodName = get<1>(GetParam());
|
|
|
|
int N = get<2>(GetParam());
|
|
|
|
|
2017-12-15 18:57:39 +08:00
|
|
|
String folder = string(cvtest::TS::ptr()->get_data_path());
|
|
|
|
String original_path = folder + "waveform.data";
|
2018-02-20 00:45:04 +08:00
|
|
|
String dataname = folder + "waveform" + '_' + methodName;
|
2017-12-15 18:57:39 +08:00
|
|
|
|
|
|
|
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
|
2018-02-20 00:45:04 +08:00
|
|
|
Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
|
|
|
|
Mat responses(N, 3, CV_32FC1, Scalar(0));
|
|
|
|
for (int i = 0; i < N; i++)
|
2017-12-15 18:57:39 +08:00
|
|
|
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
|
2018-02-20 00:45:04 +08:00
|
|
|
Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
|
2017-12-15 18:57:39 +08:00
|
|
|
|
|
|
|
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
|
|
|
|
RNG& rng = theRNG();
|
|
|
|
rng.state = 0;
|
|
|
|
tdata->setTrainTestSplitRatio(0.8);
|
|
|
|
|
2018-02-20 00:45:04 +08:00
|
|
|
Mat testSamples = tdata->getTestSamples();
|
|
|
|
|
2017-12-15 18:57:39 +08:00
|
|
|
#ifdef GENERATE_TESTDATA
|
2017-12-15 22:40:08 +08:00
|
|
|
{
|
2017-12-15 18:57:39 +08:00
|
|
|
Ptr<ml::ANN_MLP> xx = ml::ANN_MLP_ANNEAL::create();
|
|
|
|
Mat_<int> layerSizesXX(1, 4);
|
|
|
|
layerSizesXX(0, 0) = tdata->getNVars();
|
|
|
|
layerSizesXX(0, 1) = 30;
|
|
|
|
layerSizesXX(0, 2) = 30;
|
|
|
|
layerSizesXX(0, 3) = tdata->getResponses().cols;
|
|
|
|
xx->setLayerSizes(layerSizesXX);
|
|
|
|
xx->setActivationFunction(ml::ANN_MLP::SIGMOID_SYM);
|
|
|
|
xx->setTrainMethod(ml::ANN_MLP::RPROP);
|
|
|
|
xx->setTermCriteria(TermCriteria(TermCriteria::COUNT, 1, 0.01));
|
|
|
|
xx->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE);
|
|
|
|
FileStorage fs;
|
|
|
|
fs.open(dataname + "_init_weight.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
|
|
|
|
xx->write(fs);
|
|
|
|
fs.release();
|
2017-12-15 22:40:08 +08:00
|
|
|
}
|
2017-12-15 18:57:39 +08:00
|
|
|
#endif
|
|
|
|
{
|
|
|
|
FileStorage fs;
|
2018-02-20 00:45:04 +08:00
|
|
|
fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ);
|
2017-12-15 18:57:39 +08:00
|
|
|
Ptr<ml::ANN_MLP> x = ml::ANN_MLP_ANNEAL::create();
|
|
|
|
x->read(fs.root());
|
2018-02-20 00:45:04 +08:00
|
|
|
x->setTrainMethod(methodType);
|
|
|
|
if (methodType == ml::ANN_MLP::ANNEAL)
|
2017-12-15 18:57:39 +08:00
|
|
|
{
|
2017-12-15 22:40:08 +08:00
|
|
|
x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff)));
|
2017-12-15 18:57:39 +08:00
|
|
|
x->setAnnealInitialT(12);
|
|
|
|
x->setAnnealFinalT(0.15);
|
|
|
|
x->setAnnealCoolingRatio(0.96);
|
|
|
|
x->setAnnealItePerStep(11);
|
|
|
|
}
|
|
|
|
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
|
|
|
|
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
|
2018-02-20 00:45:04 +08:00
|
|
|
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName;
|
|
|
|
string filename = dataname + ".yml.gz";
|
|
|
|
Mat r_gold;
|
2017-12-15 18:57:39 +08:00
|
|
|
#ifdef GENERATE_TESTDATA
|
2018-02-20 00:45:04 +08:00
|
|
|
x->save(filename);
|
|
|
|
x->predict(testSamples, r_gold);
|
|
|
|
{
|
|
|
|
FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
|
|
|
|
fs_response << "response" << r_gold;
|
|
|
|
}
|
|
|
|
#else
|
|
|
|
{
|
|
|
|
FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::READ);
|
|
|
|
fs_response["response"] >> r_gold;
|
|
|
|
}
|
2017-12-15 18:57:39 +08:00
|
|
|
#endif
|
2018-02-20 00:45:04 +08:00
|
|
|
ASSERT_FALSE(r_gold.empty());
|
|
|
|
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(filename);
|
|
|
|
ASSERT_TRUE(y != NULL) << "Could not load " << filename;
|
|
|
|
Mat rx, ry;
|
2017-12-15 18:57:39 +08:00
|
|
|
for (int j = 0; j < 4; j++)
|
|
|
|
{
|
|
|
|
rx = x->getWeights(j);
|
|
|
|
ry = y->getWeights(j);
|
|
|
|
double n = cvtest::norm(rx, ry, NORM_INF);
|
2018-02-20 00:45:04 +08:00
|
|
|
EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for layer: " << j;
|
2017-12-15 18:57:39 +08:00
|
|
|
}
|
|
|
|
x->predict(testSamples, rx);
|
|
|
|
y->predict(testSamples, ry);
|
2018-02-20 00:45:04 +08:00
|
|
|
double n = cvtest::norm(ry, rx, NORM_INF);
|
|
|
|
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to result of the saved model";
|
|
|
|
n = cvtest::norm(r_gold, rx, NORM_INF);
|
|
|
|
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to 'gold' response";
|
2017-12-15 18:57:39 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-02-20 00:45:04 +08:00
|
|
|
INSTANTIATE_TEST_CASE_P(/*none*/, ML_ANN_METHOD,
|
|
|
|
testing::Values(
|
|
|
|
make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::RPROP, "rprop", 5000),
|
|
|
|
make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::ANNEAL, "anneal", 1000)
|
|
|
|
//make_pair<ANN_MLP_METHOD, string>(ml::ANN_MLP::BACKPROP, "backprop", 5000); -----> NO BACKPROP TEST
|
|
|
|
)
|
|
|
|
);
|
|
|
|
|
2017-11-23 05:07:23 +08:00
|
|
|
|
2011-02-10 04:55:11 +08:00
|
|
|
// 6. dtree
|
|
|
|
// 7. boost
|
2013-03-21 00:13:46 +08:00
|
|
|
int str_to_boost_type( String& str )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
|
|
|
if ( !str.compare("DISCRETE") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return Boost::DISCRETE;
|
2011-02-10 04:55:11 +08:00
|
|
|
if ( !str.compare("REAL") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return Boost::REAL;
|
2011-02-10 04:55:11 +08:00
|
|
|
if ( !str.compare("LOGIT") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return Boost::LOGIT;
|
2011-02-10 04:55:11 +08:00
|
|
|
if ( !str.compare("GENTLE") )
|
2014-07-30 03:54:23 +08:00
|
|
|
return Boost::GENTLE;
|
2011-02-10 04:55:11 +08:00
|
|
|
CV_Error( CV_StsBadArg, "incorrect boost type string" );
|
|
|
|
}
|
|
|
|
|
|
|
|
// 8. rtrees
|
|
|
|
// 9. ertrees
|
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
int str_to_svmsgd_type( String& str )
|
|
|
|
{
|
|
|
|
if ( !str.compare("SGD") )
|
|
|
|
return SVMSGD::SGD;
|
|
|
|
if ( !str.compare("ASGD") )
|
|
|
|
return SVMSGD::ASGD;
|
2016-02-09 23:42:23 +08:00
|
|
|
CV_Error( CV_StsBadArg, "incorrect svmsgd type string" );
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
int str_to_margin_type( String& str )
|
|
|
|
{
|
|
|
|
if ( !str.compare("SOFT_MARGIN") )
|
|
|
|
return SVMSGD::SOFT_MARGIN;
|
|
|
|
if ( !str.compare("HARD_MARGIN") )
|
|
|
|
return SVMSGD::HARD_MARGIN;
|
|
|
|
CV_Error( CV_StsBadArg, "incorrect svmsgd margin type string" );
|
|
|
|
}
|
2017-11-05 21:48:40 +08:00
|
|
|
|
|
|
|
}
|
2011-02-10 04:55:11 +08:00
|
|
|
// ---------------------------------- MLBaseTest ---------------------------------------------------
|
|
|
|
|
|
|
|
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
|
|
|
|
{
|
|
|
|
int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
|
|
|
|
CV_BIG_INT(0x0000a17166072c7c),
|
|
|
|
CV_BIG_INT(0x0201b32115cd1f9a),
|
|
|
|
CV_BIG_INT(0x0513cb37abcd1234),
|
|
|
|
CV_BIG_INT(0x0001a2b3c4d5f678)
|
|
|
|
};
|
|
|
|
|
|
|
|
int seedCount = sizeof(seeds)/sizeof(seeds[0]);
|
|
|
|
RNG& rng = theRNG();
|
|
|
|
|
|
|
|
initSeed = rng.state;
|
|
|
|
rng.state = seeds[rng(seedCount)];
|
|
|
|
|
|
|
|
modelName = _modelName;
|
|
|
|
}
|
|
|
|
|
|
|
|
CV_MLBaseTest::~CV_MLBaseTest()
|
|
|
|
{
|
|
|
|
if( validationFS.isOpened() )
|
|
|
|
validationFS.release();
|
|
|
|
theRNG().state = initSeed;
|
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int CV_MLBaseTest::read_params( CvFileStorage* __fs )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2014-07-30 03:54:23 +08:00
|
|
|
FileStorage _fs(__fs, false);
|
|
|
|
if( !_fs.isOpened() )
|
2011-02-10 04:55:11 +08:00
|
|
|
test_case_count = -1;
|
|
|
|
else
|
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
FileNode fn = _fs.getFirstTopLevelNode()["run_params"][modelName];
|
|
|
|
test_case_count = (int)fn.size();
|
|
|
|
if( test_case_count <= 0 )
|
|
|
|
test_case_count = -1;
|
2011-02-10 04:55:11 +08:00
|
|
|
if( test_case_count > 0 )
|
|
|
|
{
|
|
|
|
dataSetNames.resize( test_case_count );
|
2014-07-30 03:54:23 +08:00
|
|
|
FileNodeIterator it = fn.begin();
|
|
|
|
for( int i = 0; i < test_case_count; i++, ++it )
|
|
|
|
{
|
|
|
|
dataSetNames[i] = (string)*it;
|
|
|
|
}
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return cvtest::TS::OK;;
|
|
|
|
}
|
|
|
|
|
2012-06-21 01:57:26 +08:00
|
|
|
void CV_MLBaseTest::run( int )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2011-02-10 04:55:11 +08:00
|
|
|
string filename = ts->get_data_path();
|
|
|
|
filename += get_validation_filename();
|
|
|
|
validationFS.open( filename, FileStorage::READ );
|
|
|
|
read_params( *validationFS );
|
2012-06-21 01:57:26 +08:00
|
|
|
|
2011-02-10 04:55:11 +08:00
|
|
|
int code = cvtest::TS::OK;
|
|
|
|
for (int i = 0; i < test_case_count; i++)
|
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_REGION("iteration");
|
2011-02-10 04:55:11 +08:00
|
|
|
int temp_code = run_test_case( i );
|
|
|
|
if (temp_code == cvtest::TS::OK)
|
|
|
|
temp_code = validate_test_results( i );
|
|
|
|
if (temp_code != cvtest::TS::OK)
|
|
|
|
code = temp_code;
|
|
|
|
}
|
|
|
|
if ( test_case_count <= 0)
|
|
|
|
{
|
|
|
|
ts->printf( cvtest::TS::LOG, "validation file is not determined or not correct" );
|
|
|
|
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
|
|
|
}
|
|
|
|
ts->set_failed_test_info( code );
|
|
|
|
}
|
|
|
|
|
|
|
|
int CV_MLBaseTest::prepare_test_case( int test_case_idx )
|
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2011-02-10 04:55:11 +08:00
|
|
|
clear();
|
|
|
|
|
|
|
|
string dataPath = ts->get_data_path();
|
|
|
|
if ( dataPath.empty() )
|
|
|
|
{
|
|
|
|
ts->printf( cvtest::TS::LOG, "data path is empty" );
|
|
|
|
return cvtest::TS::FAIL_INVALID_TEST_DATA;
|
|
|
|
}
|
|
|
|
|
|
|
|
string dataName = dataSetNames[test_case_idx],
|
|
|
|
filename = dataPath + dataName + ".data";
|
|
|
|
|
|
|
|
FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
|
|
|
|
CV_DbgAssert( !dataParamsNode.empty() );
|
|
|
|
|
|
|
|
CV_DbgAssert( !dataParamsNode["LS"].empty() );
|
2014-07-30 03:54:23 +08:00
|
|
|
int trainSampleCount = (int)dataParamsNode["LS"];
|
2011-02-10 04:55:11 +08:00
|
|
|
|
|
|
|
CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
|
2014-07-30 03:54:23 +08:00
|
|
|
int respIdx = (int)dataParamsNode["resp_idx"];
|
2011-02-10 04:55:11 +08:00
|
|
|
|
|
|
|
CV_DbgAssert( !dataParamsNode["types"].empty() );
|
2014-07-30 03:54:23 +08:00
|
|
|
String varTypes = (String)dataParamsNode["types"];
|
2011-02-10 04:55:11 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
data = TrainData::loadFromCSV(filename, 0, respIdx, respIdx+1, varTypes);
|
|
|
|
if( data.empty() )
|
|
|
|
{
|
|
|
|
ts->printf( cvtest::TS::LOG, "file %s can not be read\n", filename.c_str() );
|
|
|
|
return cvtest::TS::FAIL_INVALID_TEST_DATA;
|
|
|
|
}
|
|
|
|
|
|
|
|
data->setTrainTestSplit(trainSampleCount);
|
2011-02-10 04:55:11 +08:00
|
|
|
return cvtest::TS::OK;
|
|
|
|
}
|
|
|
|
|
|
|
|
string& CV_MLBaseTest::get_validation_filename()
|
|
|
|
{
|
|
|
|
return validationFN;
|
|
|
|
}
|
|
|
|
|
|
|
|
int CV_MLBaseTest::train( int testCaseIdx )
|
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2011-02-10 04:55:11 +08:00
|
|
|
bool is_trained = false;
|
2012-06-21 01:57:26 +08:00
|
|
|
FileNode modelParamsNode =
|
2011-02-10 04:55:11 +08:00
|
|
|
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( modelName == CV_NBAYES )
|
|
|
|
model = NormalBayesClassifier::create();
|
|
|
|
else if( modelName == CV_KNEAREST )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
model = KNearest::create();
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_SVM )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2013-03-21 00:13:46 +08:00
|
|
|
String svm_type_str, kernel_type_str;
|
2011-02-10 04:55:11 +08:00
|
|
|
modelParamsNode["svm_type"] >> svm_type_str;
|
|
|
|
modelParamsNode["kernel_type"] >> kernel_type_str;
|
2015-02-11 18:24:14 +08:00
|
|
|
Ptr<SVM> m = SVM::create();
|
|
|
|
m->setType(str_to_svm_type( svm_type_str ));
|
|
|
|
m->setKernel(str_to_svm_kernel_type( kernel_type_str ));
|
|
|
|
m->setDegree(modelParamsNode["degree"]);
|
|
|
|
m->setGamma(modelParamsNode["gamma"]);
|
|
|
|
m->setCoef0(modelParamsNode["coef0"]);
|
|
|
|
m->setC(modelParamsNode["C"]);
|
|
|
|
m->setNu(modelParamsNode["nu"]);
|
|
|
|
m->setP(modelParamsNode["p"]);
|
|
|
|
model = m;
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_EM )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
|
|
|
assert( 0 );
|
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_ANN )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2013-03-21 00:13:46 +08:00
|
|
|
String train_method_str;
|
2011-02-10 04:55:11 +08:00
|
|
|
double param1, param2;
|
|
|
|
modelParamsNode["train_method"] >> train_method_str;
|
|
|
|
modelParamsNode["param1"] >> param1;
|
|
|
|
modelParamsNode["param2"] >> param2;
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat new_responses = ann_get_new_responses( data, cls_map );
|
|
|
|
// binarize the responses
|
|
|
|
data = TrainData::create(data->getSamples(), data->getLayout(), new_responses,
|
|
|
|
data->getVarIdx(), data->getTrainSampleIdx());
|
|
|
|
int layer_sz[] = { data->getNAllVars(), 100, 100, (int)cls_map.size() };
|
|
|
|
Mat layer_sizes( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
|
2015-02-11 18:24:14 +08:00
|
|
|
Ptr<ANN_MLP> m = ANN_MLP::create();
|
|
|
|
m->setLayerSizes(layer_sizes);
|
|
|
|
m->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
|
|
|
|
m->setTermCriteria(TermCriteria(TermCriteria::COUNT,300,0.01));
|
|
|
|
m->setTrainMethod(str_to_ann_train_method(train_method_str), param1, param2);
|
|
|
|
model = m;
|
|
|
|
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_DTREE )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
|
|
|
int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
|
|
|
|
float REG_ACCURACY = 0;
|
2014-07-30 03:54:23 +08:00
|
|
|
bool USE_SURROGATE = false, IS_PRUNED;
|
2011-02-10 04:55:11 +08:00
|
|
|
modelParamsNode["max_depth"] >> MAX_DEPTH;
|
|
|
|
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
|
2014-07-30 03:54:23 +08:00
|
|
|
//modelParamsNode["use_surrogate"] >> USE_SURROGATE;
|
2011-02-10 04:55:11 +08:00
|
|
|
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
|
|
|
|
modelParamsNode["cv_folds"] >> CV_FOLDS;
|
|
|
|
modelParamsNode["is_pruned"] >> IS_PRUNED;
|
2015-02-11 18:24:14 +08:00
|
|
|
|
|
|
|
Ptr<DTrees> m = DTrees::create();
|
|
|
|
m->setMaxDepth(MAX_DEPTH);
|
|
|
|
m->setMinSampleCount(MIN_SAMPLE_COUNT);
|
|
|
|
m->setRegressionAccuracy(REG_ACCURACY);
|
|
|
|
m->setUseSurrogates(USE_SURROGATE);
|
|
|
|
m->setMaxCategories(MAX_CATEGORIES);
|
|
|
|
m->setCVFolds(CV_FOLDS);
|
|
|
|
m->setUse1SERule(false);
|
|
|
|
m->setTruncatePrunedTree(IS_PRUNED);
|
|
|
|
m->setPriors(Mat());
|
|
|
|
model = m;
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_BOOST )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
|
|
|
int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;
|
|
|
|
float WEIGHT_TRIM_RATE;
|
2014-07-30 03:54:23 +08:00
|
|
|
bool USE_SURROGATE = false;
|
2013-03-21 00:13:46 +08:00
|
|
|
String typeStr;
|
2011-02-10 04:55:11 +08:00
|
|
|
modelParamsNode["type"] >> typeStr;
|
|
|
|
BOOST_TYPE = str_to_boost_type( typeStr );
|
|
|
|
modelParamsNode["weak_count"] >> WEAK_COUNT;
|
|
|
|
modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;
|
|
|
|
modelParamsNode["max_depth"] >> MAX_DEPTH;
|
2014-07-30 03:54:23 +08:00
|
|
|
//modelParamsNode["use_surrogate"] >> USE_SURROGATE;
|
2015-02-11 18:24:14 +08:00
|
|
|
|
|
|
|
Ptr<Boost> m = Boost::create();
|
|
|
|
m->setBoostType(BOOST_TYPE);
|
|
|
|
m->setWeakCount(WEAK_COUNT);
|
|
|
|
m->setWeightTrimRate(WEIGHT_TRIM_RATE);
|
|
|
|
m->setMaxDepth(MAX_DEPTH);
|
|
|
|
m->setUseSurrogates(USE_SURROGATE);
|
|
|
|
m->setPriors(Mat());
|
|
|
|
model = m;
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_RTREES )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
|
|
|
int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
|
|
|
|
float REG_ACCURACY = 0, OOB_EPS = 0.0;
|
2014-07-30 03:54:23 +08:00
|
|
|
bool USE_SURROGATE = false, IS_PRUNED;
|
2011-02-10 04:55:11 +08:00
|
|
|
modelParamsNode["max_depth"] >> MAX_DEPTH;
|
|
|
|
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
|
2014-07-30 03:54:23 +08:00
|
|
|
//modelParamsNode["use_surrogate"] >> USE_SURROGATE;
|
2011-02-10 04:55:11 +08:00
|
|
|
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
|
|
|
|
modelParamsNode["cv_folds"] >> CV_FOLDS;
|
|
|
|
modelParamsNode["is_pruned"] >> IS_PRUNED;
|
|
|
|
modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
|
|
|
|
modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
|
2015-02-11 18:24:14 +08:00
|
|
|
|
|
|
|
Ptr<RTrees> m = RTrees::create();
|
|
|
|
m->setMaxDepth(MAX_DEPTH);
|
|
|
|
m->setMinSampleCount(MIN_SAMPLE_COUNT);
|
|
|
|
m->setRegressionAccuracy(REG_ACCURACY);
|
|
|
|
m->setUseSurrogates(USE_SURROGATE);
|
|
|
|
m->setMaxCategories(MAX_CATEGORIES);
|
|
|
|
m->setPriors(Mat());
|
|
|
|
m->setCalculateVarImportance(true);
|
|
|
|
m->setActiveVarCount(NACTIVE_VARS);
|
|
|
|
m->setTermCriteria(TermCriteria(TermCriteria::COUNT, MAX_TREES_NUM, OOB_EPS));
|
|
|
|
model = m;
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
else if( modelName == CV_SVMSGD )
|
|
|
|
{
|
|
|
|
String svmsgdTypeStr;
|
|
|
|
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
|
2016-02-09 23:42:23 +08:00
|
|
|
|
|
|
|
Ptr<SVMSGD> m = SVMSGD::create();
|
|
|
|
int svmsgdType = str_to_svmsgd_type( svmsgdTypeStr );
|
|
|
|
m->setSvmsgdType(svmsgdType);
|
|
|
|
|
|
|
|
String marginTypeStr;
|
|
|
|
modelParamsNode["marginType"] >> marginTypeStr;
|
|
|
|
int marginType = str_to_margin_type( marginTypeStr );
|
|
|
|
m->setMarginType(marginType);
|
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
m->setMarginRegularization(modelParamsNode["marginRegularization"]);
|
|
|
|
m->setInitialStepSize(modelParamsNode["initialStepSize"]);
|
|
|
|
m->setStepDecreasingPower(modelParamsNode["stepDecreasingPower"]);
|
2016-01-20 17:59:44 +08:00
|
|
|
m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
|
|
|
|
model = m;
|
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !model.empty() )
|
|
|
|
is_trained = model->train(data, 0);
|
|
|
|
|
2011-02-10 04:55:11 +08:00
|
|
|
if( !is_trained )
|
|
|
|
{
|
|
|
|
ts->printf( cvtest::TS::LOG, "in test case %d model training was failed", testCaseIdx );
|
|
|
|
return cvtest::TS::FAIL_INVALID_OUTPUT;
|
|
|
|
}
|
|
|
|
return cvtest::TS::OK;
|
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
|
2011-02-10 04:55:11 +08:00
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2014-07-30 03:54:23 +08:00
|
|
|
int type = CV_TEST_ERROR;
|
2011-02-10 04:55:11 +08:00
|
|
|
float err = 0;
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat _resp;
|
|
|
|
if( modelName == CV_EM )
|
2011-02-10 04:55:11 +08:00
|
|
|
assert( 0 );
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_ANN )
|
|
|
|
err = ann_calc_error( model, data, cls_map, type, resp );
|
|
|
|
else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
|
2016-01-20 17:59:44 +08:00
|
|
|
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD )
|
2014-07-30 03:54:23 +08:00
|
|
|
err = model->calcError( data, true, _resp );
|
|
|
|
if( !_resp.empty() && resp )
|
|
|
|
_resp.convertTo(*resp, CV_32F);
|
2011-02-10 04:55:11 +08:00
|
|
|
return err;
|
|
|
|
}
|
|
|
|
|
|
|
|
void CV_MLBaseTest::save( const char* filename )
|
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2014-07-30 03:54:23 +08:00
|
|
|
model->save( filename );
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void CV_MLBaseTest::load( const char* filename )
|
|
|
|
{
|
2017-05-25 23:59:01 +08:00
|
|
|
CV_TRACE_FUNCTION();
|
2014-07-30 03:54:23 +08:00
|
|
|
if( modelName == CV_NBAYES )
|
2015-04-07 21:44:26 +08:00
|
|
|
model = Algorithm::load<NormalBayesClassifier>( filename );
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_KNEAREST )
|
2015-04-07 21:44:26 +08:00
|
|
|
model = Algorithm::load<KNearest>( filename );
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_SVM )
|
2015-04-07 21:44:26 +08:00
|
|
|
model = Algorithm::load<SVM>( filename );
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_ANN )
|
2015-04-07 21:44:26 +08:00
|
|
|
model = Algorithm::load<ANN_MLP>( filename );
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_DTREE )
|
2015-04-07 21:44:26 +08:00
|
|
|
model = Algorithm::load<DTrees>( filename );
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_BOOST )
|
2015-04-07 21:44:26 +08:00
|
|
|
model = Algorithm::load<Boost>( filename );
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( modelName == CV_RTREES )
|
2015-04-07 21:44:26 +08:00
|
|
|
model = Algorithm::load<RTrees>( filename );
|
2016-01-20 17:59:44 +08:00
|
|
|
else if( modelName == CV_SVMSGD )
|
|
|
|
model = Algorithm::load<SVMSGD>( filename );
|
2014-07-30 03:54:23 +08:00
|
|
|
else
|
|
|
|
CV_Error( CV_StsNotImplemented, "invalid stat model name");
|
2011-02-10 04:55:11 +08:00
|
|
|
}
|
|
|
|
|
2017-11-05 21:48:40 +08:00
|
|
|
} // namespace
|
2011-02-10 04:55:11 +08:00
|
|
|
/* End of file. */
|