mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Merge pull request #19884 from danielenricocahall:fix-prediction-features-bug
Fix bug with predictions in RTrees/Boost * address bug where predict functions with invalid feature count in rtrees/boost models * compact matrix rep in tests * check 1..n-1 and n+1 in feature size validation test
This commit is contained in:
parent
76860933f0
commit
a9a6801c6d
@ -490,6 +490,7 @@ public:
|
|||||||
|
|
||||||
float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
|
float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
|
||||||
{
|
{
|
||||||
|
CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F );
|
||||||
return impl.predict(samples, results, flags);
|
return impl.predict(samples, results, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -479,6 +479,7 @@ public:
|
|||||||
float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
|
float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
|
||||||
{
|
{
|
||||||
CV_TRACE_FUNCTION();
|
CV_TRACE_FUNCTION();
|
||||||
|
CV_Assert( samples.cols() == getVarCount() && samples.type() == CV_32F );
|
||||||
return impl.predict(samples, results, flags);
|
return impl.predict(samples, results, flags);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,6 +95,25 @@ TEST(ML_RTrees, 11142_sample_weights_classification)
|
|||||||
EXPECT_GE(error_with_weights, error_without_weights);
|
EXPECT_GE(error_with_weights, error_without_weights);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ML_RTrees, bug_12974_throw_exception_when_predict_different_feature_count)
|
||||||
|
{
|
||||||
|
int numFeatures = 5;
|
||||||
|
// create a 5 feature dataset and train the model
|
||||||
|
cv::Ptr<RTrees> model = RTrees::create();
|
||||||
|
Mat samples(10, numFeatures, CV_32F);
|
||||||
|
randu(samples, 0, 10);
|
||||||
|
Mat labels = (Mat_<int>(10,1) << 0,0,0,0,0,1,1,1,1,1);
|
||||||
|
cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, labels);
|
||||||
|
model->train(trainData);
|
||||||
|
// try to predict on data which have fewer features - this should throw an exception
|
||||||
|
for(int i = 1; i < numFeatures - 1; ++i) {
|
||||||
|
Mat test(1, i, CV_32FC1);
|
||||||
|
ASSERT_THROW(model->predict(test), Exception);
|
||||||
|
}
|
||||||
|
// try to predict on data which have more features - this should also throw an exception
|
||||||
|
Mat test(1, numFeatures + 1, CV_32FC1);
|
||||||
|
ASSERT_THROW(model->predict(test), Exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}} // namespace
|
}} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user