diff --git a/modules/ml/src/inner_functions.cpp b/modules/ml/src/inner_functions.cpp index eaf70d4d27..00595975ec 100644 --- a/modules/ml/src/inner_functions.cpp +++ b/modules/ml/src/inner_functions.cpp @@ -431,8 +431,6 @@ cvPreprocessIndexArray( const CvMat* idx_arr, int data_arr_size, bool check_for_ if( idx_selected == 0 ) CV_ERROR( CV_StsOutOfRange, "No components/input_variables is selected!" ); - //if( idx_selected == idx_total ) - // EXIT; break; case CV_32SC1: // idx_arr is array of integer indices of selected components diff --git a/modules/ml/src/tree.cpp b/modules/ml/src/tree.cpp index fd75738440..dcc54f682e 100644 --- a/modules/ml/src/tree.cpp +++ b/modules/ml/src/tree.cpp @@ -666,6 +666,8 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) CvMat* isubsample_idx = 0; CvMat* subsample_co = 0; + bool isMakeRootCopy = true; + CV_FUNCNAME( "CvDTreeTrainData::subsample_data" ); __BEGIN__; @@ -674,9 +676,26 @@ CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) CV_ERROR( CV_StsError, "No training data has been set" ); if( _subsample_idx ) + { CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )); - if( !isubsample_idx ) + if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count ) + { + const int* sidx = isubsample_idx->data.i; + for( int i = 0; i < sample_count; i++ ) + { + if( sidx[i] != i ) + { + isMakeRootCopy = false; + break; + } + } + } + else + isMakeRootCopy = false; + } + + if( isMakeRootCopy ) { // make a copy of the root node CvDTreeNode temp; @@ -1588,7 +1607,7 @@ bool CvDTree::do_train( const CvMat* _subsample_idx ) CV_CALL( try_split_node(root)); if( data->params.cv_folds > 0 ) - CV_CALL( prune_cv()); + CV_CALL( prune_cv() ); if( !data->shared ) data->free_train_data(); diff --git a/tests/ml/src/amltests.cpp b/tests/ml/src/amltests.cpp index 43f8062d9b..c4bd3f7fa0 100644 --- a/tests/ml/src/amltests.cpp +++ b/tests/ml/src/amltests.cpp @@ -118,6 +118,6 @@ int CV_AMLTest::validate_test_results( int testCaseIdx ) CV_AMLTest amldtree( CV_DTREE, "adtree" ); CV_AMLTest amlboost( CV_BOOST, "aboost" ); CV_AMLTest amlrtrees( CV_RTREES, "artrees" ); -//CV_AMLTest amlertrees( CV_ERTREES, "aertrees" ); +CV_AMLTest amlertrees( CV_ERTREES, "aertrees" ); /* End of file. */ diff --git a/tests/ml/src/mltest_main.cpp b/tests/ml/src/mltest_main.cpp index 1e8c59cede..b40f7a4f93 100644 --- a/tests/ml/src/mltest_main.cpp +++ b/tests/ml/src/mltest_main.cpp @@ -48,6 +48,7 @@ const char* blacklist[] = "adtree", //ticket 662 "artrees", //ticket 460 "aboost", //ticket 474 + "aertrees", 0 };