Deleted default value for parameters in docs.

Added some asserts.
This commit is contained in:
Marina Noskova 2016-02-25 19:12:54 +03:00
parent d484893839
commit 53711ec29d
4 changed files with 38 additions and 59 deletions

View File

@ -1542,7 +1542,7 @@ The other parameters may be described as follows:
Recommended value for SGD model is 0.0001, for ASGD model is 0.00001.
- Initial step size parameter is the initial value for the step size \f$\gamma(t)\f$.
You will have to find the best \f$\gamma_0\f$ for your problem.
You will have to find the best initial step for your problem.
- Step decreasing power is the power parameter for \f$\gamma(t)\f$ decreasing by the formula, mentioned above.
Recommended value for SGD model is 1, for ASGD model is 0.75.
@ -1605,31 +1605,15 @@ public:
*/
CV_WRAP virtual float getShift() = 0;
/** @brief Creates empty model.
Use StatModel::train to train the model. Since %SVMSGD has several parameters, you may want to
find the best parameters for your problem or use setOptimalParameters() to set some default parameters.
* Use StatModel::train to train the model. Since %SVMSGD has several parameters, you may want to
* find the best parameters for your problem or use setOptimalParameters() to set some default parameters.
*/
CV_WRAP static Ptr<SVMSGD> create();
/** @brief Function sets optimal parameters values for chosen SVM SGD model.
* If chosen type is ASGD, function sets the following values for parameters of model:
* marginRegularization = 0.00001;
* initialStepSize = 0.05;
* stepDecreasingPower = 0.75;
* termCrit.maxCount = 100000;
* termCrit.epsilon = 0.00001;
*
* If SGD:
* marginRegularization = 0.0001;
* initialStepSize = 0.05;
* stepDecreasingPower = 1;
* termCrit.maxCount = 100000;
* termCrit.epsilon = 0.00001;
* @param svmsgdType is the type of SVMSGD classifier. Legal values are SVMSGD::SvmsgdType::SGD and SVMSGD::SvmsgdType::ASGD.
* Recommended value is SVMSGD::SvmsgdType::ASGD (by default).
* @param marginType is the type of margin constraint. Legal values are SVMSGD::MarginType::SOFT_MARGIN and SVMSGD::MarginType::HARD_MARGIN.
* Default value is SVMSGD::MarginType::SOFT_MARGIN.
* @param svmsgdType is the type of SVMSGD classifier.
* @param marginType is the type of margin constraint.
*/
CV_WRAP virtual void setOptimalParameters(int svmsgdType = SVMSGD::ASGD, int marginType = SVMSGD::SOFT_MARGIN) = 0;
@ -1645,20 +1629,19 @@ public:
/** @copybrief getMarginType @see getMarginType */
CV_WRAP virtual void setMarginType(int marginType) = 0;
/** @brief Parameter marginRegularization of a %SVMSGD optimization problem. Default value is 0. */
/** @brief Parameter marginRegularization of a %SVMSGD optimization problem. */
/** @see setMarginRegularization */
CV_WRAP virtual float getMarginRegularization() const = 0;
/** @copybrief getMarginRegularization @see getMarginRegularization */
CV_WRAP virtual void setMarginRegularization(float marginRegularization) = 0;
/** @brief Parameter initialStepSize of a %SVMSGD optimization problem. Default value is 0. */
/** @brief Parameter initialStepSize of a %SVMSGD optimization problem. */
/** @see setInitialStepSize */
CV_WRAP virtual float getInitialStepSize() const = 0;
/** @copybrief getInitialStepSize @see getInitialStepSize */
CV_WRAP virtual void setInitialStepSize(float InitialStepSize) = 0;
/** @brief Parameter stepDecreasingPower of a %SVMSGD optimization problem. Default value is 0. */
/** @brief Parameter stepDecreasingPower of a %SVMSGD optimization problem. */
/** @see setStepDecreasingPower */
CV_WRAP virtual float getStepDecreasingPower() const = 0;
/** @copybrief getStepDecreasingPower @see getStepDecreasingPower */

View File

@ -97,7 +97,7 @@ public:
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
private:
void updateWeights(InputArray sample, bool isPositive, float stepSize, Mat &weights);
void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights);
void writeParams( FileStorage &fs ) const;
@ -111,8 +111,6 @@ private:
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
// Vector with SVM weights
Mat weights_;
float shift_;
@ -263,11 +261,12 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
RNG rng(0);
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS) && (trainResponses.type() == CV_32FC1));
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
double err = DBL_MAX;
CV_Assert (trainResponses.type() == CV_32FC1);
// Stochastic gradient descent SVM
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
{
@ -288,8 +287,8 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
}
else
{
err = norm(extendedWeights - previousWeights);
extendedWeights.copyTo(previousWeights);
err = norm(extendedWeights - previousWeights);
extendedWeights.copyTo(previousWeights);
}
}
@ -316,7 +315,6 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
return true;
}
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
{
float result = 0;
@ -417,17 +415,6 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
fs << "iterations" << params.termCrit.maxCount;
fs << "}";
}
void SVMSGDImpl::read(const FileNode& fn)
{
clear();
readParams(fn);
fn["weights"] >> weights_;
fn["shift"] >> shift_;
}
void SVMSGDImpl::readParams( const FileNode& fn )
{
String svmsgdTypeStr = (String)fn["svmsgdType"];
@ -443,7 +430,7 @@ void SVMSGDImpl::readParams( const FileNode& fn )
String marginTypeStr = (String)fn["marginType"];
int marginType =
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
if( marginType < 0 )
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
@ -460,16 +447,22 @@ void SVMSGDImpl::readParams( const FileNode& fn )
params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
FileNode tcnode = fn["term_criteria"];
if( !tcnode.empty() )
{
params.termCrit.epsilon = (double)tcnode["epsilon"];
params.termCrit.maxCount = (int)tcnode["iterations"];
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
}
else
params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 100000, FLT_EPSILON );
CV_Assert(!tcnode.empty());
params.termCrit.epsilon = (double)tcnode["epsilon"];
params.termCrit.maxCount = (int)tcnode["iterations"];
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS));
}
void SVMSGDImpl::read(const FileNode& fn)
{
clear();
readParams(fn);
fn["weights"] >> weights_;
fn["shift"] >> shift_;
}
void SVMSGDImpl::clear()
@ -492,7 +485,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
case SGD:
params.svmsgdType = SGD;
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
params.marginRegularization = 0.0001f;
params.initialStepSize = 0.05f;
params.stepDecreasingPower = 1.f;
@ -502,7 +495,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
case ASGD:
params.svmsgdType = ASGD;
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
params.marginRegularization = 0.00001f;
params.initialStepSize = 0.05f;
params.stepDecreasingPower = 0.75f;

View File

@ -62,7 +62,7 @@ public:
private:
virtual void run( int start_from );
static float decisionFunction(const Mat &sample, const Mat &weights, float shift);
void makeData(int samplesCount, Mat weights, float shift, RNG rng, Mat &samples, Mat & responses);
void makeData(int samplesCount, const Mat &weights, float shift, RNG &rng, Mat &samples, Mat & responses);
void generateSameBorders(int featureCount);
void generateDifferentBorders(int featureCount);
@ -112,7 +112,7 @@ float CV_SVMSGDTrainTest::decisionFunction(const Mat &sample, const Mat &weights
return static_cast<float>(sample.dot(weights)) + shift;
}
void CV_SVMSGDTrainTest::makeData(int samplesCount, Mat weights, float shift, RNG rng, Mat &samples, Mat & responses)
void CV_SVMSGDTrainTest::makeData(int samplesCount, const Mat &weights, float shift, RNG &rng, Mat &samples, Mat & responses)
{
int featureCount = weights.cols;
@ -175,6 +175,7 @@ void CV_SVMSGDTrainTest::run( int /*start_from*/ )
int errCount = 0;
int testSamplesCount = testSamples.rows;
CV_Assert((responses.type() == CV_32FC1) && (testResponses.type() == CV_32FC1));
for (int i = 0; i < testSamplesCount; i++)
{
if (responses.at<float>(i) * testResponses.at<float>(i) < 0)

View File

@ -91,6 +91,7 @@ bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<
int yMin = std::min(segment.first.y, segment.second.y);
int yMax = std::max(segment.first.y, segment.second.y);
CV_Assert(weights.type() == CV_32FC1);
CV_Assert(xMin == xMax || yMin == yMax);
if (xMin == xMax && weights.at<float>(1) != 0)
@ -146,6 +147,7 @@ void redraw(Data data, const Point points[2])
Point center;
int radius = 3;
Scalar color;
CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1));
for (int i = 0; i < data.samples.rows; i++)
{
center.x = static_cast<int>(data.samples.at<float>(i,0));
@ -160,14 +162,14 @@ void redraw(Data data, const Point points[2])
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
{
Mat currentSample(1, 2, CV_32F);
Mat currentSample(1, 2, CV_32FC1);
currentSample.at<float>(0,0) = (float)x;
currentSample.at<float>(0,1) = (float)y;
data.samples.push_back(currentSample);
data.responses.push_back(response);
Mat weights(1, 2, CV_32F);
Mat weights(1, 2, CV_32FC1);
float shift = 0;
if (doTrain(data.samples, data.responses, weights, shift))