Merge pull request #15565 from alalek:issue_15561

This commit is contained in:
Alexander Alekhin 2019-09-23 09:21:49 +00:00
commit b40fd6de32
13 changed files with 33 additions and 11 deletions

View File

@ -920,6 +920,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
const int MAX_ITER = 1000;
const double DEFAULT_EPSILON = FLT_EPSILON;
@ -955,6 +956,7 @@ public:
}
int train_anneal(const Ptr<TrainData>& trainData)
{
CV_Assert(!trainData.empty());
SimulatedAnnealingANN_MLP s(*this, trainData);
trained = true; // Enable call to CalcError
int iter = simulatedAnnealingSolver(s, params.initialT, params.finalT, params.coolingRatio, params.itePerStep, NULL, params.rEnergy);

View File

@ -88,6 +88,7 @@ public:
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
DTreesImpl::startTraining(trainData, flags);
sumResult.assign(w->sidx.size(), 0.);
@ -184,6 +185,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
startTraining(trainData, flags);
int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
vector<int> sidx = w->sidx;
@ -482,6 +484,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
return impl.train(trainData, flags);
}

View File

@ -112,6 +112,7 @@ public:
bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE
{
CV_Assert(!data.empty());
Mat samples = data->getTrainSamples(), labels;
return trainEM(samples, labels, noArray(), noArray());
}

View File

@ -59,9 +59,10 @@ bool StatModel::empty() const { return !isTrained(); }
int StatModel::getVarCount() const { return 0; }
bool StatModel::train( const Ptr<TrainData>&, int )
bool StatModel::train(const Ptr<TrainData>& trainData, int )
{
CV_TRACE_FUNCTION();
CV_Assert(!trainData.empty());
CV_Error(CV_StsNotImplemented, "");
return false;
}
@ -69,6 +70,7 @@ bool StatModel::train( const Ptr<TrainData>&, int )
bool StatModel::train( InputArray samples, int layout, InputArray responses )
{
CV_TRACE_FUNCTION();
CV_Assert(!samples.empty());
return train(TrainData::create(samples, layout, responses));
}
@ -134,6 +136,7 @@ public:
float StatModel::calcError(const Ptr<TrainData>& data, bool testerr, OutputArray _resp) const
{
CV_TRACE_FUNCTION_SKIP_NESTED();
CV_Assert(!data.empty());
Mat samples = data->getSamples();
Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx();
Mat weights = testerr ? data->getTestSampleWeights() : data->getTrainSampleWeights();

View File

@ -73,6 +73,7 @@ public:
bool train( const Ptr<TrainData>& data, int flags )
{
CV_Assert(!data.empty());
Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
Mat new_responses;
data->getTrainResponses().convertTo(new_responses, CV_32F);
@ -494,6 +495,7 @@ public:
bool train( const Ptr<TrainData>& data, int flags ) CV_OVERRIDE
{
CV_Assert(!data.empty());
return impl->train(data, flags);
}

View File

@ -142,12 +142,10 @@ Ptr<LogisticRegression> LogisticRegression::load(const String& filepath, const S
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
{
CV_TRACE_FUNCTION_SKIP_NESTED();
CV_Assert(!trainData.empty());
// return value
bool ok = false;
if (trainData.empty()) {
return false;
}
clear();
Mat _data_i = trainData->getSamples();
Mat _labels_i = trainData->getResponses();

View File

@ -54,6 +54,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_Assert(!trainData.empty());
const float min_variation = FLT_EPSILON;
Mat responses = trainData->getNormCatResponses();
Mat __cls_labels = trainData->getClassLabels();

View File

@ -111,6 +111,7 @@ public:
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_Assert(!trainData.empty());
DTreesImpl::startTraining(trainData, flags);
int nvars = w->data->getNVars();
int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
@ -133,6 +134,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_Assert(!trainData.empty());
startTraining(trainData, flags);
int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
rparams.termCrit.maxCount : 10000;
@ -463,6 +465,7 @@ public:
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_Assert(!trainData.empty());
if (impl.getCVFolds() != 0)
CV_Error(Error::StsBadArg, "Cross validation for RTrees is not implemented");
return impl.train(trainData, flags);

View File

@ -1613,6 +1613,7 @@ public:
bool train( const Ptr<TrainData>& data, int ) CV_OVERRIDE
{
CV_Assert(!data.empty());
clear();
checkParams();
@ -1739,6 +1740,7 @@ public:
ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
bool balanced ) CV_OVERRIDE
{
CV_Assert(!data.empty());
checkParams();
int svmType = params.svmType;

View File

@ -230,6 +230,7 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
{
CV_Assert(!data.empty());
clear();
CV_Assert( isClassifier() ); //toDo: consider

View File

@ -98,6 +98,7 @@ DTrees::Split::Split()
DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
{
CV_Assert(!_data.empty());
data = _data;
vector<int> subsampleIdx;
Mat sidx0 = _data->getTrainSampleIdx();
@ -136,6 +137,7 @@ void DTreesImpl::clear()
void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
{
CV_Assert(!data.empty());
clear();
w = makePtr<WorkData>(data);
@ -223,6 +225,7 @@ void DTreesImpl::endTraining()
bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
{
CV_Assert(!trainData.empty());
startTraining(trainData, flags);
bool ok = addTree( w->sidx ) >= 0;
w.release();

View File

@ -94,11 +94,7 @@ void CV_LRTest::run( int /*start_from*/ )
// initialize variables from the popular Iris Dataset
string dataFileName = ts->get_data_path() + "iris.data";
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
if (tdata.empty()) {
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);
return;
}
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << dataFileName;
// run LR classifier train classifier
Ptr<LogisticRegression> p = LogisticRegression::create();
@ -156,6 +152,7 @@ void CV_LRTest_SaveLoad::run( int /*start_from*/ )
// initialize variables from the popular Iris Dataset
string dataFileName = ts->get_data_path() + "iris.data";
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << dataFileName;
Mat responses1, responses2;
Mat learnt_mat1, learnt_mat2;

View File

@ -105,6 +105,7 @@ int str_to_ann_activation_function(String& str)
void ann_check_data( Ptr<TrainData> _data )
{
CV_TRACE_FUNCTION();
CV_Assert(!_data.empty());
Mat values = _data->getSamples();
Mat var_idx = _data->getVarIdx();
int nvars = (int)var_idx.total();
@ -118,6 +119,7 @@ void ann_check_data( Ptr<TrainData> _data )
Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map )
{
CV_TRACE_FUNCTION();
CV_Assert(!_data.empty());
Mat train_sidx = _data->getTrainSampleIdx();
int* train_sidx_ptr = train_sidx.ptr<int>();
Mat responses = _data->getResponses();
@ -150,6 +152,8 @@ Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map )
float ann_calc_error( Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels )
{
CV_TRACE_FUNCTION();
CV_Assert(!ann.empty());
CV_Assert(!_data.empty());
float err = 0;
Mat samples = _data->getSamples();
Mat responses = _data->getResponses();
@ -264,13 +268,15 @@ TEST_P(ML_ANN_METHOD, Test)
String dataname = folder + "waveform" + '_' + methodName;
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
ASSERT_FALSE(tdata2.empty()) << "Could not find test data file : " << original_path;
Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
Mat responses(N, 3, CV_32FC1, Scalar(0));
for (int i = 0; i < N; i++)
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
ASSERT_FALSE(tdata.empty());
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
RNG& rng = theRNG();
rng.state = 0;
tdata->setTrainTestSplitRatio(0.8);