diff --git a/modules/ml/test/test_mltests2.cpp b/modules/ml/test/test_mltests2.cpp index 7d6bc1dd9a..2ff0c93acb 100644 --- a/modules/ml/test/test_mltests2.cpp +++ b/modules/ml/test/test_mltests2.cpp @@ -252,31 +252,35 @@ TEST(ML_ANN, ActivationFunction) } } -TEST(ML_ANN, Method) +CV_ENUM(ANN_MLP_METHOD, ANN_MLP::RPROP, ANN_MLP::ANNEAL) + +typedef tuple ML_ANN_METHOD_Params; +typedef TestWithParam ML_ANN_METHOD; + +TEST_P(ML_ANN_METHOD, Test) { + int methodType = get<0>(GetParam()); + string methodName = get<1>(GetParam()); + int N = get<2>(GetParam()); + String folder = string(cvtest::TS::ptr()->get_data_path()); String original_path = folder + "waveform.data"; - String dataname = folder + "waveform"; + String dataname = folder + "waveform" + '_' + methodName; Ptr tdata2 = TrainData::loadFromCSV(original_path, 0); - Mat responses(tdata2->getResponses().rows, 3, CV_32FC1, Scalar(0)); - for (int i = 0; igetResponses().rows; i++) + Mat samples = tdata2->getSamples()(Range(0, N), Range::all()); + Mat responses(N, 3, CV_32FC1, Scalar(0)); + for (int i = 0; i < N; i++) responses.at(i, static_cast(tdata2->getResponses().at(i, 0))) = 1; - Ptr tdata = TrainData::create(tdata2->getSamples(), ml::ROW_SAMPLE, responses); + Ptr tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses); ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path; RNG& rng = theRNG(); rng.state = 0; tdata->setTrainTestSplitRatio(0.8); - vector methodType; - methodType.push_back(ml::ANN_MLP::RPROP); - methodType.push_back(ml::ANN_MLP::ANNEAL); -// methodType.push_back(ml::ANN_MLP::BACKPROP); -----> NO BACKPROP TEST - vector methodName; - methodName.push_back("_rprop"); - methodName.push_back("_anneal"); -// methodName.push_back("_backprop"); -----> NO BACKPROP TEST + Mat testSamples = tdata->getTestSamples(); + #ifdef GENERATE_TESTDATA { Ptr xx = ml::ANN_MLP_ANNEAL::create(); @@ -296,14 +300,13 @@ TEST(ML_ANN, Method) fs.release(); } #endif - for (size_t i = 0; i < methodType.size(); i++) { FileStorage fs; - fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ + FileStorage::BASE64); + fs.open(dataname + "_init_weight.yml.gz", FileStorage::READ); Ptr x = ml::ANN_MLP_ANNEAL::create(); x->read(fs.root()); - x->setTrainMethod(methodType[i]); - if (methodType[i] == ml::ANN_MLP::ANNEAL) + x->setTrainMethod(methodType); + if (methodType == ml::ANN_MLP::ANNEAL) { x->setAnnealEnergyRNG(RNG(CV_BIG_INT(0xffffffff))); x->setAnnealInitialT(12); @@ -313,28 +316,50 @@ TEST(ML_ANN, Method) } x->setTermCriteria(TermCriteria(TermCriteria::COUNT, 100, 0.01)); x->train(tdata, ml::ANN_MLP::NO_OUTPUT_SCALE + ml::ANN_MLP::NO_INPUT_SCALE + ml::ANN_MLP::UPDATE_WEIGHTS); - ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName[i]; + ASSERT_TRUE(x->isTrained()) << "Could not train networks with " << methodName; + string filename = dataname + ".yml.gz"; + Mat r_gold; #ifdef GENERATE_TESTDATA - x->save(dataname + methodName[i] + ".yml.gz"); + x->save(filename); + x->predict(testSamples, r_gold); + { + FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::WRITE + FileStorage::BASE64); + fs_response << "response" << r_gold; + } +#else + { + FileStorage fs_response(dataname + "_response.yml.gz", FileStorage::READ); + fs_response["response"] >> r_gold; + } #endif - Ptr y = Algorithm::load(dataname + methodName[i] + ".yml.gz"); - ASSERT_TRUE(y != NULL) << "Could not load " << dataname + methodName[i] + ".yml"; - Mat testSamples = tdata->getTestSamples(); - Mat rx, ry, dst; + ASSERT_FALSE(r_gold.empty()); + Ptr y = Algorithm::load(filename); + ASSERT_TRUE(y != NULL) << "Could not load " << filename; + Mat rx, ry; for (int j = 0; j < 4; j++) { rx = x->getWeights(j); ry = y->getWeights(j); double n = cvtest::norm(rx, ry, NORM_INF); - EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i] << " layer : " << j; + EXPECT_LT(n, FLT_EPSILON) << "Weights are not equal for layer: " << j; } x->predict(testSamples, rx); y->predict(testSamples, ry); - double n = cvtest::norm(rx, ry, NORM_INF); - EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal for " << dataname + methodName[i] + ".yml and " << methodName[i]; + double n = cvtest::norm(ry, rx, NORM_INF); + EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to result of the saved model"; + n = cvtest::norm(r_gold, rx, NORM_INF); + EXPECT_LT(n, FLT_EPSILON) << "Predict are not equal to 'gold' response"; } } +INSTANTIATE_TEST_CASE_P(/*none*/, ML_ANN_METHOD, + testing::Values( + make_tuple(ml::ANN_MLP::RPROP, "rprop", 5000), + make_tuple(ml::ANN_MLP::ANNEAL, "anneal", 1000) + //make_pair(ml::ANN_MLP::BACKPROP, "backprop", 5000); -----> NO BACKPROP TEST + ) +); + // 6. dtree // 7. boost