mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 12:22:51 +08:00
Merge pull request #8116 from mrquorr:master
This commit is contained in:
commit
e0ee2f769a
@ -1206,6 +1206,17 @@ public:
|
||||
*/
|
||||
CV_WRAP virtual Mat getVarImportance() const = 0;
|
||||
|
||||
/** Returns the result of each individual tree in the forest.
|
||||
In case the model is a regression problem, the method will return each of the trees'
|
||||
results for each of the sample cases. If the model is a classifier, it will return
|
||||
a Mat with samples + 1 rows, where the first row gives the class number and the
|
||||
following rows return the votes each class had for each sample.
|
||||
@param samples Array containg the samples for which votes will be calculated.
|
||||
@param results Array where the result of the calculation will be written.
|
||||
@param flags Flags for defining the type of RTrees.
|
||||
*/
|
||||
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
|
||||
|
||||
/** Creates the empty model.
|
||||
Use StatModel::train to train the model, StatModel::train to create and train the model,
|
||||
Algorithm::load to load the pre-trained model.
|
||||
|
@ -349,6 +349,60 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void getVotes( InputArray input, OutputArray output, int flags ) const
|
||||
{
|
||||
CV_Assert( !roots.empty() );
|
||||
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
|
||||
Mat samples = input.getMat(), results;
|
||||
int i, j, nsamples = samples.rows;
|
||||
|
||||
int predictType = flags & PREDICT_MASK;
|
||||
if( predictType == PREDICT_AUTO )
|
||||
{
|
||||
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
|
||||
PREDICT_SUM : PREDICT_MAX_VOTE;
|
||||
}
|
||||
|
||||
if( predictType == PREDICT_SUM )
|
||||
{
|
||||
output.create(nsamples, ntrees, CV_32F);
|
||||
results = output.getMat();
|
||||
for( i = 0; i < nsamples; i++ )
|
||||
{
|
||||
for( j = 0; j < ntrees; j++ )
|
||||
{
|
||||
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
|
||||
results.at<float> (i, j) = val;
|
||||
}
|
||||
}
|
||||
} else
|
||||
{
|
||||
vector<int> votes;
|
||||
output.create(nsamples+1, nclasses, CV_32S);
|
||||
results = output.getMat();
|
||||
|
||||
for ( j = 0; j < nclasses; j++)
|
||||
{
|
||||
results.at<int> (0, j) = classLabels[j];
|
||||
}
|
||||
|
||||
for( i = 0; i < nsamples; i++ )
|
||||
{
|
||||
votes.clear();
|
||||
for( j = 0; j < ntrees; j++ )
|
||||
{
|
||||
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
|
||||
votes.push_back(val);
|
||||
}
|
||||
|
||||
for ( j = 0; j < nclasses; j++)
|
||||
{
|
||||
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RTreeParams rparams;
|
||||
double oobError;
|
||||
vector<float> varImportance;
|
||||
@ -401,6 +455,11 @@ public:
|
||||
impl.read(fn);
|
||||
}
|
||||
|
||||
void getVotes_( InputArray samples, OutputArray results, int flags ) const
|
||||
{
|
||||
impl.getVotes(samples, results, flags);
|
||||
}
|
||||
|
||||
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
|
||||
int getVarCount() const { return impl.getVarCount(); }
|
||||
|
||||
@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
|
||||
return Algorithm::load<RTrees>(filepath, nodeName);
|
||||
}
|
||||
|
||||
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
|
||||
{
|
||||
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
|
||||
if(!this_)
|
||||
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
|
||||
return this_->getVotes_(input, output, flags);
|
||||
}
|
||||
|
||||
}}
|
||||
|
||||
// End of file.
|
||||
|
@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911)
|
||||
EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total());
|
||||
}
|
||||
|
||||
TEST(ML_RTrees, getVotes)
|
||||
{
|
||||
int n = 12;
|
||||
int count, i;
|
||||
int label_size = 3;
|
||||
int predicted_class = 0;
|
||||
int max_votes = -1;
|
||||
int val;
|
||||
// RTrees for classification
|
||||
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
|
||||
|
||||
//data
|
||||
Mat data(n, 4, CV_32F);
|
||||
randu(data, 0, 10);
|
||||
|
||||
//labels
|
||||
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
|
||||
|
||||
rt->train(data, ml::ROW_SAMPLE, labels);
|
||||
|
||||
//run function
|
||||
Mat test(1, 4, CV_32F);
|
||||
Mat result;
|
||||
randu(test, 0, 10);
|
||||
rt->getVotes(test, result, 0);
|
||||
|
||||
//count vote amount and find highest vote
|
||||
count = 0;
|
||||
const int* result_row = result.ptr<int>(1);
|
||||
for( i = 0; i < label_size; i++ )
|
||||
{
|
||||
val = result_row[i];
|
||||
//predicted_class = max_votes < val? i;
|
||||
if( max_votes < val )
|
||||
{
|
||||
max_votes = val;
|
||||
predicted_class = i;
|
||||
}
|
||||
count += val;
|
||||
}
|
||||
|
||||
EXPECT_EQ(count, (int)rt->getRoots().size());
|
||||
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
|
||||
}
|
||||
|
||||
/* End of file. */
|
||||
|
Loading…
Reference in New Issue
Block a user