2010-05-12 01:44:00 +08:00
|
|
|
#ifndef _OPENCV_BOOST_H_
|
|
|
|
#define _OPENCV_BOOST_H_
|
|
|
|
|
|
|
|
#include "traincascade_features.h"
|
2014-08-03 05:41:09 +08:00
|
|
|
#include "old_ml.hpp"
|
2010-05-12 01:44:00 +08:00
|
|
|
|
|
|
|
struct CvCascadeBoostParams : CvBoostParams
|
|
|
|
{
|
|
|
|
float minHitRate;
|
|
|
|
float maxFalseAlarm;
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2010-05-12 01:44:00 +08:00
|
|
|
CvCascadeBoostParams();
|
|
|
|
CvCascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm,
|
|
|
|
double _weightTrimRate, int _maxDepth, int _maxWeakCount );
|
|
|
|
virtual ~CvCascadeBoostParams() {}
|
2013-11-16 23:56:08 +08:00
|
|
|
void write( cv::FileStorage &fs ) const;
|
|
|
|
bool read( const cv::FileNode &node );
|
2010-05-12 01:44:00 +08:00
|
|
|
virtual void printDefaults() const;
|
|
|
|
virtual void printAttrs() const;
|
2013-02-25 00:14:01 +08:00
|
|
|
virtual bool scanAttr( const std::string prmName, const std::string val);
|
2010-05-12 01:44:00 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
struct CvCascadeBoostTrainData : CvDTreeTrainData
|
|
|
|
{
|
|
|
|
CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
|
|
|
|
const CvDTreeParams& _params );
|
|
|
|
CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
|
|
|
|
int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
|
|
|
|
const CvDTreeParams& _params = CvDTreeParams() );
|
|
|
|
virtual void setData( const CvFeatureEvaluator* _featureEvaluator,
|
|
|
|
int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
|
|
|
|
const CvDTreeParams& _params=CvDTreeParams() );
|
|
|
|
void precalculate();
|
|
|
|
|
2011-12-22 19:19:27 +08:00
|
|
|
virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
|
|
|
|
|
2010-05-12 01:44:00 +08:00
|
|
|
virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf );
|
|
|
|
virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf);
|
|
|
|
virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf );
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2010-05-12 01:44:00 +08:00
|
|
|
virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf,
|
|
|
|
const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf );
|
|
|
|
virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf );
|
|
|
|
virtual float getVarValue( int vi, int si );
|
|
|
|
virtual void free_train_data();
|
|
|
|
|
|
|
|
const CvFeatureEvaluator* featureEvaluator;
|
2013-11-16 23:56:08 +08:00
|
|
|
cv::Mat valCache; // precalculated feature values (CV_32FC1)
|
2010-05-12 01:44:00 +08:00
|
|
|
CvMat _resp; // for casting
|
|
|
|
int numPrecalcVal, numPrecalcIdx;
|
|
|
|
};
|
|
|
|
|
|
|
|
class CvCascadeBoostTree : public CvBoostTree
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
virtual CvDTreeNode* predict( int sampleIdx ) const;
|
2013-11-16 23:56:08 +08:00
|
|
|
void write( cv::FileStorage &fs, const cv::Mat& featureMap );
|
|
|
|
void read( const cv::FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data );
|
|
|
|
void markFeaturesInMap( cv::Mat& featureMap );
|
2010-05-12 01:44:00 +08:00
|
|
|
protected:
|
|
|
|
virtual void split_node_data( CvDTreeNode* n );
|
|
|
|
};
|
|
|
|
|
|
|
|
class CvCascadeBoost : public CvBoost
|
|
|
|
{
|
|
|
|
public:
|
|
|
|
virtual bool train( const CvFeatureEvaluator* _featureEvaluator,
|
|
|
|
int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
|
|
|
|
const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
|
|
|
|
virtual float predict( int sampleIdx, bool returnSum = false ) const;
|
|
|
|
|
2011-12-22 19:19:27 +08:00
|
|
|
float getThreshold() const { return threshold; }
|
2013-11-16 23:56:08 +08:00
|
|
|
void write( cv::FileStorage &fs, const cv::Mat& featureMap ) const;
|
|
|
|
bool read( const cv::FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
|
2010-05-12 01:44:00 +08:00
|
|
|
const CvCascadeBoostParams& _params );
|
2013-11-16 23:56:08 +08:00
|
|
|
void markUsedFeaturesInMap( cv::Mat& featureMap );
|
2010-05-12 01:44:00 +08:00
|
|
|
protected:
|
|
|
|
virtual bool set_params( const CvBoostParams& _params );
|
|
|
|
virtual void update_weights( CvBoostTree* tree );
|
|
|
|
virtual bool isErrDesired();
|
|
|
|
|
|
|
|
float threshold;
|
|
|
|
float minHitRate, maxFalseAlarm;
|
|
|
|
};
|
|
|
|
|
|
|
|
#endif
|