opencv/modules/ml/test/test_lr.cpp

82 lines
2.8 KiB
C++
Raw Normal View History

// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
//
// AUTHOR: Rahul Kavi rahulkavi[at]live[at]com
2013-08-05 21:34:53 +08:00
//
// Test data uses subset of data from the popular Iris Dataset (1936):
// - http://archive.ics.uci.edu/ml/datasets/Iris
// - https://en.wikipedia.org/wiki/Iris_flower_data_set
2013-08-05 21:34:53 +08:00
//
#include "test_precomp.hpp"
namespace opencv_test { namespace {
2013-08-05 21:34:53 +08:00
TEST(ML_LR, accuracy)
2013-08-05 21:34:53 +08:00
{
std::string dataFileName = findDataFile("iris.data");
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
ASSERT_FALSE(tdata.empty());
Ptr<LogisticRegression> p = LogisticRegression::create();
p->setLearningRate(1.0);
p->setIterations(10001);
p->setRegularization(LogisticRegression::REG_L2);
p->setTrainMethod(LogisticRegression::BATCH);
p->setMiniBatchSize(10);
p->train(tdata);
2013-08-05 21:34:53 +08:00
Mat responses;
p->predict(tdata->getSamples(), responses);
2013-08-05 21:34:53 +08:00
float error = 1000;
EXPECT_TRUE(calculateError(responses, tdata->getResponses(), error));
EXPECT_LE(error, 0.05f);
2013-08-05 21:34:53 +08:00
}
//==================================================================================================
2013-08-05 21:34:53 +08:00
TEST(ML_LR, save_load)
2013-08-05 21:34:53 +08:00
{
string dataFileName = findDataFile("iris.data");
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
ASSERT_FALSE(tdata.empty());
2013-08-05 21:34:53 +08:00
Mat responses1, responses2;
Mat learnt_mat1, learnt_mat2;
String filename = tempfile(".xml");
2013-08-05 21:34:53 +08:00
{
Ptr<LogisticRegression> lr1 = LogisticRegression::create();
lr1->setLearningRate(1.0);
lr1->setIterations(10001);
lr1->setRegularization(LogisticRegression::REG_L2);
lr1->setTrainMethod(LogisticRegression::BATCH);
lr1->setMiniBatchSize(10);
ASSERT_NO_THROW(lr1->train(tdata));
ASSERT_NO_THROW(lr1->predict(tdata->getSamples(), responses1));
ASSERT_NO_THROW(lr1->save(filename));
learnt_mat1 = lr1->get_learnt_thetas();
2013-08-05 21:34:53 +08:00
}
{
Ptr<LogisticRegression> lr2;
ASSERT_NO_THROW(lr2 = Algorithm::load<LogisticRegression>(filename));
ASSERT_NO_THROW(lr2->predict(tdata->getSamples(), responses2));
learnt_mat2 = lr2->get_learnt_thetas();
2013-08-05 21:34:53 +08:00
}
// compare difference in prediction outputs and stored inputs
EXPECT_MAT_NEAR(responses1, responses2, 0.f);
2013-08-05 21:34:53 +08:00
Mat comp_learnt_mats;
2013-08-05 21:34:53 +08:00
comp_learnt_mats = (learnt_mat1 == learnt_mat2);
comp_learnt_mats = comp_learnt_mats.reshape(1, comp_learnt_mats.rows*comp_learnt_mats.cols);
comp_learnt_mats.convertTo(comp_learnt_mats, CV_32S);
comp_learnt_mats = comp_learnt_mats/255;
// check if there is any difference between computed learnt mat and retrieved mat
EXPECT_EQ(comp_learnt_mats.rows, sum(comp_learnt_mats)[0]);
2013-08-05 21:34:53 +08:00
remove( filename.c_str() );
}
}} // namespace