2015-09-03 00:14:40 +08:00
|
|
|
/*M///////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
|
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
|
|
|
//
|
|
|
|
// By downloading, copying, installing or using the software you agree to this license.
|
|
|
|
// If you do not agree to this license, do not download, install,
|
|
|
|
// copy or use the software.
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// License Agreement
|
|
|
|
// For Open Source Computer Vision Library
|
|
|
|
//
|
|
|
|
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
2016-02-15 19:35:36 +08:00
|
|
|
// Copyright (C) 2016, Itseez Inc, all rights reserved.
|
2015-09-03 00:14:40 +08:00
|
|
|
// Third party copyrights are property of their respective owners.
|
|
|
|
//
|
|
|
|
// Redistribution and use in source and binary forms, with or without modification,
|
|
|
|
// are permitted provided that the following conditions are met:
|
|
|
|
//
|
|
|
|
// * Redistribution's of source code must retain the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer.
|
|
|
|
//
|
|
|
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer in the documentation
|
|
|
|
// and/or other materials provided with the distribution.
|
|
|
|
//
|
|
|
|
// * The name of the copyright holders may not be used to endorse or promote products
|
|
|
|
// derived from this software without specific prior written permission.
|
|
|
|
//
|
|
|
|
// This software is provided by the copyright holders and contributors "as is" and
|
|
|
|
// any express or implied warranties, including, but not limited to, the implied
|
|
|
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
|
|
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
|
|
|
// indirect, incidental, special, exemplary, or consequential damages
|
|
|
|
// (including, but not limited to, procurement of substitute goods or services;
|
|
|
|
// loss of use, data, or profits; or business interruption) however caused
|
|
|
|
// and on any theory of liability, whether in contract, strict liability,
|
|
|
|
// or tort (including negligence or otherwise) arising in any way out of
|
|
|
|
// the use of this software, even if advised of the possibility of such damage.
|
|
|
|
//
|
|
|
|
//M*/
|
|
|
|
|
|
|
|
#include "precomp.hpp"
|
2016-01-20 17:59:44 +08:00
|
|
|
#include "limits"
|
2016-02-03 20:31:05 +08:00
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
using std::cout;
|
|
|
|
using std::endl;
|
2015-09-03 00:14:40 +08:00
|
|
|
|
|
|
|
/****************************************************************************************\
|
|
|
|
* Stochastic Gradient Descent SVM Classifier *
|
|
|
|
\****************************************************************************************/
|
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
namespace cv
|
|
|
|
{
|
|
|
|
namespace ml
|
|
|
|
{
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
class SVMSGDImpl : public SVMSGD
|
|
|
|
{
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
public:
|
|
|
|
SVMSGDImpl();
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual ~SVMSGDImpl() {}
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual bool train(const Ptr<TrainData>& data, int);
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-03 20:31:05 +08:00
|
|
|
virtual bool isClassifier() const;
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual bool isTrained() const;
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual void clear();
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
virtual void write(FileStorage &fs) const;
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
virtual void read(const FileNode &fn);
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual Mat getWeights(){ return weights_; }
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual float getShift(){ return shift_; }
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual int getVarCount() const { return weights_.cols; }
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-15 20:09:59 +08:00
|
|
|
CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType)
|
|
|
|
CV_IMPL_PROPERTY(int, MarginType, params.marginType)
|
2016-02-24 18:22:07 +08:00
|
|
|
CV_IMPL_PROPERTY(float, MarginRegularization, params.marginRegularization)
|
|
|
|
CV_IMPL_PROPERTY(float, InitialStepSize, params.initialStepSize)
|
|
|
|
CV_IMPL_PROPERTY(float, StepDecreasingPower, params.stepDecreasingPower)
|
2016-01-20 17:59:44 +08:00
|
|
|
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-03 20:31:05 +08:00
|
|
|
private:
|
2016-02-26 00:12:54 +08:00
|
|
|
void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights);
|
2016-02-03 20:31:05 +08:00
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
void writeParams( FileStorage &fs ) const;
|
2016-02-03 20:31:05 +08:00
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
void readParams( const FileNode &fn );
|
2016-02-03 20:31:05 +08:00
|
|
|
|
2016-02-15 19:35:36 +08:00
|
|
|
static inline bool isPositive(float val) { return val > 0; }
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
|
2016-02-03 20:31:05 +08:00
|
|
|
|
|
|
|
float calcShift(InputArray _samples, InputArray _responses) const;
|
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
|
2016-02-03 20:31:05 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
// Vector with SVM weights
|
|
|
|
Mat weights_;
|
|
|
|
float shift_;
|
|
|
|
|
|
|
|
// Parameters for learning
|
|
|
|
struct SVMSGDParams
|
|
|
|
{
|
2016-02-24 18:22:07 +08:00
|
|
|
float marginRegularization;
|
|
|
|
float initialStepSize;
|
|
|
|
float stepDecreasingPower;
|
2016-01-20 17:59:44 +08:00
|
|
|
TermCriteria termCrit;
|
2016-02-15 20:09:59 +08:00
|
|
|
int svmsgdType;
|
|
|
|
int marginType;
|
2016-01-20 17:59:44 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
SVMSGDParams params;
|
|
|
|
};
|
|
|
|
|
|
|
|
Ptr<SVMSGD> SVMSGD::create()
|
2016-02-10 20:40:09 +08:00
|
|
|
{
|
2016-01-20 17:59:44 +08:00
|
|
|
return makePtr<SVMSGDImpl>();
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
|
2016-02-03 20:31:05 +08:00
|
|
|
{
|
|
|
|
int featuresCount = samples.cols;
|
|
|
|
int samplesCount = samples.rows;
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-03 20:31:05 +08:00
|
|
|
average = Mat(1, featuresCount, samples.type());
|
2016-02-25 21:57:03 +08:00
|
|
|
CV_Assert(average.type() == CV_32FC1);
|
2016-02-03 20:31:05 +08:00
|
|
|
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
|
2016-01-20 17:59:44 +08:00
|
|
|
{
|
2016-02-24 18:22:07 +08:00
|
|
|
average.at<float>(featureIndex) = static_cast<float>(mean(samples.col(featureIndex))[0]);
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
2016-02-03 20:31:05 +08:00
|
|
|
|
|
|
|
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
|
2016-01-20 17:59:44 +08:00
|
|
|
{
|
2016-02-03 20:31:05 +08:00
|
|
|
samples.row(sampleIndex) -= average;
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
double normValue = norm(samples);
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-03-14 15:02:09 +08:00
|
|
|
multiplier = static_cast<float>(sqrt(static_cast<double>(samples.total())) / normValue);
|
2016-02-09 23:42:23 +08:00
|
|
|
|
|
|
|
samples *= multiplier;
|
2016-02-03 20:31:05 +08:00
|
|
|
}
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
|
2016-02-03 20:31:05 +08:00
|
|
|
{
|
2016-02-15 19:35:36 +08:00
|
|
|
Mat normalizedTrainSamples = trainSamples.clone();
|
|
|
|
int samplesCount = normalizedTrainSamples.rows;
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-15 19:35:36 +08:00
|
|
|
normalizeSamples(normalizedTrainSamples, average, multiplier);
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-03 20:31:05 +08:00
|
|
|
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
|
2016-02-15 19:35:36 +08:00
|
|
|
cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples);
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
|
|
|
|
2016-02-25 21:57:03 +08:00
|
|
|
void SVMSGDImpl::updateWeights(InputArray _sample, bool positive, float stepSize, Mat& weights)
|
2016-01-20 17:59:44 +08:00
|
|
|
{
|
2016-02-03 20:31:05 +08:00
|
|
|
Mat sample = _sample.getMat();
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-25 21:57:03 +08:00
|
|
|
int response = positive ? 1 : -1; // ensure that trainResponses are -1 or 1
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-03 20:31:05 +08:00
|
|
|
if ( sample.dot(weights) * response > 1)
|
|
|
|
{
|
|
|
|
// Not a support vector, only apply weight decay
|
2016-02-24 18:22:07 +08:00
|
|
|
weights *= (1.f - stepSize * params.marginRegularization);
|
2016-02-03 20:31:05 +08:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
// It's a support vector, add it to the weights
|
2016-02-24 18:22:07 +08:00
|
|
|
weights -= (stepSize * params.marginRegularization) * weights - (stepSize * response) * sample;
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
|
|
|
{
|
2016-02-24 18:22:07 +08:00
|
|
|
float margin[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
Mat trainSamples = _samples.getMat();
|
|
|
|
int trainSamplesCount = trainSamples.rows;
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
Mat trainResponses = _responses.getMat();
|
|
|
|
|
2016-02-25 21:57:03 +08:00
|
|
|
CV_Assert(trainResponses.type() == CV_32FC1);
|
2016-01-20 17:59:44 +08:00
|
|
|
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
|
|
|
|
{
|
|
|
|
Mat currentSample = trainSamples.row(samplesIndex);
|
2016-02-11 00:46:24 +08:00
|
|
|
float dotProduct = static_cast<float>(currentSample.dot(weights_));
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
bool positive = isPositive(trainResponses.at<float>(samplesIndex));
|
|
|
|
int index = positive ? 0 : 1;
|
|
|
|
float signToMul = positive ? 1.f : -1.f;
|
|
|
|
float curMargin = dotProduct * signToMul;
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
if (curMargin < margin[index])
|
2016-01-20 17:59:44 +08:00
|
|
|
{
|
2016-02-24 18:22:07 +08:00
|
|
|
margin[index] = curMargin;
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
return -(margin[0] - margin[1]) / 2.f;
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
|
|
|
|
2016-02-03 20:31:05 +08:00
|
|
|
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
|
|
|
{
|
|
|
|
clear();
|
|
|
|
CV_Assert( isClassifier() ); //toDo: consider
|
|
|
|
|
|
|
|
Mat trainSamples = data->getTrainSamples();
|
|
|
|
|
|
|
|
int featureCount = trainSamples.cols;
|
|
|
|
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
|
|
|
|
2016-02-25 20:31:07 +08:00
|
|
|
CV_Assert(trainResponses.rows == trainSamples.rows);
|
2016-02-03 20:31:05 +08:00
|
|
|
|
2016-02-25 20:31:07 +08:00
|
|
|
if (trainResponses.empty())
|
2016-02-03 20:31:05 +08:00
|
|
|
{
|
|
|
|
return false;
|
|
|
|
}
|
2016-02-25 20:31:07 +08:00
|
|
|
|
|
|
|
int positiveCount = countNonZero(trainResponses >= 0);
|
|
|
|
int negativeCount = countNonZero(trainResponses < 0);
|
|
|
|
|
|
|
|
if ( positiveCount <= 0 || negativeCount <= 0 )
|
2016-02-03 20:31:05 +08:00
|
|
|
{
|
|
|
|
weights_ = Mat::zeros(1, featureCount, CV_32F);
|
2016-02-25 20:31:07 +08:00
|
|
|
shift_ = (positiveCount > 0) ? 1.f : -1.f;
|
2016-02-03 20:31:05 +08:00
|
|
|
return true;
|
2016-02-10 20:40:09 +08:00
|
|
|
}
|
2016-02-03 20:31:05 +08:00
|
|
|
|
|
|
|
Mat extendedTrainSamples;
|
2016-02-09 23:42:23 +08:00
|
|
|
Mat average;
|
|
|
|
float multiplier = 0;
|
|
|
|
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
|
2016-02-03 20:31:05 +08:00
|
|
|
|
|
|
|
int extendedTrainSamplesCount = extendedTrainSamples.rows;
|
|
|
|
int extendedFeatureCount = extendedTrainSamples.cols;
|
|
|
|
|
2016-02-15 20:09:59 +08:00
|
|
|
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
|
|
|
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
|
|
|
Mat averageExtendedWeights;
|
2016-02-03 20:31:05 +08:00
|
|
|
if (params.svmsgdType == ASGD)
|
|
|
|
{
|
|
|
|
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
|
|
|
}
|
|
|
|
|
|
|
|
RNG rng(0);
|
|
|
|
|
2016-02-26 00:12:54 +08:00
|
|
|
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
|
2016-02-03 20:31:05 +08:00
|
|
|
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
|
|
|
|
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
|
|
|
|
|
|
|
|
double err = DBL_MAX;
|
2016-02-26 00:12:54 +08:00
|
|
|
CV_Assert (trainResponses.type() == CV_32FC1);
|
2016-02-03 20:31:05 +08:00
|
|
|
// Stochastic gradient descent SVM
|
|
|
|
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
|
|
|
|
{
|
|
|
|
int randomNumber = rng.uniform(0, extendedTrainSamplesCount); //generate sample number
|
|
|
|
|
|
|
|
Mat currentSample = extendedTrainSamples.row(randomNumber);
|
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
float stepSize = params.initialStepSize * std::pow((1 + params.marginRegularization * params.initialStepSize * (float)iter), (-params.stepDecreasingPower)); //update stepSize
|
2016-02-03 20:31:05 +08:00
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), stepSize, extendedWeights );
|
2016-02-03 20:31:05 +08:00
|
|
|
|
|
|
|
//average weights (only for ASGD model)
|
|
|
|
if (params.svmsgdType == ASGD)
|
|
|
|
{
|
|
|
|
averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights + extendedWeights / (1 + (float) iter);
|
|
|
|
err = norm(averageExtendedWeights - previousWeights);
|
|
|
|
averageExtendedWeights.copyTo(previousWeights);
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
2016-02-26 00:12:54 +08:00
|
|
|
err = norm(extendedWeights - previousWeights);
|
|
|
|
extendedWeights.copyTo(previousWeights);
|
2016-02-03 20:31:05 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (params.svmsgdType == ASGD)
|
|
|
|
{
|
|
|
|
extendedWeights = averageExtendedWeights;
|
|
|
|
}
|
|
|
|
|
|
|
|
Rect roi(0, 0, featureCount, 1);
|
|
|
|
weights_ = extendedWeights(roi);
|
2016-02-09 23:42:23 +08:00
|
|
|
weights_ *= multiplier;
|
2016-02-03 20:31:05 +08:00
|
|
|
|
2016-02-25 21:57:03 +08:00
|
|
|
CV_Assert((params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN) && (extendedWeights.type() == CV_32FC1));
|
2016-02-03 20:31:05 +08:00
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
if (params.marginType == SOFT_MARGIN)
|
|
|
|
{
|
2016-02-11 00:46:24 +08:00
|
|
|
shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average));
|
2016-02-09 23:42:23 +08:00
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
shift_ = calcShift(trainSamples, trainResponses);
|
|
|
|
}
|
2016-02-03 20:31:05 +08:00
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
|
|
|
|
{
|
|
|
|
float result = 0;
|
|
|
|
cv::Mat samples = _samples.getMat();
|
|
|
|
int nSamples = samples.rows;
|
|
|
|
cv::Mat results;
|
2015-09-03 00:14:40 +08:00
|
|
|
|
2016-02-25 20:31:07 +08:00
|
|
|
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32FC1);
|
2016-01-20 17:59:44 +08:00
|
|
|
|
|
|
|
if( _results.needed() )
|
|
|
|
{
|
|
|
|
_results.create( nSamples, 1, samples.type() );
|
|
|
|
results = _results.getMat();
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
CV_Assert( nSamples == 1 );
|
2016-02-25 21:57:03 +08:00
|
|
|
results = Mat(1, 1, CV_32FC1, &result);
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
2016-01-20 17:59:44 +08:00
|
|
|
|
|
|
|
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
|
|
|
|
{
|
2016-02-03 20:31:05 +08:00
|
|
|
Mat currentSample = samples.row(sampleIndex);
|
2016-02-11 00:46:24 +08:00
|
|
|
float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_;
|
|
|
|
results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f;
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
|
|
|
|
2016-02-03 20:31:05 +08:00
|
|
|
bool SVMSGDImpl::isClassifier() const
|
2016-01-20 17:59:44 +08:00
|
|
|
{
|
2016-02-03 20:31:05 +08:00
|
|
|
return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
|
2016-02-09 23:42:23 +08:00
|
|
|
&&
|
|
|
|
(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
|
2016-02-03 20:31:05 +08:00
|
|
|
&&
|
2016-02-24 18:22:07 +08:00
|
|
|
(params.marginRegularization > 0) && (params.initialStepSize > 0) && (params.stepDecreasingPower >= 0);
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
bool SVMSGDImpl::isTrained() const
|
|
|
|
{
|
|
|
|
return !weights_.empty();
|
|
|
|
}
|
|
|
|
|
|
|
|
void SVMSGDImpl::write(FileStorage& fs) const
|
|
|
|
{
|
|
|
|
if( !isTrained() )
|
|
|
|
CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
|
|
|
|
|
|
|
|
writeParams( fs );
|
|
|
|
|
|
|
|
fs << "weights" << weights_;
|
2016-02-03 20:31:05 +08:00
|
|
|
fs << "shift" << shift_;
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
|
|
|
{
|
|
|
|
String SvmsgdTypeStr;
|
|
|
|
|
|
|
|
switch (params.svmsgdType)
|
|
|
|
{
|
|
|
|
case SGD:
|
|
|
|
SvmsgdTypeStr = "SGD";
|
|
|
|
break;
|
|
|
|
case ASGD:
|
|
|
|
SvmsgdTypeStr = "ASGD";
|
|
|
|
break;
|
|
|
|
default:
|
2016-02-15 20:09:59 +08:00
|
|
|
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
fs << "svmsgdType" << SvmsgdTypeStr;
|
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
String marginTypeStr;
|
|
|
|
|
|
|
|
switch (params.marginType)
|
|
|
|
{
|
|
|
|
case SOFT_MARGIN:
|
|
|
|
marginTypeStr = "SOFT_MARGIN";
|
|
|
|
break;
|
|
|
|
case HARD_MARGIN:
|
|
|
|
marginTypeStr = "HARD_MARGIN";
|
|
|
|
break;
|
|
|
|
default:
|
2016-02-15 20:09:59 +08:00
|
|
|
marginTypeStr = format("Unknown_%d", params.marginType);
|
2016-02-09 23:42:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
fs << "marginType" << marginTypeStr;
|
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
fs << "marginRegularization" << params.marginRegularization;
|
|
|
|
fs << "initialStepSize" << params.initialStepSize;
|
|
|
|
fs << "stepDecreasingPower" << params.stepDecreasingPower;
|
2016-01-20 17:59:44 +08:00
|
|
|
|
|
|
|
fs << "term_criteria" << "{:";
|
|
|
|
if( params.termCrit.type & TermCriteria::EPS )
|
|
|
|
fs << "epsilon" << params.termCrit.epsilon;
|
|
|
|
if( params.termCrit.type & TermCriteria::COUNT )
|
|
|
|
fs << "iterations" << params.termCrit.maxCount;
|
|
|
|
fs << "}";
|
|
|
|
}
|
|
|
|
void SVMSGDImpl::readParams( const FileNode& fn )
|
|
|
|
{
|
|
|
|
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
2016-02-15 20:09:59 +08:00
|
|
|
int svmsgdType =
|
2016-01-20 17:59:44 +08:00
|
|
|
svmsgdTypeStr == "SGD" ? SGD :
|
2016-02-15 20:09:59 +08:00
|
|
|
svmsgdTypeStr == "ASGD" ? ASGD : -1;
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-15 20:09:59 +08:00
|
|
|
if( svmsgdType < 0 )
|
2016-01-20 17:59:44 +08:00
|
|
|
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
|
|
|
|
|
|
|
|
params.svmsgdType = svmsgdType;
|
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
String marginTypeStr = (String)fn["marginType"];
|
2016-02-15 20:09:59 +08:00
|
|
|
int marginType =
|
2016-02-09 23:42:23 +08:00
|
|
|
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
|
2016-02-26 00:12:54 +08:00
|
|
|
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
|
2016-02-09 23:42:23 +08:00
|
|
|
|
2016-02-15 20:09:59 +08:00
|
|
|
if( marginType < 0 )
|
2016-02-09 23:42:23 +08:00
|
|
|
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
|
|
|
|
|
|
|
|
params.marginType = marginType;
|
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
CV_Assert ( fn["marginRegularization"].isReal() );
|
|
|
|
params.marginRegularization = (float)fn["marginRegularization"];
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
CV_Assert ( fn["initialStepSize"].isReal() );
|
|
|
|
params.initialStepSize = (float)fn["initialStepSize"];
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-24 18:22:07 +08:00
|
|
|
CV_Assert ( fn["stepDecreasingPower"].isReal() );
|
|
|
|
params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
|
2016-01-20 17:59:44 +08:00
|
|
|
|
|
|
|
FileNode tcnode = fn["term_criteria"];
|
2016-02-26 00:12:54 +08:00
|
|
|
CV_Assert(!tcnode.empty());
|
|
|
|
params.termCrit.epsilon = (double)tcnode["epsilon"];
|
|
|
|
params.termCrit.maxCount = (int)tcnode["iterations"];
|
|
|
|
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
|
|
|
|
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
|
|
|
|
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS));
|
|
|
|
}
|
2016-01-20 17:59:44 +08:00
|
|
|
|
2016-02-26 00:12:54 +08:00
|
|
|
void SVMSGDImpl::read(const FileNode& fn)
|
|
|
|
{
|
|
|
|
clear();
|
|
|
|
|
|
|
|
readParams(fn);
|
|
|
|
|
|
|
|
fn["weights"] >> weights_;
|
|
|
|
fn["shift"] >> shift_;
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
|
|
|
|
2016-01-20 17:59:44 +08:00
|
|
|
void SVMSGDImpl::clear()
|
|
|
|
{
|
|
|
|
weights_.release();
|
2016-02-10 21:44:16 +08:00
|
|
|
shift_ = 0;
|
2015-09-03 00:14:40 +08:00
|
|
|
}
|
2016-01-20 17:59:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
SVMSGDImpl::SVMSGDImpl()
|
|
|
|
{
|
|
|
|
clear();
|
2016-02-25 20:31:07 +08:00
|
|
|
setOptimalParameters();
|
2016-01-20 17:59:44 +08:00
|
|
|
}
|
|
|
|
|
2016-02-09 23:42:23 +08:00
|
|
|
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
2016-01-20 17:59:44 +08:00
|
|
|
{
|
2016-02-09 23:42:23 +08:00
|
|
|
switch (svmsgdType)
|
2016-01-20 17:59:44 +08:00
|
|
|
{
|
|
|
|
case SGD:
|
|
|
|
params.svmsgdType = SGD;
|
2016-02-09 23:42:23 +08:00
|
|
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
2016-02-26 00:12:54 +08:00
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
2016-02-24 18:22:07 +08:00
|
|
|
params.marginRegularization = 0.0001f;
|
|
|
|
params.initialStepSize = 0.05f;
|
|
|
|
params.stepDecreasingPower = 1.f;
|
2016-02-09 23:42:23 +08:00
|
|
|
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
|
2016-01-20 17:59:44 +08:00
|
|
|
break;
|
|
|
|
|
|
|
|
case ASGD:
|
|
|
|
params.svmsgdType = ASGD;
|
2016-02-09 23:42:23 +08:00
|
|
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
2016-02-26 00:12:54 +08:00
|
|
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
2016-02-24 18:22:07 +08:00
|
|
|
params.marginRegularization = 0.00001f;
|
|
|
|
params.initialStepSize = 0.05f;
|
|
|
|
params.stepDecreasingPower = 0.75f;
|
2016-02-09 23:42:23 +08:00
|
|
|
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
|
2016-01-20 17:59:44 +08:00
|
|
|
break;
|
|
|
|
|
|
|
|
default:
|
|
|
|
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} //ml
|
|
|
|
} //cv
|