mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
ml: refactor ML_ANN test
This commit is contained in:
parent
88b689bcf1
commit
12d2bd4adb
@ -252,31 +252,35 @@ TEST(ML_ANN, ActivationFunction)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ML_ANN, Method)
|
CV_ENUM(ANN_MLP_METHOD, ANN_MLP::RPROP, ANN_MLP::ANNEAL)
|
||||||
|
|
||||||
|
typedef tuple<ANN_MLP_METHOD, string, int> ML_ANN_METHOD_Params;
|
||||||
|
typedef TestWithParam<ML_ANN_METHOD_Params> ML_ANN_METHOD;
|
||||||
|
|
||||||
|
TEST_P(ML_ANN_METHOD, Test)
|
||||||
{
|
{
|
||||||
|
int methodType = get<0>(GetParam());
|
||||||
|
string methodName = get<1>(GetParam());
|
||||||
|
int N = get<2>(GetParam());
|
||||||
|
|
||||||
String folder = string(cvtest::TS::ptr()->get_data_path());
|
String folder = string(cvtest::TS::ptr()->get_data_path());
|
||||||
String original_path = folder + "waveform.data";
|
String original_path = folder + "waveform.data";
|
||||||
String dataname = folder + "waveform";
|
String dataname = folder + "waveform" + '_' + methodName;
|
||||||
|
|
||||||
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
|
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
|
||||||
Mat responses(tdata2->getResponses().rows, 3, CV_32FC1, Scalar(0));
|
Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
|
||||||
for (int i = 0; i<tdata2->getResponses().rows; i++)
|
Mat responses(N, 3, CV_32FC1, Scalar(0));
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
|
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
|
||||||
Ptr<TrainData> tdata = TrainData::create(tdata2->getSamples(), ml::ROW_SAMPLE, responses);
|
Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
|
||||||
|
|
||||||
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
|
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
|
||||||
RNG& rng = theRNG();
|
RNG& rng = theRNG();
|
||||||
rng.state = 0;
|
rng.state = 0;
|
||||||
tdata->setTrainTestSplitRatio(0.8);
|
tdata->setTrainTestSplitRatio(0.8);
|
||||||
|
|
||||||
vector<int> methodType;
|
Mat testSamples = tdata->getTestSamples();
|
||||||
methodType.push_back(ml::ANN_MLP::RPROP);
|
|
||||||
methodType.push_back(ml::ANN_MLP::ANNEAL);
|
|
||||||
// methodType.push_back(ml::ANN_MLP::BACKPROP); -----> NO BACKPROP TEST
|
|
||||||
vector<String> methodName;
|
|
||||||
methodName.push_back("_rprop");
|
|
||||||
methodName.push_back("_anneal");
|
|
||||||
// methodName.push_back("_backprop"); -----> NO BACKPROP TEST
|
|
||||||
#ifdef GENERATE_TESTDATA
|
#ifdef GENERATE_TESTDATA
|
||||||
{
|
{
|
||||||
Ptr<ml::ANN_MLP> xx = ml::ANN_MLP_ANNEAL::create();
|
Ptr<ml::ANN_MLP> xx = ml::ANN_MLP_ANNEAL::create();
|
||||||
@ -296,14 +300,13 @@ TEST(ML_ANN, Method)
|
|||||||
fs.release();
|
fs.release();
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
for (size_t i = 0; i < methodType.size(); i++)
|
|
||||||
{
|
{
|
||||||
FileStorage fs;
|
FileStorage fs;
|
||||||
fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ + FileStorage::BASE64);
|
fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ);
|
||||||
Ptr<ml::ANN_MLP> x = ml::ANN_MLP_ANNEAL::create();
|
Ptr<ml::ANN_MLP> x = ml::ANN_MLP_ANNEAL::create();
|
||||||
x->read(fs.root());
|
x->read(fs.root());
|
||||||
x->setTrainMethod(methodType[i]);
|
x->setTrainMethod(methodType);
|
||||||
if (methodType[i] == ml::ANN_MLP::ANNEAL)
|
if (methodType == ml::ANN_MLP::ANNEAL)
|
||||||
{
|
{
|
||||||
x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff)));
|
x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff)));
|
||||||
x->setAnnealInitialT(12);
|
x->setAnnealInitialT(12);
|
||||||
@ -313,28 +316,50 @@ TEST(ML_ANN, Method)
|
|||||||
}
|
}
|
||||||
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
|
x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01));
|
||||||
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
|
x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS);
|
||||||
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName[i];
|
ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName;
|
||||||
|
string filename = dataname + ".yml.gz";
|
||||||
|
Mat r_gold;
|
||||||
#ifdef GENERATE_TESTDATA
|
#ifdef GENERATE_TESTDATA
|
||||||
x->save(dataname + methodName[i] + ".yml.gz");
|
x->save(filename);
|
||||||
|
x->predict(testSamples, r_gold);
|
||||||
|
{
|
||||||
|
FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::WRITE + FileStorage::BASE64);
|
||||||
|
fs_response << "response" << r_gold;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
{
|
||||||
|
FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::READ);
|
||||||
|
fs_response["response"] >> r_gold;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(dataname + methodName[i] + ".yml.gz");
|
ASSERT_FALSE(r_gold.empty());
|
||||||
ASSERT_TRUE(y != NULL) << "Could not load " << dataname + methodName[i] + ".yml";
|
Ptr<ml::ANN_MLP> y = Algorithm::load<ANN_MLP>(filename);
|
||||||
Mat testSamples = tdata->getTestSamples();
|
ASSERT_TRUE(y != NULL) << "Could not load " << filename;
|
||||||
Mat rx, ry, dst;
|
Mat rx, ry;
|
||||||
for (int j = 0; j < 4; j++)
|
for (int j = 0; j < 4; j++)
|
||||||
{
|
{
|
||||||
rx = x->getWeights(j);
|
rx = x->getWeights(j);
|
||||||
ry = y->getWeights(j);
|
ry = y->getWeights(j);
|
||||||
double n = cvtest::norm(rx, ry, NORM_INF);
|
double n = cvtest::norm(rx, ry, NORM_INF);
|
||||||
EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i] << " layer : " << j;
|
EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for layer: " << j;
|
||||||
}
|
}
|
||||||
x->predict(testSamples, rx);
|
x->predict(testSamples, rx);
|
||||||
y->predict(testSamples, ry);
|
y->predict(testSamples, ry);
|
||||||
double n = cvtest::norm(rx, ry, NORM_INF);
|
double n = cvtest::norm(ry, rx, NORM_INF);
|
||||||
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i];
|
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to result of the saved model";
|
||||||
|
n = cvtest::norm(r_gold, rx, NORM_INF);
|
||||||
|
EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to 'gold' response";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(/*none*/, ML_ANN_METHOD,
|
||||||
|
testing::Values(
|
||||||
|
make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::RPROP, "rprop", 5000),
|
||||||
|
make_tuple<ANN_MLP_METHOD, string, int>(ml::ANN_MLP::ANNEAL, "anneal", 1000)
|
||||||
|
//make_pair<ANN_MLP_METHOD, string>(ml::ANN_MLP::BACKPROP, "backprop", 5000); -----> NO BACKPROP TEST
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
// 6. dtree
|
// 6. dtree
|
||||||
// 7. boost
|
// 7. boost
|
||||||
|
Loading…
Reference in New Issue
Block a user