diff --git a/modules/ml/include/opencv2/ml/ml.hpp b/modules/ml/include/opencv2/ml/ml.hpp index 881d0d1070..36f86c4749 100644 --- a/modules/ml/include/opencv2/ml/ml.hpp +++ b/modules/ml/include/opencv2/ml/ml.hpp @@ -125,6 +125,7 @@ CV_INLINE CvParamLattice cvDefaultParamLattice( void ) #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp" #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn" #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees" +#define CV_TYPE_NAME_ML_ERTREES "opencv-ml-extremely-randomized-trees" #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees" #define CV_TRAIN_ERROR 0 @@ -1041,6 +1042,7 @@ public: CvForestTree* get_tree(int i) const; protected: + virtual std::string getName() const; virtual bool grow_forest( const CvTermCriteria term_crit ); @@ -1114,6 +1116,7 @@ public: #endif virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); protected: + virtual std::string getName() const; virtual bool grow_forest( const CvTermCriteria term_crit ); }; diff --git a/modules/ml/src/ertrees.cpp b/modules/ml/src/ertrees.cpp index b38aa34e2d..0460527ca1 100644 --- a/modules/ml/src/ertrees.cpp +++ b/modules/ml/src/ertrees.cpp @@ -1517,6 +1517,11 @@ CvERTrees::~CvERTrees() { } +std::string CvERTrees::getName() const +{ + return CV_TYPE_NAME_ML_ERTREES; +} + bool CvERTrees::train( const CvMat* _train_data, int _tflag, const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx, const CvMat* _var_type, diff --git a/modules/ml/src/rtrees.cpp b/modules/ml/src/rtrees.cpp index 61614c24e2..81576c3b61 100644 --- a/modules/ml/src/rtrees.cpp +++ b/modules/ml/src/rtrees.cpp @@ -246,6 +246,10 @@ CvRTrees::~CvRTrees() clear(); } +std::string CvRTrees::getName() const +{ + return CV_TYPE_NAME_ML_RTREES; +} CvMat* CvRTrees::get_active_var_mask() { @@ -726,7 +730,8 @@ void CvRTrees::write( CvFileStorage* fs, const char* name ) const if( ntrees < 1 || !trees || nsamples < 1 ) CV_Error( CV_StsBadArg, "Invalid CvRTrees object" ); - cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES ); + std::string modelNodeName = this->getName(); + cvStartWriteStruct( fs, name, CV_NODE_MAP, modelNodeName.c_str() ); cvWriteInt( fs, "nclasses", nclasses ); cvWriteInt( fs, "nsamples", nsamples );