mirror of
https://github.com/opencv/opencv.git
synced 2025-06-21 02:20:50 +08:00
Merge pull request #8116 from mrquorr:master
This commit is contained in:
commit
e0ee2f769a
@ -1206,6 +1206,17 @@ public:
|
|||||||
*/
|
*/
|
||||||
CV_WRAP virtual Mat getVarImportance() const = 0;
|
CV_WRAP virtual Mat getVarImportance() const = 0;
|
||||||
|
|
||||||
|
/** Returns the result of each individual tree in the forest.
|
||||||
|
In case the model is a regression problem, the method will return each of the trees'
|
||||||
|
results for each of the sample cases. If the model is a classifier, it will return
|
||||||
|
a Mat with samples + 1 rows, where the first row gives the class number and the
|
||||||
|
following rows return the votes each class had for each sample.
|
||||||
|
@param samples Array containg the samples for which votes will be calculated.
|
||||||
|
@param results Array where the result of the calculation will be written.
|
||||||
|
@param flags Flags for defining the type of RTrees.
|
||||||
|
*/
|
||||||
|
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
|
||||||
|
|
||||||
/** Creates the empty model.
|
/** Creates the empty model.
|
||||||
Use StatModel::train to train the model, StatModel::train to create and train the model,
|
Use StatModel::train to train the model, StatModel::train to create and train the model,
|
||||||
Algorithm::load to load the pre-trained model.
|
Algorithm::load to load the pre-trained model.
|
||||||
|
@ -349,6 +349,60 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void getVotes( InputArray input, OutputArray output, int flags ) const
|
||||||
|
{
|
||||||
|
CV_Assert( !roots.empty() );
|
||||||
|
int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
|
||||||
|
Mat samples = input.getMat(), results;
|
||||||
|
int i, j, nsamples = samples.rows;
|
||||||
|
|
||||||
|
int predictType = flags & PREDICT_MASK;
|
||||||
|
if( predictType == PREDICT_AUTO )
|
||||||
|
{
|
||||||
|
predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
|
||||||
|
PREDICT_SUM : PREDICT_MAX_VOTE;
|
||||||
|
}
|
||||||
|
|
||||||
|
if( predictType == PREDICT_SUM )
|
||||||
|
{
|
||||||
|
output.create(nsamples, ntrees, CV_32F);
|
||||||
|
results = output.getMat();
|
||||||
|
for( i = 0; i < nsamples; i++ )
|
||||||
|
{
|
||||||
|
for( j = 0; j < ntrees; j++ )
|
||||||
|
{
|
||||||
|
float val = predictTrees( Range(j, j+1), samples.row(i), flags);
|
||||||
|
results.at<float> (i, j) = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
{
|
||||||
|
vector<int> votes;
|
||||||
|
output.create(nsamples+1, nclasses, CV_32S);
|
||||||
|
results = output.getMat();
|
||||||
|
|
||||||
|
for ( j = 0; j < nclasses; j++)
|
||||||
|
{
|
||||||
|
results.at<int> (0, j) = classLabels[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
for( i = 0; i < nsamples; i++ )
|
||||||
|
{
|
||||||
|
votes.clear();
|
||||||
|
for( j = 0; j < ntrees; j++ )
|
||||||
|
{
|
||||||
|
int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
|
||||||
|
votes.push_back(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
for ( j = 0; j < nclasses; j++)
|
||||||
|
{
|
||||||
|
results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
RTreeParams rparams;
|
RTreeParams rparams;
|
||||||
double oobError;
|
double oobError;
|
||||||
vector<float> varImportance;
|
vector<float> varImportance;
|
||||||
@ -401,6 +455,11 @@ public:
|
|||||||
impl.read(fn);
|
impl.read(fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void getVotes_( InputArray samples, OutputArray results, int flags ) const
|
||||||
|
{
|
||||||
|
impl.getVotes(samples, results, flags);
|
||||||
|
}
|
||||||
|
|
||||||
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
|
Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
|
||||||
int getVarCount() const { return impl.getVarCount(); }
|
int getVarCount() const { return impl.getVarCount(); }
|
||||||
|
|
||||||
@ -427,6 +486,14 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
|
|||||||
return Algorithm::load<RTrees>(filepath, nodeName);
|
return Algorithm::load<RTrees>(filepath, nodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
|
||||||
|
{
|
||||||
|
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
|
||||||
|
if(!this_)
|
||||||
|
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
|
||||||
|
return this_->getVotes_(input, output, flags);
|
||||||
|
}
|
||||||
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
// End of file.
|
// End of file.
|
||||||
|
@ -172,4 +172,49 @@ TEST(ML_NBAYES, regression_5911)
|
|||||||
EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total());
|
EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ML_RTrees, getVotes)
|
||||||
|
{
|
||||||
|
int n = 12;
|
||||||
|
int count, i;
|
||||||
|
int label_size = 3;
|
||||||
|
int predicted_class = 0;
|
||||||
|
int max_votes = -1;
|
||||||
|
int val;
|
||||||
|
// RTrees for classification
|
||||||
|
Ptr<ml::RTrees> rt = cv::ml::RTrees::create();
|
||||||
|
|
||||||
|
//data
|
||||||
|
Mat data(n, 4, CV_32F);
|
||||||
|
randu(data, 0, 10);
|
||||||
|
|
||||||
|
//labels
|
||||||
|
Mat labels = (Mat_<int>(n,1) << 0,0,0,0, 1,1,1,1, 2,2,2,2);
|
||||||
|
|
||||||
|
rt->train(data, ml::ROW_SAMPLE, labels);
|
||||||
|
|
||||||
|
//run function
|
||||||
|
Mat test(1, 4, CV_32F);
|
||||||
|
Mat result;
|
||||||
|
randu(test, 0, 10);
|
||||||
|
rt->getVotes(test, result, 0);
|
||||||
|
|
||||||
|
//count vote amount and find highest vote
|
||||||
|
count = 0;
|
||||||
|
const int* result_row = result.ptr<int>(1);
|
||||||
|
for( i = 0; i < label_size; i++ )
|
||||||
|
{
|
||||||
|
val = result_row[i];
|
||||||
|
//predicted_class = max_votes < val? i;
|
||||||
|
if( max_votes < val )
|
||||||
|
{
|
||||||
|
max_votes = val;
|
||||||
|
predicted_class = i;
|
||||||
|
}
|
||||||
|
count += val;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(count, (int)rt->getRoots().size());
|
||||||
|
EXPECT_EQ(result.at<float>(0, predicted_class), rt->predict(test));
|
||||||
|
}
|
||||||
|
|
||||||
/* End of file. */
|
/* End of file. */
|
||||||
|
Loading…
Reference in New Issue
Block a user