2010-05-12 01:44:00 +08:00
|
|
|
/*M///////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
|
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
|
|
|
//
|
|
|
|
// By downloading, copying, installing or using the software you agree to this license.
|
|
|
|
// If you do not agree to this license, do not download, install,
|
|
|
|
// copy or use the software.
|
|
|
|
//
|
|
|
|
//
|
2014-07-30 03:54:23 +08:00
|
|
|
// License Agreement
|
|
|
|
// For Open Source Computer Vision Library
|
2010-05-12 01:44:00 +08:00
|
|
|
//
|
|
|
|
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
2014-07-30 03:54:23 +08:00
|
|
|
// Copyright (C) 2014, Itseez Inc, all rights reserved.
|
2010-05-12 01:44:00 +08:00
|
|
|
// Third party copyrights are property of their respective owners.
|
|
|
|
//
|
|
|
|
// Redistribution and use in source and binary forms, with or without modification,
|
|
|
|
// are permitted provided that the following conditions are met:
|
|
|
|
//
|
|
|
|
// * Redistribution's of source code must retain the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer.
|
|
|
|
//
|
|
|
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer in the documentation
|
|
|
|
// and/or other materials provided with the distribution.
|
|
|
|
//
|
2014-07-30 03:54:23 +08:00
|
|
|
// * The name of the copyright holders may not be used to endorse or promote products
|
2010-05-12 01:44:00 +08:00
|
|
|
// derived from this software without specific prior written permission.
|
|
|
|
//
|
|
|
|
// This software is provided by the copyright holders and contributors "as is" and
|
|
|
|
// any express or implied warranties, including, but not limited to, the implied
|
|
|
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
|
|
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
|
|
|
// indirect, incidental, special, exemplary, or consequential damages
|
|
|
|
// (including, but not limited to, procurement of substitute goods or services;
|
|
|
|
// loss of use, data, or profits; or business interruption) however caused
|
|
|
|
// and on any theory of liability, whether in contract, strict liability,
|
|
|
|
// or tort (including negligence or otherwise) arising in any way out of
|
|
|
|
// the use of this software, even if advised of the possibility of such damage.
|
|
|
|
//
|
|
|
|
//M*/
|
|
|
|
|
|
|
|
#include "precomp.hpp"
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
namespace cv { namespace ml {
|
|
|
|
|
2010-05-12 01:44:00 +08:00
|
|
|
static inline double
|
|
|
|
log_ratio( double val )
|
|
|
|
{
|
|
|
|
const double eps = 1e-5;
|
2014-07-30 03:54:23 +08:00
|
|
|
val = std::max( val, eps );
|
|
|
|
val = std::min( val, 1. - eps );
|
2010-05-12 01:44:00 +08:00
|
|
|
return log( val/(1. - val) );
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
BoostTreeParams::BoostTreeParams()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
boostType = Boost::REAL;
|
|
|
|
weakCount = 100;
|
|
|
|
weightTrimRate = 0.95;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
BoostTreeParams::BoostTreeParams( int _boostType, int _weak_count,
|
|
|
|
double _weightTrimRate)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
boostType = _boostType;
|
|
|
|
weakCount = _weak_count;
|
|
|
|
weightTrimRate = _weightTrimRate;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
class DTreesImplForBoost CV_FINAL : public DTreesImpl
|
2013-03-28 20:12:13 +08:00
|
|
|
{
|
|
|
|
public:
|
2015-02-11 18:24:14 +08:00
|
|
|
DTreesImplForBoost()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
params.setCVFolds(0);
|
|
|
|
params.setMaxDepth(1);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2015-02-11 18:24:14 +08:00
|
|
|
virtual ~DTreesImplForBoost() {}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
bool isClassifier() const CV_OVERRIDE { return true; }
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void clear() CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
DTreesImpl::clear();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2019-09-22 19:11:08 +08:00
|
|
|
CV_Assert(!trainData.empty());
|
2014-07-30 03:54:23 +08:00
|
|
|
DTreesImpl::startTraining(trainData, flags);
|
2014-08-03 05:41:09 +08:00
|
|
|
sumResult.assign(w->sidx.size(), 0.);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( bparams.boostType != Boost::DISCRETE )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
_isClassifier = false;
|
|
|
|
int i, n = (int)w->cat_responses.size();
|
|
|
|
w->ord_responses.resize(n);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
double a = -1, b = 1;
|
2014-08-03 05:41:09 +08:00
|
|
|
if( bparams.boostType == Boost::LOGIT )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
a = -2, b = 2;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
w->ord_responses[i] = w->cat_responses[i] > 0 ? b : a;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
normalizeWeights();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void normalizeWeights()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int i, n = (int)w->sidx.size();
|
|
|
|
double sumw = 0, a, b;
|
2010-05-12 01:44:00 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
2014-07-30 03:54:23 +08:00
|
|
|
sumw += w->sample_weights[w->sidx[i]];
|
|
|
|
if( sumw > DBL_EPSILON )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
a = 1./sumw;
|
|
|
|
b = 0;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
else
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
|
|
|
a = 0;
|
|
|
|
b = 1;
|
|
|
|
}
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
double& wval = w->sample_weights[w->sidx[i]];
|
|
|
|
wval = wval*a + b;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void endTraining() CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
DTreesImpl::endTraining();
|
|
|
|
vector<double> e;
|
|
|
|
std::swap(sumResult, e);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
|
|
|
|
void scaleTree( int root, double scale )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int nidx = root, pidx = 0;
|
|
|
|
Node *node = 0;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// traverse the tree and save all the nodes in depth-first order
|
|
|
|
for(;;)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
for(;;)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
node = &nodes[nidx];
|
|
|
|
node->value *= scale;
|
|
|
|
if( node->left < 0 )
|
|
|
|
break;
|
|
|
|
nidx = node->left;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
|
|
|
|
nidx = pidx, pidx = nodes[pidx].parent )
|
|
|
|
;
|
2014-08-03 07:08:25 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( pidx < 0 )
|
|
|
|
break;
|
2014-08-03 07:08:25 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
nidx = nodes[pidx].right;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void calcValue( int nidx, const vector<int>& _sidx ) CV_OVERRIDE
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
|
|
|
DTreesImpl::calcValue(nidx, _sidx);
|
|
|
|
WNode* node = &w->wnodes[nidx];
|
|
|
|
if( bparams.boostType == Boost::DISCRETE )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
node->value = node->class_idx == 0 ? -1 : 1;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( bparams.boostType == Boost::REAL )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-08-03 05:41:09 +08:00
|
|
|
double p = (node->value+1)*0.5;
|
2014-07-30 03:54:23 +08:00
|
|
|
node->value = 0.5*log_ratio(p);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
2013-01-30 01:11:52 +08:00
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
2019-09-22 19:11:08 +08:00
|
|
|
CV_Assert(!trainData.empty());
|
2014-07-30 03:54:23 +08:00
|
|
|
startTraining(trainData, flags);
|
|
|
|
int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
|
|
|
|
vector<int> sidx = w->sidx;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( treeidx = 0; treeidx < ntrees; treeidx++ )
|
|
|
|
{
|
|
|
|
int root = addTree( sidx );
|
|
|
|
if( root < 0 )
|
|
|
|
return false;
|
|
|
|
updateWeightsAndTrim( treeidx, sidx );
|
|
|
|
}
|
|
|
|
endTraining();
|
|
|
|
return true;
|
|
|
|
}
|
2012-06-09 23:00:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void updateWeightsAndTrim( int treeidx, vector<int>& sidx )
|
|
|
|
{
|
|
|
|
int i, n = (int)w->sidx.size();
|
|
|
|
int nvars = (int)varIdx.size();
|
2014-08-03 05:41:09 +08:00
|
|
|
double sumw = 0., C = 1.;
|
2014-08-03 16:46:28 +08:00
|
|
|
cv::AutoBuffer<double> buf(n + nvars);
|
2018-06-11 06:42:00 +08:00
|
|
|
double* result = buf.data();
|
2014-08-03 16:46:28 +08:00
|
|
|
float* sbuf = (float*)(result + n);
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat sample(1, nvars, CV_32F, sbuf);
|
|
|
|
int predictFlags = bparams.boostType == Boost::DISCRETE ? (PREDICT_MAX_VOTE | RAW_OUTPUT) : PREDICT_SUM;
|
|
|
|
predictFlags |= COMPRESSED_INPUT;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
w->data->getSample(varIdx, w->sidx[i], sbuf );
|
|
|
|
result[i] = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// now update weights and other parameters for each type of boosting
|
2014-07-30 03:54:23 +08:00
|
|
|
if( bparams.boostType == Boost::DISCRETE )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
|
|
|
// Discrete AdaBoost:
|
|
|
|
// weak_eval[i] (=f(x_i)) is in {-1,1}
|
|
|
|
// err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
|
|
|
|
// C = log((1-err)/err)
|
|
|
|
// w_i *= exp(C*(f(x_i) != y_i))
|
2014-07-30 03:54:23 +08:00
|
|
|
double err = 0.;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int si = w->sidx[i];
|
|
|
|
double wval = w->sample_weights[si];
|
|
|
|
sumw += wval;
|
|
|
|
err += wval*(result[i] != w->cat_responses[si]);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if( sumw != 0 )
|
|
|
|
err /= sumw;
|
2014-08-03 05:41:09 +08:00
|
|
|
C = -log_ratio( err );
|
2014-07-30 03:54:23 +08:00
|
|
|
double scale = std::exp(C);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
|
|
|
sumw = 0;
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int si = w->sidx[i];
|
|
|
|
double wval = w->sample_weights[si];
|
|
|
|
if( result[i] != w->cat_responses[si] )
|
|
|
|
wval *= scale;
|
|
|
|
sumw += wval;
|
|
|
|
w->sample_weights[si] = wval;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
scaleTree(roots[treeidx], C);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( bparams.boostType == Boost::REAL || bparams.boostType == Boost::GENTLE )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
|
|
|
// Real AdaBoost:
|
|
|
|
// weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
|
|
|
|
// w_i *= exp(-y_i*f(x_i))
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// Gentle AdaBoost:
|
|
|
|
// weak_eval[i] = f(x_i) in [-1,1]
|
|
|
|
// w_i *= exp(-y_i*f(x_i))
|
2010-05-12 01:44:00 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int si = w->sidx[i];
|
2014-08-03 05:41:09 +08:00
|
|
|
CV_Assert( std::abs(w->ord_responses[si]) == 1 );
|
2014-07-30 03:54:23 +08:00
|
|
|
double wval = w->sample_weights[si]*std::exp(-result[i]*w->ord_responses[si]);
|
|
|
|
sumw += wval;
|
|
|
|
w->sample_weights[si] = wval;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else if( bparams.boostType == Boost::LOGIT )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
|
|
|
// LogitBoost:
|
|
|
|
// weak_eval[i] = f(x_i) in [-z_max,z_max]
|
|
|
|
// sum_response = F(x_i).
|
|
|
|
// F(x_i) += 0.5*f(x_i)
|
|
|
|
// p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
|
|
|
|
// reuse weak_eval: weak_eval[i] <- p(x_i)
|
|
|
|
// w_i = p(x_i)*1(1 - p(x_i))
|
|
|
|
// z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
|
|
|
|
// store z_i to the data->data_root as the new target responses
|
|
|
|
const double lb_weight_thresh = FLT_EPSILON;
|
|
|
|
const double lb_z_max = 10.;
|
|
|
|
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int si = w->sidx[i];
|
|
|
|
sumResult[i] += 0.5*result[i];
|
|
|
|
double p = 1./(1 + std::exp(-2*sumResult[i]));
|
|
|
|
double wval = std::max( p*(1 - p), lb_weight_thresh ), z;
|
|
|
|
w->sample_weights[si] = wval;
|
|
|
|
sumw += wval;
|
|
|
|
if( w->ord_responses[si] > 0 )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
|
|
|
z = 1./p;
|
2014-07-30 03:54:23 +08:00
|
|
|
w->ord_responses[si] = std::min(z, lb_z_max);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
z = 1./(1-p);
|
2014-07-30 03:54:23 +08:00
|
|
|
w->ord_responses[si] = -std::min(z, lb_z_max);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else
|
2024-03-04 20:51:05 +08:00
|
|
|
CV_Error(cv::Error::StsNotImplemented, "Unknown boosting type");
|
2014-08-03 05:41:09 +08:00
|
|
|
|
|
|
|
/*if( bparams.boostType != Boost::LOGIT )
|
|
|
|
{
|
|
|
|
double err = 0;
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
sumResult[i] += result[i]*C;
|
|
|
|
if( bparams.boostType != Boost::DISCRETE )
|
|
|
|
err += sumResult[i]*w->ord_responses[w->sidx[i]] < 0;
|
|
|
|
else
|
|
|
|
err += sumResult[i]*w->cat_responses[w->sidx[i]] < 0;
|
|
|
|
}
|
|
|
|
printf("%d trees. C=%.2f, training error=%.1f%%, working set size=%d (out of %d)\n", (int)roots.size(), C, err*100./n, (int)sidx.size(), n);
|
|
|
|
}*/
|
2014-08-03 07:08:25 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// renormalize weights
|
|
|
|
if( sumw > FLT_EPSILON )
|
|
|
|
normalizeWeights();
|
2012-06-09 23:00:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( bparams.weightTrimRate <= 0. || bparams.weightTrimRate >= 1. )
|
|
|
|
return;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
result[i] = w->sample_weights[w->sidx[i]];
|
|
|
|
std::sort(result, result + n);
|
2012-06-09 23:00:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// as weight trimming occurs immediately after updating the weights,
|
|
|
|
// where they are renormalized, we assume that the weight sum = 1.
|
|
|
|
sumw = 1. - bparams.weightTrimRate;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
double wval = result[i];
|
|
|
|
if( sumw <= 0 )
|
|
|
|
break;
|
|
|
|
sumw -= wval;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
double threshold = i < n ? result[i] : DBL_MAX;
|
|
|
|
sidx.clear();
|
2012-06-09 23:00:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int si = w->sidx[i];
|
|
|
|
if( w->sample_weights[si] >= threshold )
|
|
|
|
sidx.push_back(si);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
float predictTrees( const Range& range, const Mat& sample, int flags0 ) const CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int flags = (flags0 & ~PREDICT_MASK) | PREDICT_SUM;
|
|
|
|
float val = DTreesImpl::predictTrees(range, sample, flags);
|
|
|
|
if( flags != flags0 )
|
|
|
|
{
|
|
|
|
int ival = (int)(val > 0);
|
|
|
|
if( !(flags0 & RAW_OUTPUT) )
|
|
|
|
ival = classLabels[ival];
|
|
|
|
val = (float)ival;
|
|
|
|
}
|
|
|
|
return val;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void writeTrainingParams( FileStorage& fs ) const CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
fs << "boosting_type" <<
|
|
|
|
(bparams.boostType == Boost::DISCRETE ? "DiscreteAdaboost" :
|
|
|
|
bparams.boostType == Boost::REAL ? "RealAdaboost" :
|
|
|
|
bparams.boostType == Boost::LOGIT ? "LogitBoost" :
|
|
|
|
bparams.boostType == Boost::GENTLE ? "GentleAdaboost" : "Unknown");
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
DTreesImpl::writeTrainingParams(fs);
|
|
|
|
fs << "weight_trimming_rate" << bparams.weightTrimRate;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-06-09 23:00:04 +08:00
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void write( FileStorage& fs ) const CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if( roots.empty() )
|
2024-03-04 20:51:05 +08:00
|
|
|
CV_Error( cv::Error::StsBadArg, "RTrees have not been trained" );
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2016-03-23 06:19:42 +08:00
|
|
|
writeFormat(fs);
|
2014-07-30 03:54:23 +08:00
|
|
|
writeParams(fs);
|
2012-12-29 04:30:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int k, ntrees = (int)roots.size();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
fs << "ntrees" << ntrees
|
|
|
|
<< "trees" << "[";
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( k = 0; k < ntrees; k++ )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
fs << "{";
|
|
|
|
writeTree(fs, roots[k]);
|
|
|
|
fs << "}";
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
fs << "]";
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void readParams( const FileNode& fn ) CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
DTreesImpl::readParams(fn);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
FileNode tparams_node = fn["training_params"];
|
2014-12-16 23:15:50 +08:00
|
|
|
// check for old layout
|
|
|
|
String bts = (String)(fn["boosting_type"].empty() ?
|
|
|
|
tparams_node["boosting_type"] : fn["boosting_type"]);
|
2014-07-30 03:54:23 +08:00
|
|
|
bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE :
|
|
|
|
bts == "RealAdaboost" ? Boost::REAL :
|
|
|
|
bts == "LogitBoost" ? Boost::LOGIT :
|
|
|
|
bts == "GentleAdaboost" ? Boost::GENTLE : -1);
|
|
|
|
_isClassifier = bparams.boostType == Boost::DISCRETE;
|
2014-12-16 23:15:50 +08:00
|
|
|
// check for old layout
|
|
|
|
bparams.weightTrimRate = (double)(fn["weight_trimming_rate"].empty() ?
|
|
|
|
tparams_node["weight_trimming_rate"] : fn["weight_trimming_rate"]);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void read( const FileNode& fn ) CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
clear();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int ntrees = (int)fn["ntrees"];
|
|
|
|
readParams(fn);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
FileNode trees_node = fn["trees"];
|
|
|
|
FileNodeIterator it = trees_node.begin();
|
|
|
|
CV_Assert( ntrees == (int)trees_node.size() );
|
2014-08-03 07:08:25 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
FileNode nfn = (*it)["nodes"];
|
|
|
|
readTree(nfn);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
2014-08-03 07:08:25 +08:00
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
BoostTreeParams bparams;
|
2014-07-30 03:54:23 +08:00
|
|
|
vector<double> sumResult;
|
|
|
|
};
|
2010-05-12 01:44:00 +08:00
|
|
|
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
class BoostImpl : public Boost
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
public:
|
|
|
|
BoostImpl() {}
|
|
|
|
virtual ~BoostImpl() {}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
inline int getBoostType() const CV_OVERRIDE { return impl.bparams.boostType; }
|
|
|
|
inline void setBoostType(int val) CV_OVERRIDE { impl.bparams.boostType = val; }
|
|
|
|
inline int getWeakCount() const CV_OVERRIDE { return impl.bparams.weakCount; }
|
|
|
|
inline void setWeakCount(int val) CV_OVERRIDE { impl.bparams.weakCount = val; }
|
|
|
|
inline double getWeightTrimRate() const CV_OVERRIDE { return impl.bparams.weightTrimRate; }
|
|
|
|
inline void setWeightTrimRate(double val) CV_OVERRIDE { impl.bparams.weightTrimRate = val; }
|
|
|
|
|
|
|
|
inline int getMaxCategories() const CV_OVERRIDE { return impl.params.getMaxCategories(); }
|
|
|
|
inline void setMaxCategories(int val) CV_OVERRIDE { impl.params.setMaxCategories(val); }
|
|
|
|
inline int getMaxDepth() const CV_OVERRIDE { return impl.params.getMaxDepth(); }
|
|
|
|
inline void setMaxDepth(int val) CV_OVERRIDE { impl.params.setMaxDepth(val); }
|
|
|
|
inline int getMinSampleCount() const CV_OVERRIDE { return impl.params.getMinSampleCount(); }
|
|
|
|
inline void setMinSampleCount(int val) CV_OVERRIDE { impl.params.setMinSampleCount(val); }
|
|
|
|
inline int getCVFolds() const CV_OVERRIDE { return impl.params.getCVFolds(); }
|
|
|
|
inline void setCVFolds(int val) CV_OVERRIDE { impl.params.setCVFolds(val); }
|
|
|
|
inline bool getUseSurrogates() const CV_OVERRIDE { return impl.params.getUseSurrogates(); }
|
|
|
|
inline void setUseSurrogates(bool val) CV_OVERRIDE { impl.params.setUseSurrogates(val); }
|
|
|
|
inline bool getUse1SERule() const CV_OVERRIDE { return impl.params.getUse1SERule(); }
|
|
|
|
inline void setUse1SERule(bool val) CV_OVERRIDE { impl.params.setUse1SERule(val); }
|
|
|
|
inline bool getTruncatePrunedTree() const CV_OVERRIDE { return impl.params.getTruncatePrunedTree(); }
|
|
|
|
inline void setTruncatePrunedTree(bool val) CV_OVERRIDE { impl.params.setTruncatePrunedTree(val); }
|
|
|
|
inline float getRegressionAccuracy() const CV_OVERRIDE { return impl.params.getRegressionAccuracy(); }
|
|
|
|
inline void setRegressionAccuracy(float val) CV_OVERRIDE { impl.params.setRegressionAccuracy(val); }
|
|
|
|
inline cv::Mat getPriors() const CV_OVERRIDE { return impl.params.getPriors(); }
|
|
|
|
inline void setPriors(const cv::Mat& val) CV_OVERRIDE { impl.params.setPriors(val); }
|
|
|
|
|
|
|
|
String getDefaultName() const CV_OVERRIDE { return "opencv_ml_boost"; }
|
|
|
|
|
|
|
|
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2019-09-22 19:11:08 +08:00
|
|
|
CV_Assert(!trainData.empty());
|
2014-07-30 03:54:23 +08:00
|
|
|
return impl.train(trainData, flags);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2021-04-13 03:05:52 +08:00
|
|
|
CV_CheckEQ(samples.cols(), getVarCount(), "");
|
2014-07-30 03:54:23 +08:00
|
|
|
return impl.predict(samples, results, flags);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void write( FileStorage& fs ) const CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
impl.write(fs);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
void read( const FileNode& fn ) CV_OVERRIDE
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
impl.read(fn);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
int getVarCount() const CV_OVERRIDE { return impl.getVarCount(); }
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
bool isTrained() const CV_OVERRIDE { return impl.isTrained(); }
|
|
|
|
bool isClassifier() const CV_OVERRIDE { return impl.isClassifier(); }
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2018-03-15 21:16:58 +08:00
|
|
|
const vector<int>& getRoots() const CV_OVERRIDE { return impl.getRoots(); }
|
|
|
|
const vector<Node>& getNodes() const CV_OVERRIDE { return impl.getNodes(); }
|
|
|
|
const vector<Split>& getSplits() const CV_OVERRIDE { return impl.getSplits(); }
|
|
|
|
const vector<int>& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); }
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
DTreesImplForBoost impl;
|
|
|
|
};
|
2010-05-12 01:44:00 +08:00
|
|
|
|
|
|
|
|
2015-02-11 18:24:14 +08:00
|
|
|
Ptr<Boost> Boost::create()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2015-02-11 18:24:14 +08:00
|
|
|
return makePtr<BoostImpl>();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2017-01-29 18:13:01 +08:00
|
|
|
Ptr<Boost> Boost::load(const String& filepath, const String& nodeName)
|
|
|
|
{
|
|
|
|
return Algorithm::load<Boost>(filepath, nodeName);
|
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
}}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
|
|
|
/* End of file. */
|