Bug #2214: Fixed bug, when calling the train method multiple times. The Eigenfaces and Fisherfaces method now re-estimate the model, the LBPH method appends the new data to the existing model.

This commit is contained in:
Philipp Wagner 2012-07-29 22:20:07 +02:00
parent f8facadc67
commit 79b45b1392

View File

@ -338,6 +338,9 @@ void Eigenfaces::train(InputArray _src, InputArray _local_labels) {
string error_message = format("The number of samples (src) must equal the number of labels (labels)! len(src)=%d, len(labels)=%d.", n, labels.total());
CV_Error(CV_StsBadArg, error_message);
}
// clear existing model data
_labels.release();
_projections.clear();
// clip number of components to be valid
if((_num_components <= 0) || (_num_components > n))
_num_components = n;
@ -347,7 +350,8 @@ void Eigenfaces::train(InputArray _src, InputArray _local_labels) {
_mean = pca.mean.reshape(1,1); // store the mean vector
_eigenvalues = pca.eigenvalues.clone(); // eigenvalues by row
transpose(pca.eigenvectors, _eigenvectors); // eigenvectors by column
labels.copyTo(_labels); // store labels for prediction
// store labels for prediction
labels.copyTo(_labels);
// save projections
for(int sampleIdx = 0; sampleIdx < data.rows; sampleIdx++) {
Mat p = subspaceProject(_eigenvectors, _mean, data.row(sampleIdx));
@ -443,7 +447,10 @@ void Fisherfaces::train(InputArray src, InputArray _lbls) {
string error_message = format("Expected the labels in a matrix with one row or column! Given dimensions are rows=%s, cols=%d.", labels.rows, labels.cols);
CV_Error(CV_StsBadArg, error_message);
}
// Get the number of unique classes (provide a cv::Mat overloaded version?)
// clear existing model data
_labels.release();
_projections.clear();
// get the number of unique classes (provide a cv::Mat overloaded version?)
vector<int> ll;
labels.copyTo(ll);
int C = (int) remove_dups(ll).size();
@ -462,7 +469,7 @@ void Fisherfaces::train(InputArray src, InputArray _lbls) {
lda.eigenvalues().convertTo(_eigenvalues, CV_64FC1);
// Now calculate the projection matrix as pca.eigenvectors * lda.eigenvectors.
// Note: OpenCV stores the eigenvectors by row, so we need to transpose it!
gemm(pca.eigenvectors, lda.eigenvectors(), 1.0, Mat(), 0.0, _eigenvectors, CV_GEMM_A_T);
gemm(pca.eigenvectors, lda.eigenvectors(), 1.0, Mat(), 0.0, _eigenvectors, GEMM_1_T);
// store the projections of the original data
for(int sampleIdx = 0; sampleIdx < data.rows; sampleIdx++) {
Mat p = subspaceProject(_eigenvectors, _mean, data.row(sampleIdx));
@ -525,6 +532,7 @@ void Fisherfaces::save(FileStorage& fs) const {
writeFileNodeList(fs, "projections", _projections);
fs << "labels" << _labels;
}
//------------------------------------------------------------------------------
// LBPH
//------------------------------------------------------------------------------
@ -724,7 +732,8 @@ void LBPH::train(InputArray _src, InputArray _lbls) {
if(_src.kind() != _InputArray::STD_VECTOR_MAT && _src.kind() != _InputArray::STD_VECTOR_VECTOR) {
string error_message = "The images are expected as InputArray::STD_VECTOR_MAT (a std::vector<Mat>) or _InputArray::STD_VECTOR_VECTOR (a std::vector< vector<...> >).";
CV_Error(CV_StsBadArg, error_message);
} else if(_src.total() == 0) {
}
if(_src.total() == 0) {
string error_message = format("Empty training data was given. You'll need more than one sample to learn a model.");
CV_Error(CV_StsUnsupportedFormat, error_message);
} else if(_lbls.getMat().type() != CV_32SC1) {
@ -734,22 +743,19 @@ void LBPH::train(InputArray _src, InputArray _lbls) {
// get the vector of matrices
vector<Mat> src;
_src.getMatVector(src);
for (vector<Mat>::const_iterator image = src.begin(); image != src.end(); ++image) {
if (image->channels() != 1) {
string error_message = format("The images must be single channel (grayscale), but an image has %d channels.", image->channels());
CV_Error(CV_StsUnsupportedFormat, error_message);
}
}
// turn the label matrix into a vector
// get the label matrix
Mat labels = _lbls.getMat();
CV_Assert( labels.type() == CV_32S && (labels.cols == 1 || labels.rows == 1));
// check if data is well- aligned
if(labels.total() != src.size()) {
CV_Error(CV_StsUnsupportedFormat, "The number of labels must equal the number of samples.");
string error_message = format("The number of samples (src) must equal the number of labels (labels). Was len(samples)=%d, len(labels)=%d.", src.size(), _labels.total());
CV_Error(CV_StsBadArg, error_message);
}
// append labels to _labels matrix
for(int labelIdx = 0; labelIdx < labels.total(); labelIdx++) {
_labels.push_back(labels.at<int>(labelIdx));
}
// store given labels
labels.copyTo(_labels);
// store the spatial histograms of the original data
for(size_t sampleIdx = 0; sampleIdx < src.size(); sampleIdx++) {
for(int sampleIdx = 0; sampleIdx < src.size(); sampleIdx++) {
// calculate lbp image
Mat lbp_image = elbp(src[sampleIdx], _radius, _neighbors);
// get spatial histogram from this lbp image
@ -765,11 +771,12 @@ void LBPH::train(InputArray _src, InputArray _lbls) {
}
void LBPH::predict(InputArray _src, int &minClass, double &minDist) const {
Mat src = _src.getMat();
if (src.channels() != 1) {
string error_message = format("The image must be single channel (grayscale), but the image has %d channels.", src.channels());
CV_Error(CV_StsUnsupportedFormat, error_message);
if(_histograms.empty()) {
// throw error if no data (or simply return -1?)
string error_message = "This LBPH model is not computed yet. Did you call the train method?";
CV_Error(CV_StsBadArg, error_message);
}
Mat src = _src.getMat();
// get the spatial histogram from input image
Mat lbp_image = elbp(src, _radius, _neighbors);
Mat query = spatial_histogram(
@ -781,11 +788,11 @@ void LBPH::predict(InputArray _src, int &minClass, double &minDist) const {
// find 1-nearest neighbor
minDist = DBL_MAX;
minClass = -1;
for(size_t sampleIdx = 0; sampleIdx < _histograms.size(); sampleIdx++) {
for(int sampleIdx = 0; sampleIdx < _histograms.size(); sampleIdx++) {
double dist = compareHist(_histograms[sampleIdx], query, CV_COMP_CHISQR);
if((dist < minDist) && (dist < _threshold)) {
minDist = dist;
minClass = _labels.at<int>((int)sampleIdx);
minClass = _labels.at<int>(sampleIdx);
}
}
}