mirror of
https://github.com/opencv/opencv.git
synced 2025-07-30 17:37:05 +08:00
Merge pull request #18126 from danielenricocahall:add-oob-error-sample-weighting
Account for sample weights in calculating OOB Error * account for sample weights in oob error calculation * redefine oob error functions * fix ABI compatibility
This commit is contained in:
parent
3835ab394e
commit
c31164bf1e
@ -1294,6 +1294,15 @@ public:
|
||||
*/
|
||||
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
|
||||
|
||||
/** Returns the OOB error value, computed at the training stage when calcOOBError is set to true.
|
||||
* If this flag was set to false, 0 is returned. The OOB error is also scaled by sample weighting.
|
||||
*/
|
||||
#if CV_VERSION_MAJOR == 3
|
||||
CV_WRAP double getOOBError() const;
|
||||
#else
|
||||
/*CV_WRAP*/ virtual double getOOBError() const = 0;
|
||||
#endif
|
||||
|
||||
/** Creates the empty model.
|
||||
Use StatModel::train to train the model, StatModel::train to create and train the model,
|
||||
Algorithm::load to load the pre-trained model.
|
||||
|
@ -216,13 +216,14 @@ public:
|
||||
sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
|
||||
|
||||
double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
|
||||
double sample_weight = w->sample_weights[w->sidx[j]];
|
||||
if( !_isClassifier )
|
||||
{
|
||||
oobres[j] += val;
|
||||
oobcount[j]++;
|
||||
double true_val = w->ord_responses[w->sidx[j]];
|
||||
double a = oobres[j]/oobcount[j] - true_val;
|
||||
oobError += a*a;
|
||||
oobError += sample_weight * a*a;
|
||||
val = (val - true_val)/max_response;
|
||||
ncorrect_responses += std::exp( -val*val );
|
||||
}
|
||||
@ -237,7 +238,7 @@ public:
|
||||
if( votes[best_class] < votes[k] )
|
||||
best_class = k;
|
||||
int diff = best_class != w->cat_responses[w->sidx[j]];
|
||||
oobError += diff;
|
||||
oobError += sample_weight * diff;
|
||||
ncorrect_responses += diff == 0;
|
||||
}
|
||||
}
|
||||
@ -421,6 +422,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
double getOOBError() const {
|
||||
return oobError;
|
||||
}
|
||||
|
||||
RTreeParams rparams;
|
||||
double oobError;
|
||||
vector<float> varImportance;
|
||||
@ -505,6 +510,12 @@ public:
|
||||
const vector<Node>& getNodes() const CV_OVERRIDE { return impl.getNodes(); }
|
||||
const vector<Split>& getSplits() const CV_OVERRIDE { return impl.getSplits(); }
|
||||
const vector<int>& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); }
|
||||
#if CV_VERSION_MAJOR == 3
|
||||
double getOOBError_() const { return impl.getOOBError(); }
|
||||
#else
|
||||
double getOOBError() const CV_OVERRIDE { return impl.getOOBError(); }
|
||||
#endif
|
||||
|
||||
|
||||
DTreesImplForRTrees impl;
|
||||
};
|
||||
@ -532,6 +543,17 @@ void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
|
||||
return this_->getVotes_(input, output, flags);
|
||||
}
|
||||
|
||||
#if CV_VERSION_MAJOR == 3
|
||||
double RTrees::getOOBError() const
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
|
||||
if(!this_)
|
||||
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
|
||||
return this_->getOOBError_();
|
||||
}
|
||||
#endif
|
||||
|
||||
}}
|
||||
|
||||
// End of file.
|
||||
|
@ -51,4 +51,50 @@ TEST(ML_RTrees, getVotes)
|
||||
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
|
||||
}
|
||||
|
||||
TEST(ML_RTrees, 11142_sample_weights_regression)
|
||||
{
|
||||
int n = 3;
|
||||
// RTrees for regression
|
||||
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
|
||||
//simple regression problem of x -> 2x
|
||||
Mat data = (Mat_<float>(n,1) << 1, 2, 3);
|
||||
Mat values = (Mat_<float>(n,1) << 2, 4, 6);
|
||||
Mat weights = (Mat_<float>(n, 1) << 10, 10, 10);
|
||||
|
||||
Ptr<TrainData> trainData = TrainData::create(data, ml::ROW_SAMPLE, values);
|
||||
rt->train(trainData);
|
||||
double error_without_weights = round(rt->getOOBError());
|
||||
rt->clear();
|
||||
Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, values, Mat(), Mat(), weights );
|
||||
rt->train(trainDataWithWeights);
|
||||
double error_with_weights = round(rt->getOOBError());
|
||||
// error with weights should be larger than error without weights
|
||||
EXPECT_GE(error_with_weights, error_without_weights);
|
||||
}
|
||||
|
||||
TEST(ML_RTrees, 11142_sample_weights_classification)
|
||||
{
|
||||
int n = 12;
|
||||
// RTrees for classification
|
||||
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
|
||||
|
||||
Mat data(n, 4, CV_32F);
|
||||
randu(data, 0, 10);
|
||||
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
|
||||
Mat weights = (Mat_<float>(n, 1) << 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10);
|
||||
|
||||
rt->train(data, ml::ROW_SAMPLE, labels);
|
||||
rt->clear();
|
||||
double error_without_weights = round(rt->getOOBError());
|
||||
Ptr<TrainData> trainDataWithWeights = TrainData::create(data, ml::ROW_SAMPLE, labels, Mat(), Mat(), weights );
|
||||
rt->train(data, ml::ROW_SAMPLE, labels);
|
||||
double error_with_weights = round(rt->getOOBError());
|
||||
std::cout << error_without_weights << std::endl;
|
||||
std::cout << error_with_weights << std::endl;
|
||||
// error with weights should be larger than error without weights
|
||||
EXPECT_GE(error_with_weights, error_without_weights);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user