mirror of
https://github.com/opencv/opencv.git
synced 2024-11-29 22:00:25 +08:00
Merge pull request #15565 from alalek:issue_15561
This commit is contained in:
commit
b40fd6de32
@ -920,6 +920,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
const int MAX_ITER = 1000;
|
||||
const double DEFAULT_EPSILON = FLT_EPSILON;
|
||||
|
||||
@ -955,6 +956,7 @@ public:
|
||||
}
|
||||
int train_anneal(const Ptr<TrainData>& trainData)
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
SimulatedAnnealingANN_MLP s(*this, trainData);
|
||||
trained = true; // Enable call to CalcError
|
||||
int iter = simulatedAnnealingSolver(s, params.initialT, params.finalT, params.coolingRatio, params.itePerStep, NULL, params.rEnergy);
|
||||
|
@ -88,6 +88,7 @@ public:
|
||||
|
||||
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
DTreesImpl::startTraining(trainData, flags);
|
||||
sumResult.assign(w->sidx.size(), 0.);
|
||||
|
||||
@ -184,6 +185,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
startTraining(trainData, flags);
|
||||
int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
|
||||
vector<int> sidx = w->sidx;
|
||||
@ -482,6 +484,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
return impl.train(trainData, flags);
|
||||
}
|
||||
|
||||
|
@ -112,6 +112,7 @@ public:
|
||||
|
||||
bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
Mat samples = data->getTrainSamples(), labels;
|
||||
return trainEM(samples, labels, noArray(), noArray());
|
||||
}
|
||||
|
@ -59,9 +59,10 @@ bool StatModel::empty() const { return !isTrained(); }
|
||||
|
||||
int StatModel::getVarCount() const { return 0; }
|
||||
|
||||
bool StatModel::train( const Ptr<TrainData>&, int )
|
||||
bool StatModel::train(const Ptr<TrainData>& trainData, int )
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!trainData.empty());
|
||||
CV_Error(CV_StsNotImplemented, "");
|
||||
return false;
|
||||
}
|
||||
@ -69,6 +70,7 @@ bool StatModel::train( const Ptr<TrainData>&, int )
|
||||
bool StatModel::train( InputArray samples, int layout, InputArray responses )
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!samples.empty());
|
||||
return train(TrainData::create(samples, layout, responses));
|
||||
}
|
||||
|
||||
@ -134,6 +136,7 @@ public:
|
||||
float StatModel::calcError(const Ptr<TrainData>& data, bool testerr, OutputArray _resp) const
|
||||
{
|
||||
CV_TRACE_FUNCTION_SKIP_NESTED();
|
||||
CV_Assert(!data.empty());
|
||||
Mat samples = data->getSamples();
|
||||
Mat sidx = testerr ? data->getTestSampleIdx() : data->getTrainSampleIdx();
|
||||
Mat weights = testerr ? data->getTestSampleWeights() : data->getTrainSampleWeights();
|
||||
|
@ -73,6 +73,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& data, int flags )
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
|
||||
Mat new_responses;
|
||||
data->getTrainResponses().convertTo(new_responses, CV_32F);
|
||||
@ -494,6 +495,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& data, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
return impl->train(data, flags);
|
||||
}
|
||||
|
||||
|
@ -142,12 +142,10 @@ Ptr<LogisticRegression> LogisticRegression::load(const String& filepath, const S
|
||||
bool LogisticRegressionImpl::train(const Ptr<TrainData>& trainData, int)
|
||||
{
|
||||
CV_TRACE_FUNCTION_SKIP_NESTED();
|
||||
CV_Assert(!trainData.empty());
|
||||
|
||||
// return value
|
||||
bool ok = false;
|
||||
|
||||
if (trainData.empty()) {
|
||||
return false;
|
||||
}
|
||||
clear();
|
||||
Mat _data_i = trainData->getSamples();
|
||||
Mat _labels_i = trainData->getResponses();
|
||||
|
@ -54,6 +54,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
const float min_variation = FLT_EPSILON;
|
||||
Mat responses = trainData->getNormCatResponses();
|
||||
Mat __cls_labels = trainData->getClassLabels();
|
||||
|
@ -111,6 +111,7 @@ public:
|
||||
void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!trainData.empty());
|
||||
DTreesImpl::startTraining(trainData, flags);
|
||||
int nvars = w->data->getNVars();
|
||||
int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
|
||||
@ -133,6 +134,7 @@ public:
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!trainData.empty());
|
||||
startTraining(trainData, flags);
|
||||
int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
|
||||
rparams.termCrit.maxCount : 10000;
|
||||
@ -463,6 +465,7 @@ public:
|
||||
bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!trainData.empty());
|
||||
if (impl.getCVFolds() != 0)
|
||||
CV_Error(Error::StsBadArg, "Cross validation for RTrees is not implemented");
|
||||
return impl.train(trainData, flags);
|
||||
|
@ -1613,6 +1613,7 @@ public:
|
||||
|
||||
bool train( const Ptr<TrainData>& data, int ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
clear();
|
||||
|
||||
checkParams();
|
||||
@ -1739,6 +1740,7 @@ public:
|
||||
ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
|
||||
bool balanced ) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
checkParams();
|
||||
|
||||
int svmType = params.svmType;
|
||||
|
@ -230,6 +230,7 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||
|
||||
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
clear();
|
||||
CV_Assert( isClassifier() ); //toDo: consider
|
||||
|
||||
|
@ -98,6 +98,7 @@ DTrees::Split::Split()
|
||||
|
||||
DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
|
||||
{
|
||||
CV_Assert(!_data.empty());
|
||||
data = _data;
|
||||
vector<int> subsampleIdx;
|
||||
Mat sidx0 = _data->getTrainSampleIdx();
|
||||
@ -136,6 +137,7 @@ void DTreesImpl::clear()
|
||||
|
||||
void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
|
||||
{
|
||||
CV_Assert(!data.empty());
|
||||
clear();
|
||||
w = makePtr<WorkData>(data);
|
||||
|
||||
@ -223,6 +225,7 @@ void DTreesImpl::endTraining()
|
||||
|
||||
bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
|
||||
{
|
||||
CV_Assert(!trainData.empty());
|
||||
startTraining(trainData, flags);
|
||||
bool ok = addTree( w->sidx ) >= 0;
|
||||
w.release();
|
||||
|
@ -94,11 +94,7 @@ void CV_LRTest::run( int /*start_from*/ )
|
||||
// initialize variables from the popular Iris Dataset
|
||||
string dataFileName = ts->get_data_path() + "iris.data";
|
||||
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
|
||||
|
||||
if (tdata.empty()) {
|
||||
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);
|
||||
return;
|
||||
}
|
||||
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << dataFileName;
|
||||
|
||||
// run LR classifier train classifier
|
||||
Ptr<LogisticRegression> p = LogisticRegression::create();
|
||||
@ -156,6 +152,7 @@ void CV_LRTest_SaveLoad::run( int /*start_from*/ )
|
||||
// initialize variables from the popular Iris Dataset
|
||||
string dataFileName = ts->get_data_path() + "iris.data";
|
||||
Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
|
||||
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << dataFileName;
|
||||
|
||||
Mat responses1, responses2;
|
||||
Mat learnt_mat1, learnt_mat2;
|
||||
|
@ -105,6 +105,7 @@ int str_to_ann_activation_function(String& str)
|
||||
void ann_check_data( Ptr<TrainData> _data )
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!_data.empty());
|
||||
Mat values = _data->getSamples();
|
||||
Mat var_idx = _data->getVarIdx();
|
||||
int nvars = (int)var_idx.total();
|
||||
@ -118,6 +119,7 @@ void ann_check_data( Ptr<TrainData> _data )
|
||||
Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map )
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!_data.empty());
|
||||
Mat train_sidx = _data->getTrainSampleIdx();
|
||||
int* train_sidx_ptr = train_sidx.ptr<int>();
|
||||
Mat responses = _data->getResponses();
|
||||
@ -150,6 +152,8 @@ Mat ann_get_new_responses( Ptr<TrainData> _data, map<int, int>& cls_map )
|
||||
float ann_calc_error( Ptr<StatModel> ann, Ptr<TrainData> _data, map<int, int>& cls_map, int type, vector<float> *resp_labels )
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_Assert(!ann.empty());
|
||||
CV_Assert(!_data.empty());
|
||||
float err = 0;
|
||||
Mat samples = _data->getSamples();
|
||||
Mat responses = _data->getResponses();
|
||||
@ -264,13 +268,15 @@ TEST_P(ML_ANN_METHOD, Test)
|
||||
String dataname = folder + "waveform" + '_' + methodName;
|
||||
|
||||
Ptr<TrainData> tdata2 = TrainData::loadFromCSV(original_path, 0);
|
||||
ASSERT_FALSE(tdata2.empty()) << "Could not find test data file : " << original_path;
|
||||
|
||||
Mat samples = tdata2->getSamples()(Range(0, N), Range::all());
|
||||
Mat responses(N, 3, CV_32FC1, Scalar(0));
|
||||
for (int i = 0; i < N; i++)
|
||||
responses.at<float>(i, static_cast<int>(tdata2->getResponses().at<float>(i, 0))) = 1;
|
||||
Ptr<TrainData> tdata = TrainData::create(samples, ml::ROW_SAMPLE, responses);
|
||||
ASSERT_FALSE(tdata.empty());
|
||||
|
||||
ASSERT_FALSE(tdata.empty()) << "Could not find test data file : " << original_path;
|
||||
RNG& rng = theRNG();
|
||||
rng.state = 0;
|
||||
tdata->setTrainTestSplitRatio(0.8);
|
||||
|
Loading…
Reference in New Issue
Block a user