mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
ml(test): test different samples layout of TrainData
This commit is contained in:
parent
828cb4286d
commit
7ee69740e8
@ -721,5 +721,68 @@ void CV_MLBaseTest::load( const char* filename )
|
||||
CV_Error( CV_StsNotImplemented, "invalid stat model name");
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST(TrainDataGet, layout_ROW_SAMPLE) // Details: #12236
|
||||
{
|
||||
cv::Mat test = cv::Mat::ones(150, 30, CV_32FC1) * 2;
|
||||
test.col(3) += Scalar::all(3);
|
||||
cv::Mat labels = cv::Mat::ones(150, 3, CV_32SC1) * 5;
|
||||
labels.col(1) += 1;
|
||||
cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(test, cv::ml::ROW_SAMPLE, labels);
|
||||
train_data->setTrainTestSplitRatio(0.9);
|
||||
|
||||
Mat tidx = train_data->getTestSampleIdx();
|
||||
EXPECT_EQ((size_t)15, tidx.total());
|
||||
|
||||
Mat tresp = train_data->getTestResponses();
|
||||
EXPECT_EQ(15, tresp.rows);
|
||||
EXPECT_EQ(labels.cols, tresp.cols);
|
||||
EXPECT_EQ(5, tresp.at<int>(0, 0)) << tresp;
|
||||
EXPECT_EQ(6, tresp.at<int>(0, 1)) << tresp;
|
||||
EXPECT_EQ(6, tresp.at<int>(14, 1)) << tresp;
|
||||
EXPECT_EQ(5, tresp.at<int>(14, 2)) << tresp;
|
||||
|
||||
Mat tsamples = train_data->getTestSamples();
|
||||
EXPECT_EQ(15, tsamples.rows);
|
||||
EXPECT_EQ(test.cols, tsamples.cols);
|
||||
EXPECT_EQ(2, tsamples.at<float>(0, 0)) << tsamples;
|
||||
EXPECT_EQ(5, tsamples.at<float>(0, 3)) << tsamples;
|
||||
EXPECT_EQ(2, tsamples.at<float>(14, test.cols - 1)) << tsamples;
|
||||
EXPECT_EQ(5, tsamples.at<float>(14, 3)) << tsamples;
|
||||
}
|
||||
|
||||
TEST(TrainDataGet, layout_COL_SAMPLE) // Details: #12236
|
||||
{
|
||||
cv::Mat test = cv::Mat::ones(30, 150, CV_32FC1) * 3;
|
||||
test.row(3) += Scalar::all(3);
|
||||
cv::Mat labels = cv::Mat::ones(3, 150, CV_32SC1) * 5;
|
||||
labels.row(1) += 1;
|
||||
cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(test, cv::ml::COL_SAMPLE, labels);
|
||||
train_data->setTrainTestSplitRatio(0.9);
|
||||
|
||||
Mat tidx = train_data->getTestSampleIdx();
|
||||
EXPECT_EQ((size_t)15, tidx.total());
|
||||
|
||||
Mat tresp = train_data->getTestResponses(); // always row-based, transposed
|
||||
EXPECT_EQ(15, tresp.rows);
|
||||
EXPECT_EQ(labels.rows, tresp.cols);
|
||||
EXPECT_EQ(5, tresp.at<int>(0, 0)) << tresp;
|
||||
EXPECT_EQ(6, tresp.at<int>(0, 1)) << tresp;
|
||||
EXPECT_EQ(6, tresp.at<int>(14, 1)) << tresp;
|
||||
EXPECT_EQ(5, tresp.at<int>(14, 2)) << tresp;
|
||||
|
||||
|
||||
Mat tsamples = train_data->getTestSamples();
|
||||
EXPECT_EQ(15, tsamples.cols);
|
||||
EXPECT_EQ(test.rows, tsamples.rows);
|
||||
EXPECT_EQ(3, tsamples.at<float>(0, 0)) << tsamples;
|
||||
EXPECT_EQ(6, tsamples.at<float>(3, 0)) << tsamples;
|
||||
EXPECT_EQ(6, tsamples.at<float>(3, 14)) << tsamples;
|
||||
EXPECT_EQ(3, tsamples.at<float>(test.rows - 1, 14)) << tsamples;
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace
|
||||
/* End of file. */
|
||||
|
Loading…
Reference in New Issue
Block a user