mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 12:22:51 +08:00
* export SVM::trainAuto to python #7224 * workaround for ABI compatibility of SVM::trainAuto * add parameter comments to new SVM::trainAuto function * Export ParamGrid member variables
This commit is contained in:
parent
1857aa22b3
commit
f70cc29edb
@ -104,7 +104,7 @@ enum SampleTypes
|
|||||||
It is used for optimizing statmodel accuracy by varying model parameters, the accuracy estimate
|
It is used for optimizing statmodel accuracy by varying model parameters, the accuracy estimate
|
||||||
being computed by cross-validation.
|
being computed by cross-validation.
|
||||||
*/
|
*/
|
||||||
class CV_EXPORTS ParamGrid
|
class CV_EXPORTS_W ParamGrid
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
/** @brief Default constructor */
|
/** @brief Default constructor */
|
||||||
@ -112,8 +112,8 @@ public:
|
|||||||
/** @brief Constructor with parameters */
|
/** @brief Constructor with parameters */
|
||||||
ParamGrid(double _minVal, double _maxVal, double _logStep);
|
ParamGrid(double _minVal, double _maxVal, double _logStep);
|
||||||
|
|
||||||
double minVal; //!< Minimum value of the statmodel parameter. Default value is 0.
|
CV_PROP_RW double minVal; //!< Minimum value of the statmodel parameter. Default value is 0.
|
||||||
double maxVal; //!< Maximum value of the statmodel parameter. Default value is 0.
|
CV_PROP_RW double maxVal; //!< Maximum value of the statmodel parameter. Default value is 0.
|
||||||
/** @brief Logarithmic step for iterating the statmodel parameter.
|
/** @brief Logarithmic step for iterating the statmodel parameter.
|
||||||
|
|
||||||
The grid determines the following iteration sequence of the statmodel parameter values:
|
The grid determines the following iteration sequence of the statmodel parameter values:
|
||||||
@ -122,7 +122,15 @@ public:
|
|||||||
\f[\texttt{minVal} * \texttt{logStep} ^n < \texttt{maxVal}\f]
|
\f[\texttt{minVal} * \texttt{logStep} ^n < \texttt{maxVal}\f]
|
||||||
The grid is logarithmic, so logStep must always be greater then 1. Default value is 1.
|
The grid is logarithmic, so logStep must always be greater then 1. Default value is 1.
|
||||||
*/
|
*/
|
||||||
double logStep;
|
CV_PROP_RW double logStep;
|
||||||
|
|
||||||
|
/** @brief Creates a ParamGrid Ptr that can be given to the %SVM::trainAuto method
|
||||||
|
|
||||||
|
@param minVal minimum value of the parameter grid
|
||||||
|
@param maxVal maximum value of the parameter grid
|
||||||
|
@param logstep Logarithmic step for iterating the statmodel parameter
|
||||||
|
*/
|
||||||
|
CV_WRAP static Ptr<ParamGrid> create(double minVal=0., double maxVal=0., double logstep=1.);
|
||||||
};
|
};
|
||||||
|
|
||||||
/** @brief Class encapsulating training data.
|
/** @brief Class encapsulating training data.
|
||||||
@ -691,6 +699,46 @@ public:
|
|||||||
ParamGrid degreeGrid = getDefaultGrid(DEGREE),
|
ParamGrid degreeGrid = getDefaultGrid(DEGREE),
|
||||||
bool balanced=false) = 0;
|
bool balanced=false) = 0;
|
||||||
|
|
||||||
|
/** @brief Trains an %SVM with optimal parameters
|
||||||
|
|
||||||
|
@param samples training samples
|
||||||
|
@param layout See ml::SampleTypes.
|
||||||
|
@param responses vector of responses associated with the training samples.
|
||||||
|
@param kFold Cross-validation parameter. The training set is divided into kFold subsets. One
|
||||||
|
subset is used to test the model, the others form the train set. So, the %SVM algorithm is
|
||||||
|
@param Cgrid grid for C
|
||||||
|
@param gammaGrid grid for gamma
|
||||||
|
@param pGrid grid for p
|
||||||
|
@param nuGrid grid for nu
|
||||||
|
@param coeffGrid grid for coeff
|
||||||
|
@param degreeGrid grid for degree
|
||||||
|
@param balanced If true and the problem is 2-class classification then the method creates more
|
||||||
|
balanced cross-validation subsets that is proportions between classes in subsets are close
|
||||||
|
to such proportion in the whole train dataset.
|
||||||
|
|
||||||
|
The method trains the %SVM model automatically by choosing the optimal parameters C, gamma, p,
|
||||||
|
nu, coef0, degree. Parameters are considered optimal when the cross-validation
|
||||||
|
estimate of the test set error is minimal.
|
||||||
|
|
||||||
|
This function only makes use of SVM::getDefaultGrid for parameter optimization and thus only
|
||||||
|
offers rudimentary parameter options.
|
||||||
|
|
||||||
|
This function works for the classification (SVM::C_SVC or SVM::NU_SVC) as well as for the
|
||||||
|
regression (SVM::EPS_SVR or SVM::NU_SVR). If it is SVM::ONE_CLASS, no optimization is made and
|
||||||
|
the usual %SVM with parameters specified in params is executed.
|
||||||
|
*/
|
||||||
|
CV_WRAP bool trainAuto(InputArray samples,
|
||||||
|
int layout,
|
||||||
|
InputArray responses,
|
||||||
|
int kFold = 10,
|
||||||
|
Ptr<ParamGrid> Cgrid = SVM::getDefaultGridPtr(SVM::C),
|
||||||
|
Ptr<ParamGrid> gammaGrid = SVM::getDefaultGridPtr(SVM::GAMMA),
|
||||||
|
Ptr<ParamGrid> pGrid = SVM::getDefaultGridPtr(SVM::P),
|
||||||
|
Ptr<ParamGrid> nuGrid = SVM::getDefaultGridPtr(SVM::NU),
|
||||||
|
Ptr<ParamGrid> coeffGrid = SVM::getDefaultGridPtr(SVM::COEF),
|
||||||
|
Ptr<ParamGrid> degreeGrid = SVM::getDefaultGridPtr(SVM::DEGREE),
|
||||||
|
bool balanced=false);
|
||||||
|
|
||||||
/** @brief Retrieves all the support vectors
|
/** @brief Retrieves all the support vectors
|
||||||
|
|
||||||
The method returns all the support vectors as a floating-point matrix, where support vectors are
|
The method returns all the support vectors as a floating-point matrix, where support vectors are
|
||||||
@ -733,6 +781,16 @@ public:
|
|||||||
*/
|
*/
|
||||||
static ParamGrid getDefaultGrid( int param_id );
|
static ParamGrid getDefaultGrid( int param_id );
|
||||||
|
|
||||||
|
/** @brief Generates a grid for %SVM parameters.
|
||||||
|
|
||||||
|
@param param_id %SVM parameters IDs that must be one of the SVM::ParamTypes. The grid is
|
||||||
|
generated for the parameter with this ID.
|
||||||
|
|
||||||
|
The function generates a grid pointer for the specified parameter of the %SVM algorithm.
|
||||||
|
The grid may be passed to the function SVM::trainAuto.
|
||||||
|
*/
|
||||||
|
CV_WRAP static Ptr<ParamGrid> getDefaultGridPtr( int param_id );
|
||||||
|
|
||||||
/** Creates empty model.
|
/** Creates empty model.
|
||||||
Use StatModel::train to train the model. Since %SVM has several parameters, you may want to
|
Use StatModel::train to train the model. Since %SVM has several parameters, you may want to
|
||||||
find the best parameters for your problem, it can be done with SVM::trainAuto. */
|
find the best parameters for your problem, it can be done with SVM::trainAuto. */
|
||||||
|
@ -50,6 +50,10 @@ ParamGrid::ParamGrid(double _minVal, double _maxVal, double _logStep)
|
|||||||
logStep = std::max(_logStep, 1.);
|
logStep = std::max(_logStep, 1.);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ptr<ParamGrid> ParamGrid::create(double minval, double maxval, double logstep) {
|
||||||
|
return makePtr<ParamGrid>(minval, maxval, logstep);
|
||||||
|
}
|
||||||
|
|
||||||
bool StatModel::empty() const { return !isTrained(); }
|
bool StatModel::empty() const { return !isTrained(); }
|
||||||
|
|
||||||
int StatModel::getVarCount() const { return 0; }
|
int StatModel::getVarCount() const { return 0; }
|
||||||
|
@ -362,6 +362,12 @@ static void sortSamplesByClasses( const Mat& _samples, const Mat& _responses,
|
|||||||
|
|
||||||
//////////////////////// SVM implementation //////////////////////////////
|
//////////////////////// SVM implementation //////////////////////////////
|
||||||
|
|
||||||
|
Ptr<ParamGrid> SVM::getDefaultGridPtr( int param_id)
|
||||||
|
{
|
||||||
|
ParamGrid grid = getDefaultGrid(param_id); // this is not a nice solution..
|
||||||
|
return makePtr<ParamGrid>(grid.minVal, grid.maxVal, grid.logStep);
|
||||||
|
}
|
||||||
|
|
||||||
ParamGrid SVM::getDefaultGrid( int param_id )
|
ParamGrid SVM::getDefaultGrid( int param_id )
|
||||||
{
|
{
|
||||||
ParamGrid grid;
|
ParamGrid grid;
|
||||||
@ -1920,6 +1926,24 @@ public:
|
|||||||
bool returnDFVal;
|
bool returnDFVal;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
bool trainAuto_(InputArray samples, int layout,
|
||||||
|
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
|
||||||
|
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
|
||||||
|
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
|
||||||
|
{
|
||||||
|
Ptr<TrainData> data = TrainData::create(samples, layout, responses);
|
||||||
|
return this->trainAuto(
|
||||||
|
data, kfold,
|
||||||
|
*Cgrid.get(),
|
||||||
|
*gammaGrid.get(),
|
||||||
|
*pGrid.get(),
|
||||||
|
*nuGrid.get(),
|
||||||
|
*coeffGrid.get(),
|
||||||
|
*degreeGrid.get(),
|
||||||
|
balanced);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
float predict( InputArray _samples, OutputArray _results, int flags ) const
|
float predict( InputArray _samples, OutputArray _results, int flags ) const
|
||||||
{
|
{
|
||||||
float result = 0;
|
float result = 0;
|
||||||
@ -2281,6 +2305,19 @@ Mat SVM::getUncompressedSupportVectors() const
|
|||||||
return this_->getUncompressedSupportVectors_();
|
return this_->getUncompressedSupportVectors_();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool SVM::trainAuto(InputArray samples, int layout,
|
||||||
|
InputArray responses, int kfold, Ptr<ParamGrid> Cgrid,
|
||||||
|
Ptr<ParamGrid> gammaGrid, Ptr<ParamGrid> pGrid, Ptr<ParamGrid> nuGrid,
|
||||||
|
Ptr<ParamGrid> coeffGrid, Ptr<ParamGrid> degreeGrid, bool balanced)
|
||||||
|
{
|
||||||
|
SVMImpl* this_ = dynamic_cast<SVMImpl*>(this);
|
||||||
|
if (!this_) {
|
||||||
|
CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
|
||||||
|
}
|
||||||
|
return this_->trainAuto_(samples, layout, responses,
|
||||||
|
kfold, Cgrid, gammaGrid, pGrid, nuGrid, coeffGrid, degreeGrid, balanced);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user