opencv/modules/ml/src/svmsgd.cpp

595 lines
18 KiB
C++
Raw Normal View History

2015-09-03 00:14:40 +08:00
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
2016-02-15 19:35:36 +08:00
// Copyright (C) 2016, Itseez Inc, all rights reserved.
2015-09-03 00:14:40 +08:00
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#include "precomp.hpp"
2016-01-20 17:59:44 +08:00
#include "limits"
#include <iostream>
using std::cout;
using std::endl;
2015-09-03 00:14:40 +08:00
/****************************************************************************************\
* Stochastic Gradient Descent SVM Classifier *
\****************************************************************************************/
2016-01-20 17:59:44 +08:00
namespace cv
{
namespace ml
{
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
class SVMSGDImpl : public SVMSGD
{
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
public:
SVMSGDImpl();
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual ~SVMSGDImpl() {}
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual bool train(const Ptr<TrainData>& data, int);
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
2015-09-03 00:14:40 +08:00
virtual bool isClassifier() const;
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual bool isTrained() const;
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual void clear();
2015-09-03 00:14:40 +08:00
virtual void write(FileStorage &fs) const;
2015-09-03 00:14:40 +08:00
virtual void read(const FileNode &fn);
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual Mat getWeights(){ return weights_; }
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual float getShift(){ return shift_; }
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual int getVarCount() const { return weights_.cols; }
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
2015-09-03 00:14:40 +08:00
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
2015-09-03 00:14:40 +08:00
virtual int getSvmsgdType() const;
2015-09-03 00:14:40 +08:00
virtual void setSvmsgdType(int svmsgdType);
virtual int getMarginType() const;
virtual void setMarginType(int marginType);
2016-01-20 17:59:44 +08:00
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
CV_IMPL_PROPERTY(float, C, params.c)
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
2015-09-03 00:14:40 +08:00
private:
2016-02-15 19:35:36 +08:00
void updateWeights(InputArray sample, bool isPositive, float gamma, Mat &weights);
2016-01-20 17:59:44 +08:00
std::pair<bool,bool> areClassesEmpty(Mat responses);
void writeParams( FileStorage &fs ) const;
void readParams( const FileNode &fn );
2016-02-15 19:35:36 +08:00
static inline bool isPositive(float val) { return val > 0; }
2016-01-20 17:59:44 +08:00
static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
float calcShift(InputArray _samples, InputArray _responses) const;
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
2016-01-20 17:59:44 +08:00
// Vector with SVM weights
Mat weights_;
float shift_;
// Parameters for learning
struct SVMSGDParams
{
float lambda; //regularization
float gamma0; //learning rate
float c;
TermCriteria termCrit;
SvmsgdType svmsgdType;
MarginType marginType;
2016-01-20 17:59:44 +08:00
};
SVMSGDParams params;
};
Ptr<SVMSGD> SVMSGD::create()
2016-02-10 20:40:09 +08:00
{
2016-01-20 17:59:44 +08:00
return makePtr<SVMSGDImpl>();
2015-09-03 00:14:40 +08:00
}
std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
2016-01-20 17:59:44 +08:00
{
CV_Assert(responses.cols == 1 || responses.rows == 1);
std::pair<bool,bool> emptyInClasses(true, true);
int limit_index = responses.rows;
2016-01-20 17:59:44 +08:00
for(int index = 0; index < limit_index; index++)
{
2016-02-15 19:35:36 +08:00
if (isPositive(responses.at<float>(index)))
emptyInClasses.first = false;
else
emptyInClasses.second = false;
2015-09-03 00:14:40 +08:00
if (!emptyInClasses.first && ! emptyInClasses.second)
break;
}
2016-01-20 17:59:44 +08:00
return emptyInClasses;
}
2016-01-20 17:59:44 +08:00
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
{
int featuresCount = samples.cols;
int samplesCount = samples.rows;
2016-01-20 17:59:44 +08:00
average = Mat(1, featuresCount, samples.type());
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
2016-01-20 17:59:44 +08:00
{
2016-02-15 19:35:36 +08:00
Scalar scalAverage = mean(samples.col(featureIndex));
2016-02-11 00:46:24 +08:00
average.at<float>(featureIndex) = static_cast<float>(scalAverage[0]);
2016-01-20 17:59:44 +08:00
}
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
2016-01-20 17:59:44 +08:00
{
samples.row(sampleIndex) -= average;
2016-01-20 17:59:44 +08:00
}
double normValue = norm(samples);
2015-09-03 00:14:40 +08:00
2016-02-11 00:46:24 +08:00
multiplier = static_cast<float>(sqrt(samples.total()) / normValue);
samples *= multiplier;
}
2016-01-20 17:59:44 +08:00
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
{
2016-02-15 19:35:36 +08:00
Mat normalizedTrainSamples = trainSamples.clone();
int samplesCount = normalizedTrainSamples.rows;
2016-01-20 17:59:44 +08:00
2016-02-15 19:35:36 +08:00
normalizeSamples(normalizedTrainSamples, average, multiplier);
2016-01-20 17:59:44 +08:00
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
2016-02-15 19:35:36 +08:00
cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples);
2015-09-03 00:14:40 +08:00
}
void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma, Mat& weights)
2016-01-20 17:59:44 +08:00
{
Mat sample = _sample.getMat();
2016-01-20 17:59:44 +08:00
int response = firstClass ? 1 : -1; // ensure that trainResponses are -1 or 1
2015-09-03 00:14:40 +08:00
if ( sample.dot(weights) * response > 1)
{
// Not a support vector, only apply weight decay
weights *= (1.f - gamma * params.lambda);
}
else
{
// It's a support vector, add it to the weights
weights -= (gamma * params.lambda) * weights - (gamma * response) * sample;
2015-09-03 00:14:40 +08:00
}
2016-01-20 17:59:44 +08:00
}
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
{
float distanceToClasses[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
Mat trainSamples = _samples.getMat();
int trainSamplesCount = trainSamples.rows;
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
Mat trainResponses = _responses.getMat();
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
{
Mat currentSample = trainSamples.row(samplesIndex);
2016-02-11 00:46:24 +08:00
float dotProduct = static_cast<float>(currentSample.dot(weights_));
2015-09-03 00:14:40 +08:00
2016-02-15 19:35:36 +08:00
bool firstClass = isPositive(trainResponses.at<float>(samplesIndex));
2016-02-11 00:46:24 +08:00
int index = firstClass ? 0 : 1;
float signToMul = firstClass ? 1.f : -1.f;
float curDistance = dotProduct * signToMul;
2016-01-20 17:59:44 +08:00
if (curDistance < distanceToClasses[index])
2016-01-20 17:59:44 +08:00
{
distanceToClasses[index] = curDistance;
2016-01-20 17:59:44 +08:00
}
2015-09-03 00:14:40 +08:00
}
return -(distanceToClasses[0] - distanceToClasses[1]) / 2.f;
2015-09-03 00:14:40 +08:00
}
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
{
clear();
CV_Assert( isClassifier() ); //toDo: consider
Mat trainSamples = data->getTrainSamples();
int featureCount = trainSamples.cols;
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
std::pair<bool,bool> areEmpty = areClassesEmpty(trainResponses);
if ( areEmpty.first && areEmpty.second )
{
return false;
}
if ( areEmpty.first || areEmpty.second )
{
weights_ = Mat::zeros(1, featureCount, CV_32F);
2016-02-11 00:46:24 +08:00
shift_ = areEmpty.first ? -1.f : 1.f;
return true;
2016-02-10 20:40:09 +08:00
}
Mat extendedTrainSamples;
Mat average;
float multiplier = 0;
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
int extendedTrainSamplesCount = extendedTrainSamples.rows;
int extendedFeatureCount = extendedTrainSamples.cols;
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); // Initialize extendedWeights vector with zeros
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); //extendedWeights vector for calculating terminal criterion
Mat averageExtendedWeights; //average extendedWeights vector for ASGD model
if (params.svmsgdType == ASGD)
{
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
}
RNG rng(0);
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
double err = DBL_MAX;
// Stochastic gradient descent SVM
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
{
int randomNumber = rng.uniform(0, extendedTrainSamplesCount); //generate sample number
Mat currentSample = extendedTrainSamples.row(randomNumber);
float gamma = params.gamma0 * std::pow((1 + params.lambda * params.gamma0 * (float)iter), (-params.c)); //update gamma
2016-02-15 19:35:36 +08:00
updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), gamma, extendedWeights );
//average weights (only for ASGD model)
if (params.svmsgdType == ASGD)
{
averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights + extendedWeights / (1 + (float) iter);
err = norm(averageExtendedWeights - previousWeights);
averageExtendedWeights.copyTo(previousWeights);
}
else
{
err = norm(extendedWeights - previousWeights);
extendedWeights.copyTo(previousWeights);
}
}
if (params.svmsgdType == ASGD)
{
extendedWeights = averageExtendedWeights;
}
Rect roi(0, 0, featureCount, 1);
weights_ = extendedWeights(roi);
weights_ *= multiplier;
CV_Assert(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN);
if (params.marginType == SOFT_MARGIN)
{
2016-02-11 00:46:24 +08:00
shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average));
}
else
{
shift_ = calcShift(trainSamples, trainResponses);
}
return true;
}
2016-01-20 17:59:44 +08:00
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
{
float result = 0;
cv::Mat samples = _samples.getMat();
int nSamples = samples.rows;
cv::Mat results;
2015-09-03 00:14:40 +08:00
2016-01-20 17:59:44 +08:00
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32F );
if( _results.needed() )
{
_results.create( nSamples, 1, samples.type() );
results = _results.getMat();
}
else
{
CV_Assert( nSamples == 1 );
results = Mat(1, 1, CV_32F, &result);
2015-09-03 00:14:40 +08:00
}
2016-01-20 17:59:44 +08:00
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
{
Mat currentSample = samples.row(sampleIndex);
2016-02-11 00:46:24 +08:00
float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_;
results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f;
2016-01-20 17:59:44 +08:00
}
return result;
2015-09-03 00:14:40 +08:00
}
bool SVMSGDImpl::isClassifier() const
2016-01-20 17:59:44 +08:00
{
return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
&&
(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
&&
(params.lambda > 0) && (params.gamma0 > 0) && (params.c >= 0);
2016-01-20 17:59:44 +08:00
}
bool SVMSGDImpl::isTrained() const
{
return !weights_.empty();
}
void SVMSGDImpl::write(FileStorage& fs) const
{
if( !isTrained() )
CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
writeParams( fs );
fs << "weights" << weights_;
fs << "shift" << shift_;
2016-01-20 17:59:44 +08:00
}
void SVMSGDImpl::writeParams( FileStorage& fs ) const
{
String SvmsgdTypeStr;
switch (params.svmsgdType)
{
case SGD:
SvmsgdTypeStr = "SGD";
break;
case ASGD:
SvmsgdTypeStr = "ASGD";
break;
case ILLEGAL_SVMSGD_TYPE:
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
2016-01-20 17:59:44 +08:00
default:
std::cout << "params.svmsgdType isn't initialized" << std::endl;
}
fs << "svmsgdType" << SvmsgdTypeStr;
String marginTypeStr;
switch (params.marginType)
{
case SOFT_MARGIN:
marginTypeStr = "SOFT_MARGIN";
break;
case HARD_MARGIN:
marginTypeStr = "HARD_MARGIN";
break;
case ILLEGAL_MARGIN_TYPE:
marginTypeStr = format("Unknown_%d", params.marginType);
default:
std::cout << "params.marginType isn't initialized" << std::endl;
}
fs << "marginType" << marginTypeStr;
2016-01-20 17:59:44 +08:00
fs << "lambda" << params.lambda;
fs << "gamma0" << params.gamma0;
fs << "c" << params.c;
fs << "term_criteria" << "{:";
if( params.termCrit.type & TermCriteria::EPS )
fs << "epsilon" << params.termCrit.epsilon;
if( params.termCrit.type & TermCriteria::COUNT )
fs << "iterations" << params.termCrit.maxCount;
fs << "}";
}
void SVMSGDImpl::read(const FileNode& fn)
{
clear();
readParams(fn);
fn["weights"] >> weights_;
fn["shift"] >> shift_;
2016-01-20 17:59:44 +08:00
}
void SVMSGDImpl::readParams( const FileNode& fn )
{
String svmsgdTypeStr = (String)fn["svmsgdType"];
SvmsgdType svmsgdType =
svmsgdTypeStr == "SGD" ? SGD :
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE;
2016-01-20 17:59:44 +08:00
if( svmsgdType == ILLEGAL_SVMSGD_TYPE )
2016-01-20 17:59:44 +08:00
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
params.svmsgdType = svmsgdType;
String marginTypeStr = (String)fn["marginType"];
MarginType marginType =
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
if( marginType == ILLEGAL_MARGIN_TYPE )
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
params.marginType = marginType;
2016-01-20 17:59:44 +08:00
CV_Assert ( fn["lambda"].isReal() );
params.lambda = (float)fn["lambda"];
CV_Assert ( fn["gamma0"].isReal() );
params.gamma0 = (float)fn["gamma0"];
CV_Assert ( fn["c"].isReal() );
params.c = (float)fn["c"];
FileNode tcnode = fn["term_criteria"];
if( !tcnode.empty() )
{
params.termCrit.epsilon = (double)tcnode["epsilon"];
params.termCrit.maxCount = (int)tcnode["iterations"];
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
2015-09-03 00:14:40 +08:00
}
2016-01-20 17:59:44 +08:00
else
params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 100000, FLT_EPSILON );
2016-01-20 17:59:44 +08:00
2015-09-03 00:14:40 +08:00
}
2016-01-20 17:59:44 +08:00
void SVMSGDImpl::clear()
{
weights_.release();
2016-02-10 21:44:16 +08:00
shift_ = 0;
2015-09-03 00:14:40 +08:00
}
2016-01-20 17:59:44 +08:00
SVMSGDImpl::SVMSGDImpl()
{
clear();
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
params.marginType = ILLEGAL_MARGIN_TYPE;
2016-01-20 17:59:44 +08:00
// Parameters for learning
params.lambda = 0; // regularization
params.gamma0 = 0; // learning rate (ideally should be large at beginning and decay each iteration)
params.c = 0;
TermCriteria _termCrit(TermCriteria::COUNT + TermCriteria::EPS, 0, 0);
params.termCrit = _termCrit;
}
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
2016-01-20 17:59:44 +08:00
{
switch (svmsgdType)
2016-01-20 17:59:44 +08:00
{
case SGD:
params.svmsgdType = SGD;
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
2016-02-11 00:46:24 +08:00
params.lambda = 0.0001f;
params.gamma0 = 0.05f;
params.c = 1.f;
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
2016-01-20 17:59:44 +08:00
break;
case ASGD:
params.svmsgdType = ASGD;
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
2016-02-11 00:46:24 +08:00
params.lambda = 0.00001f;
params.gamma0 = 0.05f;
params.c = 0.75f;
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
2016-01-20 17:59:44 +08:00
break;
default:
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
}
}
void SVMSGDImpl::setSvmsgdType(int type)
2016-01-20 17:59:44 +08:00
{
switch (type)
{
case SGD:
params.svmsgdType = SGD;
break;
case ASGD:
params.svmsgdType = ASGD;
break;
default:
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
2016-01-20 17:59:44 +08:00
}
}
int SVMSGDImpl::getSvmsgdType() const
2016-01-20 17:59:44 +08:00
{
return params.svmsgdType;
2015-09-03 00:14:40 +08:00
}
void SVMSGDImpl::setMarginType(int type)
{
switch (type)
{
case HARD_MARGIN:
params.marginType = HARD_MARGIN;
break;
case SOFT_MARGIN:
params.marginType = SOFT_MARGIN;
break;
default:
params.marginType = ILLEGAL_MARGIN_TYPE;
}
}
int SVMSGDImpl::getMarginType() const
{
return params.marginType;
}
2016-01-20 17:59:44 +08:00
} //ml
} //cv