mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Return uncompressed support vectors for getSupportVectors on linear SVM (Bug #4096)
This commit is contained in:
parent
544990e377
commit
0d706f6796
@ -675,11 +675,19 @@ public:
|
|||||||
|
|
||||||
/** @brief Retrieves all the support vectors
|
/** @brief Retrieves all the support vectors
|
||||||
|
|
||||||
The method returns all the support vector as floating-point matrix, where support vectors are
|
The method returns all the support vectors as a floating-point matrix, where support vectors are
|
||||||
stored as matrix rows.
|
stored as matrix rows.
|
||||||
*/
|
*/
|
||||||
CV_WRAP virtual Mat getSupportVectors() const = 0;
|
CV_WRAP virtual Mat getSupportVectors() const = 0;
|
||||||
|
|
||||||
|
/** @brief Retrieves all the uncompressed support vectors of a linear %SVM
|
||||||
|
|
||||||
|
The method returns all the uncompressed support vectors of a linear %SVM that the compressed
|
||||||
|
support vector, used for prediction, was derived from. They are returned in a floating-point
|
||||||
|
matrix, where the support vectors are stored as matrix rows.
|
||||||
|
*/
|
||||||
|
CV_WRAP Mat getUncompressedSupportVectors() const;
|
||||||
|
|
||||||
/** @brief Retrieves the decision function
|
/** @brief Retrieves the decision function
|
||||||
|
|
||||||
@param i the index of the decision function. If the problem solved is regression, 1-class or
|
@param i the index of the decision function. If the problem solved is regression, 1-class or
|
||||||
|
@ -1241,6 +1241,12 @@ public:
|
|||||||
df_alpha.clear();
|
df_alpha.clear();
|
||||||
df_index.clear();
|
df_index.clear();
|
||||||
sv.release();
|
sv.release();
|
||||||
|
uncompressed_sv.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat getUncompressedSupportVectors_() const
|
||||||
|
{
|
||||||
|
return uncompressed_sv;
|
||||||
}
|
}
|
||||||
|
|
||||||
Mat getSupportVectors() const
|
Mat getSupportVectors() const
|
||||||
@ -1538,6 +1544,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
optimize_linear_svm();
|
optimize_linear_svm();
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1588,6 +1595,7 @@ public:
|
|||||||
|
|
||||||
setRangeVector(df_index, df_count);
|
setRangeVector(df_index, df_count);
|
||||||
df_alpha.assign(df_count, 1.);
|
df_alpha.assign(df_count, 1.);
|
||||||
|
sv.copyTo(uncompressed_sv);
|
||||||
std::swap(sv, new_sv);
|
std::swap(sv, new_sv);
|
||||||
std::swap(decision_func, new_df);
|
std::swap(decision_func, new_df);
|
||||||
}
|
}
|
||||||
@ -2056,6 +2064,21 @@ public:
|
|||||||
}
|
}
|
||||||
fs << "]";
|
fs << "]";
|
||||||
|
|
||||||
|
if ( !uncompressed_sv.empty() )
|
||||||
|
{
|
||||||
|
// write the joint collection of uncompressed support vectors
|
||||||
|
int uncompressed_sv_total = uncompressed_sv.rows;
|
||||||
|
fs << "uncompressed_sv_total" << uncompressed_sv_total;
|
||||||
|
fs << "uncompressed_support_vectors" << "[";
|
||||||
|
for( i = 0; i < uncompressed_sv_total; i++ )
|
||||||
|
{
|
||||||
|
fs << "[:";
|
||||||
|
fs.writeRaw("f", uncompressed_sv.ptr(i), uncompressed_sv.cols*uncompressed_sv.elemSize());
|
||||||
|
fs << "]";
|
||||||
|
}
|
||||||
|
fs << "]";
|
||||||
|
}
|
||||||
|
|
||||||
// write decision functions
|
// write decision functions
|
||||||
int df_count = (int)decision_func.size();
|
int df_count = (int)decision_func.size();
|
||||||
|
|
||||||
@ -2096,7 +2119,7 @@ public:
|
|||||||
svm_type_str == "NU_SVR" ? NU_SVR : -1;
|
svm_type_str == "NU_SVR" ? NU_SVR : -1;
|
||||||
|
|
||||||
if( svmType < 0 )
|
if( svmType < 0 )
|
||||||
CV_Error( CV_StsParseError, "Missing of invalid SVM type" );
|
CV_Error( CV_StsParseError, "Missing or invalid SVM type" );
|
||||||
|
|
||||||
FileNode kernel_node = fn["kernel"];
|
FileNode kernel_node = fn["kernel"];
|
||||||
if( kernel_node.empty() )
|
if( kernel_node.empty() )
|
||||||
@ -2168,14 +2191,31 @@ public:
|
|||||||
FileNode sv_node = fn["support_vectors"];
|
FileNode sv_node = fn["support_vectors"];
|
||||||
|
|
||||||
CV_Assert((int)sv_node.size() == sv_total);
|
CV_Assert((int)sv_node.size() == sv_total);
|
||||||
sv.create(sv_total, var_count, CV_32F);
|
|
||||||
|
|
||||||
|
sv.create(sv_total, var_count, CV_32F);
|
||||||
FileNodeIterator sv_it = sv_node.begin();
|
FileNodeIterator sv_it = sv_node.begin();
|
||||||
for( i = 0; i < sv_total; i++, ++sv_it )
|
for( i = 0; i < sv_total; i++, ++sv_it )
|
||||||
{
|
{
|
||||||
(*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
|
(*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int uncompressed_sv_total = (int)fn["uncompressed_sv_total"];
|
||||||
|
|
||||||
|
if( uncompressed_sv_total > 0 )
|
||||||
|
{
|
||||||
|
// read uncompressed support vectors
|
||||||
|
FileNode uncompressed_sv_node = fn["uncompressed_support_vectors"];
|
||||||
|
|
||||||
|
CV_Assert((int)uncompressed_sv_node.size() == uncompressed_sv_total);
|
||||||
|
uncompressed_sv.create(uncompressed_sv_total, var_count, CV_32F);
|
||||||
|
|
||||||
|
FileNodeIterator uncompressed_sv_it = uncompressed_sv_node.begin();
|
||||||
|
for( i = 0; i < uncompressed_sv_total; i++, ++uncompressed_sv_it )
|
||||||
|
{
|
||||||
|
(*uncompressed_sv_it).readRaw("f", uncompressed_sv.ptr(i), var_count*uncompressed_sv.elemSize());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// read decision functions
|
// read decision functions
|
||||||
int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
|
int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
|
||||||
FileNode df_node = fn["decision_functions"];
|
FileNode df_node = fn["decision_functions"];
|
||||||
@ -2207,7 +2247,7 @@ public:
|
|||||||
SvmParams params;
|
SvmParams params;
|
||||||
Mat class_labels;
|
Mat class_labels;
|
||||||
int var_count;
|
int var_count;
|
||||||
Mat sv;
|
Mat sv, uncompressed_sv;
|
||||||
vector<DecisionFunc> decision_func;
|
vector<DecisionFunc> decision_func;
|
||||||
vector<double> df_alpha;
|
vector<double> df_alpha;
|
||||||
vector<int> df_index;
|
vector<int> df_index;
|
||||||
@ -2221,6 +2261,14 @@ Ptr<SVM> SVM::create()
|
|||||||
return makePtr<SVMImpl>();
|
return makePtr<SVMImpl>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Mat SVM::getUncompressedSupportVectors() const
|
||||||
|
{
|
||||||
|
const SVMImpl* this_ = dynamic_cast<const SVMImpl*>(this);
|
||||||
|
if(!this_)
|
||||||
|
CV_Error(Error::StsNotImplemented, "the class is not SVMImpl");
|
||||||
|
return this_->getUncompressedSupportVectors_();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,3 +118,51 @@ TEST(ML_SVM, trainAuto_regression_5369)
|
|||||||
EXPECT_EQ(0., result0);
|
EXPECT_EQ(0., result0);
|
||||||
EXPECT_EQ(1., result1);
|
EXPECT_EQ(1., result1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class CV_SVMGetSupportVectorsTest : public cvtest::BaseTest {
|
||||||
|
public:
|
||||||
|
CV_SVMGetSupportVectorsTest() {}
|
||||||
|
protected:
|
||||||
|
virtual void run( int startFrom );
|
||||||
|
};
|
||||||
|
void CV_SVMGetSupportVectorsTest::run(int /*startFrom*/ )
|
||||||
|
{
|
||||||
|
int code = cvtest::TS::OK;
|
||||||
|
|
||||||
|
// Set up training data
|
||||||
|
int labels[4] = {1, -1, -1, -1};
|
||||||
|
float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} };
|
||||||
|
Mat trainingDataMat(4, 2, CV_32FC1, trainingData);
|
||||||
|
Mat labelsMat(4, 1, CV_32SC1, labels);
|
||||||
|
|
||||||
|
Ptr<SVM> svm = SVM::create();
|
||||||
|
svm->setType(SVM::C_SVC);
|
||||||
|
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
|
||||||
|
|
||||||
|
|
||||||
|
// Test retrieval of SVs and compressed SVs on linear SVM
|
||||||
|
svm->setKernel(SVM::LINEAR);
|
||||||
|
svm->train(trainingDataMat, cv::ml::ROW_SAMPLE, labelsMat);
|
||||||
|
|
||||||
|
Mat sv = svm->getSupportVectors();
|
||||||
|
CV_Assert(sv.rows == 1); // by default compressed SV returned
|
||||||
|
sv = svm->getUncompressedSupportVectors();
|
||||||
|
CV_Assert(sv.rows == 3);
|
||||||
|
|
||||||
|
|
||||||
|
// Test retrieval of SVs and compressed SVs on non-linear SVM
|
||||||
|
svm->setKernel(SVM::POLY);
|
||||||
|
svm->setDegree(2);
|
||||||
|
svm->train(trainingDataMat, cv::ml::ROW_SAMPLE, labelsMat);
|
||||||
|
|
||||||
|
sv = svm->getSupportVectors();
|
||||||
|
CV_Assert(sv.rows == 3);
|
||||||
|
sv = svm->getUncompressedSupportVectors();
|
||||||
|
CV_Assert(sv.rows == 0); // inapplicable for non-linear SVMs
|
||||||
|
|
||||||
|
|
||||||
|
ts->set_failed_test_info(code);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(ML_SVM, getSupportVectors) { CV_SVMGetSupportVectorsTest test; test.safe_run(); }
|
||||||
|
@ -65,7 +65,7 @@ int main(int, char**)
|
|||||||
//! [show_vectors]
|
//! [show_vectors]
|
||||||
thickness = 2;
|
thickness = 2;
|
||||||
lineType = 8;
|
lineType = 8;
|
||||||
Mat sv = svm->getSupportVectors();
|
Mat sv = svm->getUncompressedSupportVectors();
|
||||||
|
|
||||||
for (int i = 0; i < sv.rows; ++i)
|
for (int i = 0; i < sv.rows; ++i)
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user