diff --git a/modules/ml/src/lr.cpp b/modules/ml/src/lr.cpp index 8b0a670d4e..585162cf98 100644 --- a/modules/ml/src/lr.cpp +++ b/modules/ml/src/lr.cpp @@ -96,11 +96,11 @@ public: CV_IMPL_PROPERTY(TermCriteria, TermCriteria, params.term_crit) virtual bool train( const Ptr& trainData, int=0 ); - virtual float predict(InputArray samples, OutputArray results, int) const; + virtual float predict(InputArray samples, OutputArray results, int flags=0) const; virtual void clear(); virtual void write(FileStorage& fs) const; virtual void read(const FileNode& fn); - virtual Mat get_learnt_thetas() const; + virtual Mat get_learnt_thetas() const { return learnt_thetas; } virtual int getVarCount() const { return learnt_thetas.cols; } virtual bool isTrained() const { return !learnt_thetas.empty(); } virtual bool isClassifier() const { return true; } @@ -129,57 +129,48 @@ Ptr LogisticRegression::create() bool LogisticRegressionImpl::train(const Ptr& trainData, int) { + // return value + bool ok = false; + clear(); Mat _data_i = trainData->getSamples(); Mat _labels_i = trainData->getResponses(); + // check size and type of training data CV_Assert( !_labels_i.empty() && !_data_i.empty()); - - // check the number of columns if(_labels_i.cols != 1) { - CV_Error( CV_StsBadArg, "_labels_i should be a column matrix" ); + CV_Error( CV_StsBadArg, "labels should be a column matrix" ); } - - // check data type. - // data should be of floating type CV_32FC1 - - if((_data_i.type() != CV_32FC1) || (_labels_i.type() != CV_32FC1)) + if(_data_i.type() != CV_32FC1 || _labels_i.type() != CV_32FC1) { CV_Error( CV_StsBadArg, "data and labels must be a floating point matrix" ); } + if(_labels_i.rows != _data_i.rows) + { + CV_Error( CV_StsBadArg, "number of rows in data and labels should be equal" ); + } - bool ok = false; - - Mat labels; - + // class labels set_label_map(_labels_i); + Mat labels_l = remap_labels(_labels_i, this->forward_mapper); int num_classes = (int) this->forward_mapper.size(); - - // add a column of ones - Mat data_t; - hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t ); - if(num_classes < 2) { CV_Error( CV_StsBadArg, "data should have atleast 2 classes" ); } - if(_labels_i.rows != _data_i.rows) - { - CV_Error( CV_StsBadArg, "number of rows in data and labels should be the equal" ); - } + // add a column of ones to the data (bias/intercept term) + Mat data_t; + hconcat( cv::Mat::ones( _data_i.rows, 1, CV_32F ), _data_i, data_t ); - - Mat thetas = Mat::zeros(num_classes, data_t.cols, CV_32F); + // coefficient matrix (zero-initialized) + Mat thetas; Mat init_theta = Mat::zeros(data_t.cols, 1, CV_32F); - Mat labels_l = remap_labels(_labels_i, this->forward_mapper); - Mat new_local_labels; - - int ii=0; + // fit the model (handles binary and multiclass cases) Mat new_theta; - + Mat labels; if(num_classes == 2) { labels_l.convertTo(labels, CV_32F); @@ -193,12 +184,14 @@ bool LogisticRegressionImpl::train(const Ptr& trainData, int) { /* take each class and rename classes you will get a theta per class as in multi class class scenario, we will have n thetas for n classes */ - ii = 0; - + thetas.create(num_classes, data_t.cols, CV_32F); + Mat labels_binary; + int ii = 0; for(map::iterator it = this->forward_mapper.begin(); it != this->forward_mapper.end(); ++it) { - new_local_labels = (labels_l == it->second)/255; - new_local_labels.convertTo(labels, CV_32F); + // one-vs-rest (OvR) scheme + labels_binary = (labels_l == it->second)/255; + labels_binary.convertTo(labels, CV_32F); if(this->params.train_method == LogisticRegression::BATCH) new_theta = batch_gradient_descent(data_t, labels, init_theta); else @@ -208,38 +201,28 @@ bool LogisticRegressionImpl::train(const Ptr& trainData, int) } } + // check that the estimates are stable and finite this->learnt_thetas = thetas.clone(); if( cvIsNaN( (double)sum(this->learnt_thetas)[0] ) ) { CV_Error( CV_StsBadArg, "check training parameters. Invalid training classifier" ); } + + // success ok = true; return ok; } float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, int flags) const { - /* returns a class of the predicted class - class names can be 1,2,3,4, .... etc */ - Mat thetas, data, pred_labs; - data = samples.getMat(); - - const bool rawout = flags & StatModel::RAW_OUTPUT; - // check if learnt_mats array is populated - if(this->learnt_thetas.total()<=0) + if(!this->isTrained()) { CV_Error( CV_StsBadArg, "classifier should be trained first" ); } - if(data.type() != CV_32F) - { - CV_Error( CV_StsBadArg, "data must be of floating type" ); - } - - // add a column of ones - Mat data_t; - hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t ); + // coefficient matrix + Mat thetas; if ( learnt_thetas.type() == CV_32F ) { thetas = learnt_thetas; @@ -248,53 +231,65 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i { this->learnt_thetas.convertTo( thetas, CV_32F ); } - CV_Assert(thetas.rows > 0); - double min_val; - double max_val; + // data samples + Mat data = samples.getMat(); + if(data.type() != CV_32F) + { + CV_Error( CV_StsBadArg, "data must be of floating type" ); + } - Point min_loc; - Point max_loc; + // add a column of ones to the data (bias/intercept term) + Mat data_t; + hconcat( cv::Mat::ones( data.rows, 1, CV_32F ), data, data_t ); + CV_Assert(data_t.cols == thetas.cols); - Mat labels; + // predict class labels for samples (handles binary and multiclass cases) Mat labels_c; + Mat pred_m; Mat temp_pred; - Mat pred_m = Mat::zeros(data_t.rows, thetas.rows, data.type()); - if(thetas.rows == 1) { - temp_pred = calc_sigmoid(data_t*thetas.t()); + // apply sigmoid function + temp_pred = calc_sigmoid(data_t * thetas.t()); CV_Assert(temp_pred.cols==1); pred_m = temp_pred.clone(); // if greater than 0.5, predict class 0 or predict class 1 - temp_pred = (temp_pred>0.5)/255; + temp_pred = (temp_pred > 0.5f) / 255; temp_pred.convertTo(labels_c, CV_32S); } else { - for(int i = 0;ireverse_mapper); - // convert pred_labs to integer type + + // return label of the predicted class. class names can be 1,2,3,... + Mat pred_labs = remap_labels(labels_c, this->reverse_mapper); pred_labs.convertTo(pred_labs, CV_32S); // return either the labels or the raw output if ( results.needed() ) { - if ( rawout ) + if ( flags & StatModel::RAW_OUTPUT ) { pred_m.copyTo( results ); } @@ -304,7 +299,7 @@ float LogisticRegressionImpl::predict(InputArray samples, OutputArray results, i } } - return ( pred_labs.empty() ? 0.f : (float) pred_labs.at< int >( 0 ) ); + return ( pred_labs.empty() ? 0.f : static_cast(pred_labs.at(0)) ); } Mat LogisticRegressionImpl::calc_sigmoid(const Mat& data) const @@ -596,11 +591,6 @@ void LogisticRegressionImpl::read(const FileNode& fn) } } -Mat LogisticRegressionImpl::get_learnt_thetas() const -{ - return this->learnt_thetas; -} - } }