mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #11384 from berak:ml_virtual
This commit is contained in:
commit
1e0a60be2a
@ -198,7 +198,7 @@ public:
|
|||||||
CV_WRAP virtual Mat getTestSampleWeights() const = 0;
|
CV_WRAP virtual Mat getTestSampleWeights() const = 0;
|
||||||
CV_WRAP virtual Mat getVarIdx() const = 0;
|
CV_WRAP virtual Mat getVarIdx() const = 0;
|
||||||
CV_WRAP virtual Mat getVarType() const = 0;
|
CV_WRAP virtual Mat getVarType() const = 0;
|
||||||
CV_WRAP Mat getVarSymbolFlags() const;
|
CV_WRAP virtual Mat getVarSymbolFlags() const = 0;
|
||||||
CV_WRAP virtual int getResponseType() const = 0;
|
CV_WRAP virtual int getResponseType() const = 0;
|
||||||
CV_WRAP virtual Mat getTrainSampleIdx() const = 0;
|
CV_WRAP virtual Mat getTrainSampleIdx() const = 0;
|
||||||
CV_WRAP virtual Mat getTestSampleIdx() const = 0;
|
CV_WRAP virtual Mat getTestSampleIdx() const = 0;
|
||||||
@ -234,10 +234,10 @@ public:
|
|||||||
CV_WRAP virtual void shuffleTrainTest() = 0;
|
CV_WRAP virtual void shuffleTrainTest() = 0;
|
||||||
|
|
||||||
/** @brief Returns matrix of test samples */
|
/** @brief Returns matrix of test samples */
|
||||||
CV_WRAP Mat getTestSamples() const;
|
CV_WRAP virtual Mat getTestSamples() const = 0;
|
||||||
|
|
||||||
/** @brief Returns vector of symbolic names captured in loadFromCSV() */
|
/** @brief Returns vector of symbolic names captured in loadFromCSV() */
|
||||||
CV_WRAP void getNames(std::vector<String>& names) const;
|
CV_WRAP virtual void getNames(std::vector<String>& names) const = 0;
|
||||||
|
|
||||||
CV_WRAP static Mat getSubVector(const Mat& vec, const Mat& idx);
|
CV_WRAP static Mat getSubVector(const Mat& vec, const Mat& idx);
|
||||||
|
|
||||||
@ -727,7 +727,7 @@ public:
|
|||||||
regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and
|
regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and
|
||||||
the usual %SVM with parameters specified in params is executed.
|
the usual %SVM with parameters specified in params is executed.
|
||||||
*/
|
*/
|
||||||
CV_WRAP bool trainAuto(InputArray samples,
|
CV_WRAP virtual bool trainAuto(InputArray samples,
|
||||||
int layout,
|
int layout,
|
||||||
InputArray responses,
|
InputArray responses,
|
||||||
int kFold = 10,
|
int kFold = 10,
|
||||||
@ -737,7 +737,7 @@ public:
|
|||||||
Ptr<ParamGrid> nuGrid = SVM::getDefaultGridPtr(SVM::NU),
|
Ptr<ParamGrid> nuGrid = SVM::getDefaultGridPtr(SVM::NU),
|
||||||
Ptr<ParamGrid> coeffGrid = SVM::getDefaultGridPtr(SVM::COEF),
|
Ptr<ParamGrid> coeffGrid = SVM::getDefaultGridPtr(SVM::COEF),
|
||||||
Ptr<ParamGrid> degreeGrid = SVM::getDefaultGridPtr(SVM::DEGREE),
|
Ptr<ParamGrid> degreeGrid = SVM::getDefaultGridPtr(SVM::DEGREE),
|
||||||
bool balanced=false);
|
bool balanced=false) = 0;
|
||||||
|
|
||||||
/** @brief Retrieves all the support vectors
|
/** @brief Retrieves all the support vectors
|
||||||
|
|
||||||
@ -752,7 +752,7 @@ public:
|
|||||||
support vector, used for prediction, was derived from. They are returned in a floating-point
|
support vector, used for prediction, was derived from. They are returned in a floating-point
|
||||||
matrix, where the support vectors are stored as matrix rows.
|
matrix, where the support vectors are stored as matrix rows.
|
||||||
*/
|
*/
|
||||||
CV_WRAP Mat getUncompressedSupportVectors() const;
|
CV_WRAP virtual Mat getUncompressedSupportVectors() const = 0;
|
||||||
|
|
||||||
/** @brief Retrieves the decision function
|
/** @brief Retrieves the decision function
|
||||||
|
|
||||||
@ -1273,7 +1273,7 @@ public:
|
|||||||
@param results Array where the result of the calculation will be written.
|
@param results Array where the result of the calculation will be written.
|
||||||
@param flags Flags for defining the type of RTrees.
|
@param flags Flags for defining the type of RTrees.
|
||||||
*/
|
*/
|
||||||
CV_WRAP void getVotes(InputArray samples, OutputArray results, int flags) const;
|
CV_WRAP virtual void getVotes(InputArray samples, OutputArray results, int flags) const = 0;
|
||||||
|
|
||||||
/** Creates the empty model.
|
/** Creates the empty model.
|
||||||
Use StatModel::train to train the model, StatModel::train to create and train the model,
|
Use StatModel::train to train the model, StatModel::train to create and train the model,
|
||||||
|
@ -50,13 +50,6 @@ static const int VAR_MISSED = VAR_ORDERED;
|
|||||||
|
|
||||||
TrainData::~TrainData() {}
|
TrainData::~TrainData() {}
|
||||||
|
|
||||||
Mat TrainData::getTestSamples() const
|
|
||||||
{
|
|
||||||
Mat idx = getTestSampleIdx();
|
|
||||||
Mat samples = getSamples();
|
|
||||||
return idx.empty() ? Mat() : getSubVector(samples, idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
|
Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
|
||||||
{
|
{
|
||||||
if( idx.empty() )
|
if( idx.empty() )
|
||||||
@ -119,6 +112,7 @@ Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
|
|||||||
return subvec;
|
return subvec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TrainDataImpl CV_FINAL : public TrainData
|
class TrainDataImpl CV_FINAL : public TrainData
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -155,6 +149,12 @@ public:
|
|||||||
return layout == ROW_SAMPLE ? samples.cols : samples.rows;
|
return layout == ROW_SAMPLE ? samples.cols : samples.rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Mat getTestSamples() const CV_OVERRIDE
|
||||||
|
{
|
||||||
|
Mat idx = getTestSampleIdx();
|
||||||
|
return idx.empty() ? Mat() : getSubVector(samples, idx);
|
||||||
|
}
|
||||||
|
|
||||||
Mat getSamples() const CV_OVERRIDE { return samples; }
|
Mat getSamples() const CV_OVERRIDE { return samples; }
|
||||||
Mat getResponses() const CV_OVERRIDE { return responses; }
|
Mat getResponses() const CV_OVERRIDE { return responses; }
|
||||||
Mat getMissing() const CV_OVERRIDE { return missing; }
|
Mat getMissing() const CV_OVERRIDE { return missing; }
|
||||||
@ -987,6 +987,27 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void getNames(std::vector<String>& names) const CV_OVERRIDE
|
||||||
|
{
|
||||||
|
size_t n = nameMap.size();
|
||||||
|
TrainDataImpl::MapType::const_iterator it = nameMap.begin(),
|
||||||
|
it_end = nameMap.end();
|
||||||
|
names.resize(n+1);
|
||||||
|
names[0] = "?";
|
||||||
|
for( ; it != it_end; ++it )
|
||||||
|
{
|
||||||
|
String s = it->first;
|
||||||
|
int label = it->second;
|
||||||
|
CV_Assert( label > 0 && label <= (int)n );
|
||||||
|
names[label] = s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat getVarSymbolFlags() const CV_OVERRIDE
|
||||||
|
{
|
||||||
|
return varSymbolFlags;
|
||||||
|
}
|
||||||
|
|
||||||
FILE* file;
|
FILE* file;
|
||||||
int layout;
|
int layout;
|
||||||
Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
|
Mat samples, missing, varType, varIdx, varSymbolFlags, responses, missingSubst;
|
||||||
@ -996,30 +1017,6 @@ public:
|
|||||||
MapType nameMap;
|
MapType nameMap;
|
||||||
};
|
};
|
||||||
|
|
||||||
void TrainData::getNames(std::vector<String>& names) const
|
|
||||||
{
|
|
||||||
const TrainDataImpl* impl = dynamic_cast<const TrainDataImpl*>(this);
|
|
||||||
CV_Assert(impl != 0);
|
|
||||||
size_t n = impl->nameMap.size();
|
|
||||||
TrainDataImpl::MapType::const_iterator it = impl->nameMap.begin(),
|
|
||||||
it_end = impl->nameMap.end();
|
|
||||||
names.resize(n+1);
|
|
||||||
names[0] = "?";
|
|
||||||
for( ; it != it_end; ++it )
|
|
||||||
{
|
|
||||||
String s = it->first;
|
|
||||||
int label = it->second;
|
|
||||||
CV_Assert( label > 0 && label <= (int)n );
|
|
||||||
names[label] = s;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Mat TrainData::getVarSymbolFlags() const
|
|
||||||
{
|
|
||||||
const TrainDataImpl* impl = dynamic_cast<const TrainDataImpl*>(this);
|
|
||||||
CV_Assert(impl != 0);
|
|
||||||
return impl->varSymbolFlags;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
|
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
|
||||||
int headerLines,
|
int headerLines,
|
||||||
|
@ -453,6 +453,7 @@ public:
|
|||||||
inline void setRegressionAccuracy(float val) CV_OVERRIDE { impl.params.setRegressionAccuracy(val); }
|
inline void setRegressionAccuracy(float val) CV_OVERRIDE { impl.params.setRegressionAccuracy(val); }
|
||||||
inline cv::Mat getPriors() const CV_OVERRIDE { return impl.params.getPriors(); }
|
inline cv::Mat getPriors() const CV_OVERRIDE { return impl.params.getPriors(); }
|
||||||
inline void setPriors(const cv::Mat& val) CV_OVERRIDE { impl.params.setPriors(val); }
|
inline void setPriors(const cv::Mat& val) CV_OVERRIDE { impl.params.setPriors(val); }
|
||||||
|
inline void getVotes(InputArray input, OutputArray output, int flags) const CV_OVERRIDE {return impl.getVotes(input,output,flags);}
|
||||||
|
|
||||||
RTreesImpl() {}
|
RTreesImpl() {}
|
||||||
virtual ~RTreesImpl() CV_OVERRIDE {}
|
virtual ~RTreesImpl() CV_OVERRIDE {}
|
||||||
@ -485,12 +486,6 @@ public:
|
|||||||
impl.read(fn);
|
impl.read(fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
void getVotes_( InputArray samples, OutputArray results, int flags ) const
|
|
||||||
{
|
|
||||||
CV_TRACE_FUNCTION();
|
|
||||||
impl.getVotes(samples, results, flags);
|
|
||||||
}
|
|
||||||
|
|
||||||
Mat getVarImportance() const CV_OVERRIDE { return Mat_<float>(impl.varImportance, true); }
|
Mat getVarImportance() const CV_OVERRIDE { return Mat_<float>(impl.varImportance, true); }
|
||||||
int getVarCount() const CV_OVERRIDE { return impl.getVarCount(); }
|
int getVarCount() const CV_OVERRIDE { return impl.getVarCount(); }
|
||||||
|
|
||||||
@ -519,15 +514,6 @@ Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
|
|||||||
return Algorithm::load<RTrees>(filepath, nodeName);
|
return Algorithm::load<RTrees>(filepath, nodeName);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RTrees::getVotes(InputArray input, OutputArray output, int flags) const
|
|
||||||
{
|
|
||||||
CV_TRACE_FUNCTION();
|
|
||||||
const RTreesImpl* this_ = dynamic_cast<const RTreesImpl*>(this);
|
|
||||||
if(!this_)
|
|
||||||
CV_Error(Error::StsNotImplemented, "the class is not RTreesImpl");
|
|
||||||
return this_->getVotes_(input, output, flags);
|
|
||||||
}
|
|
||||||
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
// End of file.
|
// End of file.
|
||||||
|
@ -1250,7 +1250,7 @@ public:
|
|||||||
uncompressed_sv.release();
|
uncompressed_sv.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
Mat getUncompressedSupportVectors_() const
|
Mat getUncompressedSupportVectors() const CV_OVERRIDE
|
||||||
{
|
{
|
||||||
return uncompressed_sv;
|
return uncompressed_sv;
|
||||||
}
|
}
|
||||||
@ -1982,10 +1982,10 @@ public:
|
|||||||
bool returnDFVal;
|
bool returnDFVal;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool trainAuto_(InputArray samples, int layout,
|
bool trainAuto(InputArray samples, int layout,
|
||||||
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
|
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
|
||||||
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
|
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
|
||||||
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
|
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced) CV_OVERRIDE
|
||||||
{
|
{
|
||||||
Ptr<TrainData> data = TrainData::create(samples, layout, responses);
|
Ptr<TrainData> data = TrainData::create(samples, layout, responses);
|
||||||
return this->trainAuto(
|
return this->trainAuto(
|
||||||
@ -2353,26 +2353,6 @@ Ptr<SVM> SVM::load(const String& filepath)
|
|||||||
return svm;
|
return svm;
|
||||||
}
|
}
|
||||||
|
|
||||||
Mat SVM::getUncompressedSupportVectors() const
|
|
||||||
{
|
|
||||||
const SVMImpl* this_ = dynamic_cast<const SVMImpl*>(this);
|
|
||||||
if(!this_)
|
|
||||||
CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
|
|
||||||
return this_->getUncompressedSupportVectors_();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool SVM::trainAuto(InputArray samples, int layout,
|
|
||||||
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
|
|
||||||
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
|
|
||||||
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
|
|
||||||
{
|
|
||||||
SVMImpl* this_ = dynamic_cast<SVMImpl*>(this);
|
|
||||||
if (!this_) {
|
|
||||||
CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
|
|
||||||
}
|
|
||||||
return this_->trainAuto_(samples, layout, responses,
|
|
||||||
kfold, Cgrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user