2010-05-12 01:44:00 +08:00
|
|
|
/*M///////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//
|
|
|
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
|
|
|
//
|
|
|
|
// By downloading, copying, installing or using the software you agree to this license.
|
|
|
|
// If you do not agree to this license, do not download, install,
|
|
|
|
// copy or use the software.
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// Intel License Agreement
|
|
|
|
//
|
|
|
|
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
|
|
|
// Third party copyrights are property of their respective owners.
|
|
|
|
//
|
|
|
|
// Redistribution and use in source and binary forms, with or without modification,
|
|
|
|
// are permitted provided that the following conditions are met:
|
|
|
|
//
|
|
|
|
// * Redistribution's of source code must retain the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer.
|
|
|
|
//
|
|
|
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
|
|
|
// this list of conditions and the following disclaimer in the documentation
|
|
|
|
// and/or other materials provided with the distribution.
|
|
|
|
//
|
|
|
|
// * The name of Intel Corporation may not be used to endorse or promote products
|
|
|
|
// derived from this software without specific prior written permission.
|
|
|
|
//
|
|
|
|
// This software is provided by the copyright holders and contributors "as is" and
|
|
|
|
// any express or implied warranties, including, but not limited to, the implied
|
|
|
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
|
|
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
|
|
|
// indirect, incidental, special, exemplary, or consequential damages
|
|
|
|
// (including, but not limited to, procurement of substitute goods or services;
|
|
|
|
// loss of use, data, or profits; or business interruption) however caused
|
|
|
|
// and on any theory of liability, whether in contract, strict liability,
|
|
|
|
// or tort (including negligence or otherwise) arising in any way out of
|
|
|
|
// the use of this software, even if advised of the possibility of such damage.
|
|
|
|
//
|
|
|
|
//M*/
|
|
|
|
|
|
|
|
#include "precomp.hpp"
|
|
|
|
#include <ctype.h>
|
2014-07-30 03:54:23 +08:00
|
|
|
#include <algorithm>
|
|
|
|
#include <iterator>
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
namespace cv { namespace ml {
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
static const float MISSED_VAL = TrainData::missingValue();
|
|
|
|
static const int VAR_MISSED = VAR_ORDERED;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
TrainData::~TrainData() {}
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if( idx.empty() )
|
|
|
|
return vec;
|
|
|
|
int i, j, n = idx.checkVector(1, CV_32S);
|
|
|
|
int type = vec.type();
|
|
|
|
CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F );
|
|
|
|
int dims = 1, m;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( vec.cols == 1 || vec.rows == 1 )
|
|
|
|
{
|
|
|
|
dims = 1;
|
|
|
|
m = vec.cols + vec.rows - 1;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
dims = vec.cols;
|
|
|
|
m = vec.rows;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat subvec;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( vec.cols == m )
|
|
|
|
subvec.create(dims, n, type);
|
|
|
|
else
|
|
|
|
subvec.create(n, dims, type);
|
|
|
|
if( type == CV_32S )
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
int k = idx.at<int>(i);
|
|
|
|
CV_Assert( 0 <= k && k < m );
|
|
|
|
if( dims == 1 )
|
|
|
|
subvec.at<int>(i) = vec.at<int>(k);
|
|
|
|
else
|
|
|
|
for( j = 0; j < dims; j++ )
|
|
|
|
subvec.at<int>(i, j) = vec.at<int>(k, j);
|
|
|
|
}
|
|
|
|
else if( type == CV_32F )
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
int k = idx.at<int>(i);
|
|
|
|
CV_Assert( 0 <= k && k < m );
|
|
|
|
if( dims == 1 )
|
|
|
|
subvec.at<float>(i) = vec.at<float>(k);
|
|
|
|
else
|
|
|
|
for( j = 0; j < dims; j++ )
|
|
|
|
subvec.at<float>(i, j) = vec.at<float>(k, j);
|
|
|
|
}
|
|
|
|
else
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
int k = idx.at<int>(i);
|
|
|
|
CV_Assert( 0 <= k && k < m );
|
|
|
|
if( dims == 1 )
|
|
|
|
subvec.at<double>(i) = vec.at<double>(k);
|
|
|
|
else
|
|
|
|
for( j = 0; j < dims; j++ )
|
|
|
|
subvec.at<double>(i, j) = vec.at<double>(k, j);
|
|
|
|
}
|
|
|
|
return subvec;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
class TrainDataImpl : public TrainData
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
public:
|
|
|
|
typedef std::map<String, int> MapType;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
TrainDataImpl()
|
|
|
|
{
|
|
|
|
file = 0;
|
|
|
|
clear();
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
virtual ~TrainDataImpl() { closeFile(); }
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int getLayout() const { return layout; }
|
|
|
|
int getNSamples() const
|
|
|
|
{
|
|
|
|
return !sampleIdx.empty() ? (int)sampleIdx.total() :
|
|
|
|
layout == ROW_SAMPLE ? samples.rows : samples.cols;
|
|
|
|
}
|
|
|
|
int getNTrainSamples() const
|
|
|
|
{
|
|
|
|
return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
|
|
|
|
}
|
|
|
|
int getNTestSamples() const
|
|
|
|
{
|
|
|
|
return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
|
|
|
|
}
|
|
|
|
int getNVars() const
|
|
|
|
{
|
|
|
|
return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
|
|
|
|
}
|
|
|
|
int getNAllVars() const
|
|
|
|
{
|
|
|
|
return layout == ROW_SAMPLE ? samples.cols : samples.rows;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat getSamples() const { return samples; }
|
|
|
|
Mat getResponses() const { return responses; }
|
|
|
|
Mat getMissing() const { return missing; }
|
|
|
|
Mat getVarIdx() const { return varIdx; }
|
|
|
|
Mat getVarType() const { return varType; }
|
|
|
|
int getResponseType() const
|
|
|
|
{
|
|
|
|
return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
|
|
|
|
}
|
|
|
|
Mat getTrainSampleIdx() const { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
|
|
|
|
Mat getTestSampleIdx() const { return testSampleIdx; }
|
|
|
|
Mat getSampleWeights() const
|
|
|
|
{
|
|
|
|
return sampleWeights;
|
|
|
|
}
|
|
|
|
Mat getTrainSampleWeights() const
|
|
|
|
{
|
|
|
|
return getSubVector(sampleWeights, getTrainSampleIdx());
|
|
|
|
}
|
|
|
|
Mat getTestSampleWeights() const
|
|
|
|
{
|
|
|
|
Mat idx = getTestSampleIdx();
|
|
|
|
return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);
|
|
|
|
}
|
|
|
|
Mat getTrainResponses() const
|
|
|
|
{
|
|
|
|
return getSubVector(responses, getTrainSampleIdx());
|
|
|
|
}
|
|
|
|
Mat getTrainNormCatResponses() const
|
|
|
|
{
|
|
|
|
return getSubVector(normCatResponses, getTrainSampleIdx());
|
|
|
|
}
|
|
|
|
Mat getTestResponses() const
|
|
|
|
{
|
|
|
|
Mat idx = getTestSampleIdx();
|
|
|
|
return idx.empty() ? Mat() : getSubVector(responses, idx);
|
|
|
|
}
|
|
|
|
Mat getTestNormCatResponses() const
|
|
|
|
{
|
|
|
|
Mat idx = getTestSampleIdx();
|
|
|
|
return idx.empty() ? Mat() : getSubVector(normCatResponses, idx);
|
|
|
|
}
|
|
|
|
Mat getNormCatResponses() const { return normCatResponses; }
|
|
|
|
Mat getClassLabels() const { return classLabels; }
|
|
|
|
Mat getClassCounters() const { return classCounters; }
|
|
|
|
int getCatCount(int vi) const
|
|
|
|
{
|
|
|
|
int n = (int)catOfs.total();
|
|
|
|
CV_Assert( 0 <= vi && vi < n );
|
|
|
|
Vec2i ofs = catOfs.at<Vec2i>(vi);
|
|
|
|
return ofs[1] - ofs[0];
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat getCatOfs() const { return catOfs; }
|
|
|
|
Mat getCatMap() const { return catMap; }
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat getDefaultSubstValues() const { return missingSubst; }
|
2012-05-19 22:34:36 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void closeFile() { if(file) fclose(file); file=0; }
|
|
|
|
void clear()
|
|
|
|
{
|
|
|
|
closeFile();
|
|
|
|
samples.release();
|
|
|
|
missing.release();
|
|
|
|
varType.release();
|
|
|
|
responses.release();
|
|
|
|
sampleIdx.release();
|
|
|
|
trainSampleIdx.release();
|
|
|
|
testSampleIdx.release();
|
|
|
|
normCatResponses.release();
|
|
|
|
classLabels.release();
|
|
|
|
classCounters.release();
|
|
|
|
catMap.release();
|
|
|
|
catOfs.release();
|
|
|
|
nameMap = MapType();
|
|
|
|
layout = ROW_SAMPLE;
|
|
|
|
}
|
2012-05-19 22:34:36 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
typedef std::map<int, int> CatMapHash;
|
2012-05-19 22:34:36 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void setData(InputArray _samples, int _layout, InputArray _responses,
|
|
|
|
InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
|
|
|
|
InputArray _varType, InputArray _missing)
|
2012-10-17 15:12:04 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
clear();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
|
|
|
|
samples = _samples.getMat();
|
|
|
|
layout = _layout;
|
|
|
|
responses = _responses.getMat();
|
|
|
|
varIdx = _varIdx.getMat();
|
|
|
|
sampleIdx = _sampleIdx.getMat();
|
|
|
|
sampleWeights = _sampleWeights.getMat();
|
|
|
|
varType = _varType.getMat();
|
|
|
|
missing = _missing.getMat();
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
|
|
|
|
int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
|
|
|
|
int i, noutputvars = 0;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !sampleIdx.empty() )
|
|
|
|
{
|
|
|
|
CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
|
2015-12-11 01:17:17 +08:00
|
|
|
checkRange(sampleIdx, true, 0, 0, nsamples)) ||
|
2014-07-30 03:54:23 +08:00
|
|
|
sampleIdx.checkVector(1, CV_8U, true) == nsamples );
|
|
|
|
if( sampleIdx.type() == CV_8U )
|
|
|
|
sampleIdx = convertMaskToIdx(sampleIdx);
|
|
|
|
}
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !sampleWeights.empty() )
|
|
|
|
{
|
|
|
|
CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
sampleWeights = Mat::ones(nsamples, 1, CV_32F);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !varIdx.empty() )
|
|
|
|
{
|
|
|
|
CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
|
|
|
|
checkRange(varIdx, true, 0, 0, ninputvars)) ||
|
|
|
|
varIdx.checkVector(1, CV_8U, true) == ninputvars );
|
|
|
|
if( varIdx.type() == CV_8U )
|
|
|
|
varIdx = convertMaskToIdx(varIdx);
|
|
|
|
varIdx = varIdx.clone();
|
|
|
|
std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
|
|
|
|
}
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !responses.empty() )
|
2013-07-24 21:53:18 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
|
|
|
|
if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
|
|
|
|
noutputvars = 1;
|
|
|
|
else
|
|
|
|
{
|
|
|
|
CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
|
|
|
|
(layout == COL_SAMPLE && responses.cols == nsamples) );
|
|
|
|
noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
|
|
|
|
}
|
|
|
|
if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
|
|
|
|
{
|
|
|
|
Mat temp;
|
|
|
|
transpose(responses, temp);
|
|
|
|
responses = temp;
|
|
|
|
}
|
2013-07-24 21:53:18 +08:00
|
|
|
}
|
2012-05-19 22:34:36 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int nvars = ninputvars + noutputvars;
|
2012-04-14 05:50:59 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !varType.empty() )
|
2012-04-14 05:50:59 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
|
|
|
|
checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
|
2012-04-14 05:50:59 +08:00
|
|
|
}
|
|
|
|
else
|
2014-07-30 03:54:23 +08:00
|
|
|
{
|
|
|
|
varType.create(1, nvars, CV_8U);
|
|
|
|
varType = Scalar::all(VAR_ORDERED);
|
|
|
|
if( noutputvars == 1 )
|
2014-08-03 16:46:28 +08:00
|
|
|
varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( noutputvars > 1 )
|
|
|
|
{
|
|
|
|
for( i = 0; i < noutputvars; i++ )
|
|
|
|
CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
|
|
|
|
}
|
2012-05-25 00:52:14 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
catOfs = Mat::zeros(1, nvars, CV_32SC2);
|
|
|
|
missingSubst = Mat::zeros(1, nvars, CV_32F);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
vector<int> labels, counters, sortbuf, tempCatMap;
|
|
|
|
vector<Vec2i> tempCatOfs;
|
|
|
|
CatMapHash ofshash;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
AutoBuffer<uchar> buf(nsamples);
|
|
|
|
Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, (uchar*)buf);
|
|
|
|
bool haveMissing = !missing.empty();
|
|
|
|
if( haveMissing )
|
|
|
|
{
|
|
|
|
CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// we iterate through all the variables. For each categorical variable we build a map
|
|
|
|
// in order to convert input values of the variable into normalized values (0..catcount_vi-1)
|
|
|
|
// often many categorical variables are similar, so we compress the map - try to re-use
|
|
|
|
// maps for different variables if they are identical
|
|
|
|
for( i = 0; i < ninputvars; i++ )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
|
|
|
|
|
|
|
|
if( varType.at<uchar>(i) == VAR_CATEGORICAL )
|
|
|
|
{
|
|
|
|
preprocessCategorical(values_i, 0, labels, 0, sortbuf);
|
|
|
|
missingSubst.at<float>(i) = -1.f;
|
|
|
|
int j, m = (int)labels.size();
|
|
|
|
CV_Assert( m > 0 );
|
|
|
|
int a = labels.front(), b = labels.back();
|
|
|
|
const int* currmap = &labels[0];
|
|
|
|
int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
|
|
|
|
CatMapHash::iterator it = ofshash.find(hashval);
|
|
|
|
if( it != ofshash.end() )
|
|
|
|
{
|
|
|
|
int vi = it->second;
|
|
|
|
Vec2i ofs0 = tempCatOfs[vi];
|
|
|
|
int m0 = ofs0[1] - ofs0[0];
|
|
|
|
const int* map0 = &tempCatMap[ofs0[0]];
|
|
|
|
if( m0 == m && map0[0] == a && map0[m0-1] == b )
|
|
|
|
{
|
|
|
|
for( j = 0; j < m; j++ )
|
|
|
|
if( map0[j] != currmap[j] )
|
|
|
|
break;
|
|
|
|
if( j == m )
|
|
|
|
{
|
|
|
|
// re-use the map
|
|
|
|
tempCatOfs.push_back(ofs0);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else
|
|
|
|
ofshash[hashval] = i;
|
|
|
|
Vec2i ofs;
|
|
|
|
ofs[0] = (int)tempCatMap.size();
|
|
|
|
ofs[1] = ofs[0] + m;
|
|
|
|
tempCatOfs.push_back(ofs);
|
|
|
|
std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
|
|
|
|
}
|
2014-08-03 05:41:09 +08:00
|
|
|
else
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
tempCatOfs.push_back(Vec2i(0, 0));
|
|
|
|
/*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
|
|
|
|
compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
|
|
|
|
missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
|
|
|
|
missingSubst.at<float>(i) = 0.f;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !tempCatOfs.empty() )
|
|
|
|
{
|
|
|
|
Mat(tempCatOfs).copyTo(catOfs);
|
|
|
|
Mat(tempCatMap).copyTo(catMap);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
|
|
|
|
Mat(labels).copyTo(classLabels);
|
|
|
|
Mat(counters).copyTo(classCounters);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat convertMaskToIdx(const Mat& mask)
|
|
|
|
{
|
|
|
|
int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
|
|
|
|
Mat idx(1, nz, CV_32S);
|
|
|
|
for( i = j = 0; i < n; i++ )
|
|
|
|
if( mask.at<uchar>(i) )
|
|
|
|
idx.at<int>(j++) = i;
|
|
|
|
return idx;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
struct CmpByIdx
|
|
|
|
{
|
|
|
|
CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
|
|
|
|
bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
|
|
|
|
const int* data;
|
|
|
|
int step;
|
|
|
|
};
|
|
|
|
|
|
|
|
void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
|
|
|
|
vector<int>* counters, vector<int>& sortbuf)
|
|
|
|
{
|
|
|
|
CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
|
|
|
|
int* odata = 0;
|
|
|
|
int ostep = 0;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if(normdata)
|
|
|
|
{
|
|
|
|
normdata->create(data.size(), CV_32S);
|
|
|
|
odata = normdata->ptr<int>();
|
|
|
|
ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
|
|
|
|
}
|
2011-06-16 20:35:40 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int i, n = data.cols + data.rows - 1;
|
|
|
|
sortbuf.resize(n*2);
|
|
|
|
int* idx = &sortbuf[0];
|
|
|
|
int* idata = (int*)data.ptr<int>();
|
|
|
|
int istep = data.isContinuous() ? 1 : (int)data.step1();
|
2011-06-17 18:11:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( data.type() == CV_32F )
|
|
|
|
{
|
|
|
|
idata = idx + n;
|
|
|
|
const float* fdata = data.ptr<float>();
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
if( fdata[i*istep] == MISSED_VAL )
|
|
|
|
idata[i] = -1;
|
|
|
|
else
|
|
|
|
{
|
|
|
|
idata[i] = cvRound(fdata[i*istep]);
|
|
|
|
CV_Assert( (float)idata[i] == fdata[i*istep] );
|
|
|
|
}
|
|
|
|
}
|
|
|
|
istep = 1;
|
|
|
|
}
|
2011-06-17 18:11:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
idx[i] = i;
|
2011-06-17 18:11:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
std::sort(idx, idx + n, CmpByIdx(idata, istep));
|
2011-06-16 20:35:40 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int clscount = 1;
|
|
|
|
for( i = 1; i < n; i++ )
|
|
|
|
clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
|
2011-06-17 18:11:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int clslabel = -1;
|
|
|
|
int prev = ~idata[idx[0]*istep];
|
|
|
|
int previdx = 0;
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
labels.resize(clscount);
|
|
|
|
if(counters)
|
|
|
|
counters->resize(clscount);
|
|
|
|
|
|
|
|
for( i = 0; i < n; i++ )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int l = idata[idx[i]*istep];
|
|
|
|
if( l != prev )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
clslabel++;
|
|
|
|
labels[clslabel] = l;
|
|
|
|
int k = i - previdx;
|
|
|
|
if( clslabel > 0 && counters )
|
|
|
|
counters->at(clslabel-1) = k;
|
|
|
|
prev = l;
|
|
|
|
previdx = i;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
if(odata)
|
|
|
|
odata[idx[i]*ostep] = clslabel;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
if(counters)
|
|
|
|
counters->at(clslabel) = i - previdx;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
bool loadCSV(const String& filename, int headerLines,
|
|
|
|
int responseStartIdx, int responseEndIdx,
|
|
|
|
const String& varTypeSpec, char delimiter, char missch)
|
|
|
|
{
|
|
|
|
const int M = 1000000;
|
|
|
|
const char delimiters[3] = { ' ', delimiter, '\0' };
|
|
|
|
int nvars = 0;
|
|
|
|
bool varTypesSet = false;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
clear();
|
2011-06-16 20:35:40 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
file = fopen( filename.c_str(), "rt" );
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !file )
|
|
|
|
return false;
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
std::vector<char> _buf(M);
|
|
|
|
std::vector<float> allresponses;
|
|
|
|
std::vector<float> rowvals;
|
|
|
|
std::vector<uchar> vtypes, rowtypes;
|
|
|
|
bool haveMissed = false;
|
|
|
|
char* buf = &_buf[0];
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
|
|
|
|
int ninputvars = 0, noutputvars = 0;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat tempSamples, tempMissing, tempResponses;
|
|
|
|
MapType tempNameMap;
|
|
|
|
int catCounter = 1;
|
2011-06-16 20:35:40 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// skip header lines
|
|
|
|
int lineno = 0;
|
|
|
|
for(;;lineno++)
|
|
|
|
{
|
|
|
|
if( !fgets(buf, M, file) )
|
|
|
|
break;
|
|
|
|
if(lineno < headerLines )
|
|
|
|
continue;
|
|
|
|
// trim trailing spaces
|
|
|
|
int idx = (int)strlen(buf)-1;
|
|
|
|
while( idx >= 0 && isspace(buf[idx]) )
|
|
|
|
buf[idx--] = '\0';
|
|
|
|
// skip spaces in the beginning
|
|
|
|
char* ptr = buf;
|
|
|
|
while( *ptr != '\0' && isspace(*ptr) )
|
|
|
|
ptr++;
|
|
|
|
// skip commented off lines
|
|
|
|
if(*ptr == '#')
|
|
|
|
continue;
|
|
|
|
rowvals.clear();
|
|
|
|
rowtypes.clear();
|
|
|
|
|
|
|
|
char* token = strtok(buf, delimiters);
|
|
|
|
if (!token)
|
|
|
|
break;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for(;;)
|
|
|
|
{
|
|
|
|
float val=0.f; int tp = 0;
|
|
|
|
decodeElem( token, val, tp, missch, tempNameMap, catCounter );
|
|
|
|
if( tp == VAR_MISSED )
|
|
|
|
haveMissed = true;
|
|
|
|
rowvals.push_back(val);
|
2014-08-03 16:46:28 +08:00
|
|
|
rowtypes.push_back((uchar)tp);
|
2014-07-30 03:54:23 +08:00
|
|
|
token = strtok(NULL, delimiters);
|
|
|
|
if (!token)
|
|
|
|
break;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( nvars == 0 )
|
|
|
|
{
|
|
|
|
if( rowvals.empty() )
|
|
|
|
CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
|
|
|
|
nvars = (int)rowvals.size();
|
|
|
|
if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
|
|
|
|
{
|
|
|
|
setVarTypes(varTypeSpec, nvars, vtypes);
|
|
|
|
varTypesSet = true;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
vtypes = rowtypes;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
|
|
|
|
ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
|
|
|
|
CV_Assert(ridx1 > ridx0);
|
|
|
|
noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
|
|
|
|
ninputvars = nvars - noutputvars;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
CV_Assert( nvars == (int)rowvals.size() );
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
// check var types
|
|
|
|
for( i = 0; i < nvars; i++ )
|
|
|
|
{
|
|
|
|
CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
|
|
|
|
(varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( ridx0 >= 0 )
|
|
|
|
{
|
|
|
|
for( i = ridx1; i < nvars; i++ )
|
|
|
|
std::swap(rowvals[i], rowvals[i-noutputvars]);
|
|
|
|
for( i = ninputvars; i < nvars; i++ )
|
|
|
|
allresponses.push_back(rowvals[i]);
|
|
|
|
rowvals.pop_back();
|
|
|
|
}
|
|
|
|
Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
|
|
|
|
tempSamples.push_back(rmat);
|
|
|
|
}
|
2011-06-17 18:11:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
closeFile();
|
2011-06-16 20:35:40 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
int nsamples = tempSamples.rows;
|
|
|
|
if( nsamples == 0 )
|
|
|
|
return false;
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( haveMissed )
|
|
|
|
compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( ridx0 >= 0 )
|
|
|
|
{
|
|
|
|
for( i = ridx1; i < nvars; i++ )
|
|
|
|
std::swap(vtypes[i], vtypes[i-noutputvars]);
|
|
|
|
if( noutputvars > 1 )
|
|
|
|
{
|
|
|
|
for( i = ninputvars; i < nvars; i++ )
|
|
|
|
if( vtypes[i] == VAR_CATEGORICAL )
|
|
|
|
CV_Error(CV_StsBadArg,
|
|
|
|
"If responses are vector values, not scalars, they must be marked as ordered responses");
|
|
|
|
}
|
|
|
|
}
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
|
|
|
|
{
|
|
|
|
for( i = 0; i < nsamples; i++ )
|
|
|
|
if( allresponses[i] != cvRound(allresponses[i]) )
|
|
|
|
break;
|
|
|
|
if( i == nsamples )
|
|
|
|
vtypes[ninputvars] = VAR_CATEGORICAL;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2015-08-04 21:58:24 +08:00
|
|
|
//If there are responses in the csv file, save them. If not, responses matrix will contain just zeros
|
2015-08-04 22:45:39 +08:00
|
|
|
if (noutputvars != 0){
|
2015-08-04 22:50:55 +08:00
|
|
|
Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
|
|
|
|
setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
|
|
|
|
noArray(), Mat(vtypes).clone(), tempMissing);
|
2015-08-04 21:58:24 +08:00
|
|
|
}
|
|
|
|
else{
|
|
|
|
Mat zero_mat(nsamples, 1, CV_32F, Scalar(0));
|
|
|
|
zero_mat.copyTo(tempResponses);
|
|
|
|
setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
|
|
|
|
noArray(), noArray(), tempMissing);
|
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
bool ok = !samples.empty();
|
|
|
|
if(ok)
|
|
|
|
std::swap(tempNameMap, nameMap);
|
|
|
|
return ok;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void decodeElem( const char* token, float& elem, int& type,
|
|
|
|
char missch, MapType& namemap, int& counter ) const
|
|
|
|
{
|
|
|
|
char* stopstring = NULL;
|
|
|
|
elem = (float)strtod( token, &stopstring );
|
|
|
|
if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
|
|
|
|
{
|
|
|
|
elem = MISSED_VAL;
|
|
|
|
type = VAR_MISSED;
|
|
|
|
}
|
|
|
|
else if( *stopstring != '\0' )
|
|
|
|
{
|
|
|
|
MapType::iterator it = namemap.find(token);
|
|
|
|
if( it == namemap.end() )
|
|
|
|
{
|
|
|
|
elem = (float)counter;
|
|
|
|
namemap[token] = counter++;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
elem = (float)it->second;
|
|
|
|
type = VAR_CATEGORICAL;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
type = VAR_ORDERED;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
|
|
|
|
{
|
|
|
|
const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
|
|
|
|
"\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
|
|
|
|
const char* str = s.c_str();
|
|
|
|
int specCounter = 0;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
vtypes.resize(nvars);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( int k = 0; k < 2; k++ )
|
|
|
|
{
|
|
|
|
const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
|
|
|
|
int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
|
|
|
|
if( ptr ) // parse ord/cat str
|
|
|
|
{
|
|
|
|
char* stopstring = NULL;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( ptr[3] == '\0' )
|
|
|
|
{
|
|
|
|
for( int i = 0; i < nvars; i++ )
|
|
|
|
vtypes[i] = (uchar)tp;
|
|
|
|
specCounter = nvars;
|
|
|
|
break;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if ( ptr[3] != '[')
|
|
|
|
CV_Error( CV_StsBadArg, errmsg );
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
ptr += 4; // pass "ord["
|
|
|
|
do
|
|
|
|
{
|
|
|
|
int b1 = (int)strtod( ptr, &stopstring );
|
|
|
|
if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
|
|
|
|
CV_Error( CV_StsBadArg, errmsg );
|
|
|
|
ptr = stopstring + 1;
|
|
|
|
if( (stopstring[0] == ',') || (stopstring[0] == ']'))
|
|
|
|
{
|
|
|
|
CV_Assert( 0 <= b1 && b1 < nvars );
|
|
|
|
vtypes[b1] = (uchar)tp;
|
|
|
|
specCounter++;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
{
|
|
|
|
if( stopstring[0] == '-')
|
|
|
|
{
|
|
|
|
int b2 = (int)strtod( ptr, &stopstring);
|
|
|
|
if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
|
|
|
|
CV_Error( CV_StsBadArg, errmsg );
|
|
|
|
ptr = stopstring + 1;
|
|
|
|
CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
|
|
|
|
for (int i = b1; i <= b2; i++)
|
|
|
|
vtypes[i] = (uchar)tp;
|
|
|
|
specCounter += b2 - b1 + 1;
|
|
|
|
}
|
|
|
|
else
|
|
|
|
CV_Error( CV_StsBadArg, errmsg );
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
while(*stopstring != ']');
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( stopstring[1] != '\0' && stopstring[1] != ',')
|
|
|
|
CV_Error( CV_StsBadArg, errmsg );
|
|
|
|
}
|
|
|
|
}
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( specCounter != nvars )
|
|
|
|
CV_Error( CV_StsBadArg, "type of some variables is not specified" );
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-08-03 05:41:09 +08:00
|
|
|
void setTrainTestSplitRatio(double ratio, bool shuffle)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-08-03 05:41:09 +08:00
|
|
|
CV_Assert( 0. <= ratio && ratio <= 1. );
|
2014-07-30 03:54:23 +08:00
|
|
|
setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void setTrainTestSplit(int count, bool shuffle)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int i, nsamples = getNSamples();
|
2014-08-03 07:08:25 +08:00
|
|
|
CV_Assert( 0 <= count && count < nsamples );
|
2014-07-30 03:54:23 +08:00
|
|
|
|
|
|
|
trainSampleIdx.release();
|
|
|
|
testSampleIdx.release();
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( count == 0 )
|
|
|
|
trainSampleIdx = sampleIdx;
|
|
|
|
else if( count == nsamples )
|
|
|
|
testSampleIdx = sampleIdx;
|
|
|
|
else
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat mask(1, nsamples, CV_8U);
|
2014-08-13 19:08:27 +08:00
|
|
|
uchar* mptr = mask.ptr();
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < nsamples; i++ )
|
|
|
|
mptr[i] = (uchar)(i < count);
|
|
|
|
trainSampleIdx.create(1, count, CV_32S);
|
|
|
|
testSampleIdx.create(1, nsamples - count, CV_32S);
|
|
|
|
int j0 = 0, j1 = 0;
|
|
|
|
const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
|
|
|
|
int* trainptr = trainSampleIdx.ptr<int>();
|
|
|
|
int* testptr = testSampleIdx.ptr<int>();
|
|
|
|
for( i = 0; i < nsamples; i++ )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int idx = sptr ? sptr[i] : i;
|
|
|
|
if( mptr[i] )
|
|
|
|
trainptr[j0++] = idx;
|
2010-05-12 01:44:00 +08:00
|
|
|
else
|
2014-07-30 03:54:23 +08:00
|
|
|
testptr[j1++] = idx;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
if( shuffle )
|
|
|
|
shuffleTrainTest();
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2012-10-17 15:12:04 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void shuffleTrainTest()
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
|
|
|
|
int* trainIdx = trainSampleIdx.ptr<int>();
|
|
|
|
int* testIdx = testSampleIdx.ptr<int>();
|
|
|
|
RNG& rng = theRNG();
|
|
|
|
|
|
|
|
for( i = 0; i < nsamples; i++)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int a = rng.uniform(0, nsamples);
|
|
|
|
int b = rng.uniform(0, nsamples);
|
|
|
|
int* ptra = trainIdx;
|
|
|
|
int* ptrb = trainIdx;
|
|
|
|
if( a >= ntrain )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
ptra = testIdx;
|
|
|
|
a -= ntrain;
|
|
|
|
CV_Assert( a < ntest );
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
if( b >= ntrain )
|
|
|
|
{
|
|
|
|
ptrb = testIdx;
|
|
|
|
b -= ntrain;
|
|
|
|
CV_Assert( b < ntest );
|
|
|
|
}
|
|
|
|
std::swap(ptra[a], ptrb[b]);
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat getTrainSamples(int _layout,
|
|
|
|
bool compressSamples,
|
|
|
|
bool compressVars) const
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
if( samples.empty() )
|
|
|
|
return samples;
|
|
|
|
|
|
|
|
if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
|
|
|
|
(!compressVars || varIdx.empty()) &&
|
|
|
|
layout == _layout )
|
|
|
|
return samples;
|
|
|
|
|
|
|
|
int drows = getNTrainSamples(), dcols = getNVars();
|
|
|
|
Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
|
|
|
|
const float* src0 = samples.ptr<float>();
|
|
|
|
const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
|
|
|
|
const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
|
|
|
|
size_t sstep0 = samples.step/samples.elemSize();
|
|
|
|
size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
|
|
|
|
size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
|
|
|
|
|
|
|
|
if( _layout == COL_SAMPLE )
|
|
|
|
{
|
|
|
|
std::swap(drows, dcols);
|
|
|
|
std::swap(sptr, vptr);
|
|
|
|
std::swap(sstep, vstep);
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat dsamples(drows, dcols, CV_32F);
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( int i = 0; i < drows; i++ )
|
|
|
|
{
|
|
|
|
const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
|
|
|
|
float* dst = dsamples.ptr<float>(i);
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
for( int j = 0; j < dcols; j++ )
|
|
|
|
dst[j] = src[(vptr ? vptr[j] : j)*vstep];
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
return dsamples;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void getValues( int vi, InputArray _sidx, float* values ) const
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
Mat sidx = _sidx.getMat();
|
2014-08-10 04:10:05 +08:00
|
|
|
int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert( 0 <= vi && vi < getNAllVars() );
|
2014-08-10 04:10:05 +08:00
|
|
|
CV_Assert( n >= 0 );
|
2014-07-30 03:54:23 +08:00
|
|
|
const int* s = n > 0 ? sidx.ptr<int>() : 0;
|
|
|
|
if( n == 0 )
|
|
|
|
n = nsamples;
|
|
|
|
|
|
|
|
size_t step = samples.step/samples.elemSize();
|
|
|
|
size_t sstep = layout == ROW_SAMPLE ? step : 1;
|
|
|
|
size_t vstep = layout == ROW_SAMPLE ? 1 : step;
|
|
|
|
|
|
|
|
const float* src = samples.ptr<float>() + vi*vstep;
|
|
|
|
float subst = missingSubst.at<float>(vi);
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
int j = i;
|
|
|
|
if( s )
|
|
|
|
{
|
|
|
|
j = s[i];
|
2014-08-03 16:46:28 +08:00
|
|
|
CV_Assert( 0 <= j && j < nsamples );
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
values[i] = src[j*sstep];
|
|
|
|
if( values[i] == MISSED_VAL )
|
|
|
|
values[i] = subst;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void getNormCatValues( int vi, InputArray _sidx, int* values ) const
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
float* fvalues = (float*)values;
|
|
|
|
getValues(vi, _sidx, fvalues);
|
|
|
|
int i, n = (int)_sidx.total();
|
|
|
|
Vec2i ofs = catOfs.at<Vec2i>(vi);
|
|
|
|
int m = ofs[1] - ofs[0];
|
2011-06-16 20:35:40 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
|
|
|
|
const int* cmap = &catMap.at<int>(ofs[0]);
|
2014-12-16 23:15:50 +08:00
|
|
|
bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
|
2011-06-17 18:11:52 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
if( fastMap )
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
int val = cvRound(fvalues[i]);
|
|
|
|
int idx = val - cmap[0];
|
|
|
|
CV_Assert(cmap[idx] == val);
|
|
|
|
values[i] = idx;
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
else
|
|
|
|
{
|
|
|
|
for( i = 0; i < n; i++ )
|
|
|
|
{
|
|
|
|
int val = cvRound(fvalues[i]);
|
|
|
|
int a = 0, b = m, c = -1;
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
while( a < b )
|
|
|
|
{
|
|
|
|
c = (a + b) >> 1;
|
|
|
|
if( val < cmap[c] )
|
|
|
|
b = c;
|
|
|
|
else if( val > cmap[c] )
|
|
|
|
a = c+1;
|
|
|
|
else
|
|
|
|
break;
|
|
|
|
}
|
2012-10-17 15:12:04 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
CV_DbgAssert( c >= 0 && val == cmap[c] );
|
|
|
|
values[i] = c;
|
|
|
|
}
|
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
void getSample(InputArray _vidx, int sidx, float* buf) const
|
|
|
|
{
|
|
|
|
CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
|
|
|
|
Mat vidx = _vidx.getMat();
|
2014-08-10 04:10:05 +08:00
|
|
|
int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
|
|
|
|
CV_Assert( n >= 0 );
|
2014-07-30 03:54:23 +08:00
|
|
|
const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
|
|
|
|
if( n == 0 )
|
|
|
|
n = nvars;
|
|
|
|
|
|
|
|
size_t step = samples.step/samples.elemSize();
|
|
|
|
size_t sstep = layout == ROW_SAMPLE ? step : 1;
|
|
|
|
size_t vstep = layout == ROW_SAMPLE ? 1 : step;
|
|
|
|
|
|
|
|
const float* src = samples.ptr<float>() + sidx*sstep;
|
|
|
|
for( i = 0; i < n; i++ )
|
2012-10-17 15:12:04 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
int j = i;
|
|
|
|
if( vptr )
|
|
|
|
{
|
|
|
|
j = vptr[i];
|
2014-08-03 16:46:28 +08:00
|
|
|
CV_Assert( 0 <= j && j < nvars );
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
|
|
|
buf[i] = src[j*vstep];
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
2014-07-30 03:54:23 +08:00
|
|
|
}
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
FILE* file;
|
|
|
|
int layout;
|
|
|
|
Mat samples, missing, varType, varIdx, responses, missingSubst;
|
|
|
|
Mat sampleIdx, trainSampleIdx, testSampleIdx;
|
|
|
|
Mat sampleWeights, catMap, catOfs;
|
|
|
|
Mat normCatResponses, classLabels, classCounters;
|
|
|
|
MapType nameMap;
|
|
|
|
};
|
2010-05-12 01:44:00 +08:00
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
|
|
|
|
int headerLines,
|
|
|
|
int responseStartIdx,
|
|
|
|
int responseEndIdx,
|
|
|
|
const String& varTypeSpec,
|
|
|
|
char delimiter, char missch)
|
2010-05-12 01:44:00 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
|
|
|
|
if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
|
|
|
|
td.release();
|
|
|
|
return td;
|
2012-04-30 22:33:52 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
|
|
|
|
InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
|
|
|
|
InputArray varType)
|
2012-04-30 22:33:52 +08:00
|
|
|
{
|
2014-07-30 03:54:23 +08:00
|
|
|
Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
|
|
|
|
td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
|
|
|
|
return td;
|
2010-05-12 01:44:00 +08:00
|
|
|
}
|
|
|
|
|
2014-07-30 03:54:23 +08:00
|
|
|
}}
|
|
|
|
|
2010-05-12 01:44:00 +08:00
|
|
|
/* End of file. */
|