diff --git a/modules/core/include/opencv2/core/optim.hpp b/modules/core/include/opencv2/core/optim.hpp index 4f1749ec97..18e733f47b 100644 --- a/modules/core/include/opencv2/core/optim.hpp +++ b/modules/core/include/opencv2/core/optim.hpp @@ -63,9 +63,11 @@ public: class CV_EXPORTS Function { public: - virtual ~Function() {} - virtual double calc(const double* x) const = 0; - virtual void getGradient(const double* /*x*/,double* /*grad*/) {} + virtual ~Function() {} + virtual int getDims() const = 0; + virtual double getGradientEps() const; + virtual double calc(const double* x) const = 0; + virtual void getGradient(const double* x,double* grad); }; /** @brief Getter for the optimized function. diff --git a/modules/core/src/conjugate_gradient.cpp b/modules/core/src/conjugate_gradient.cpp index 90353cc7fd..1259cc9756 100644 --- a/modules/core/src/conjugate_gradient.cpp +++ b/modules/core/src/conjugate_gradient.cpp @@ -46,6 +46,25 @@ namespace cv { + double MinProblemSolver::Function::getGradientEps() const { return 1e-3; } + void MinProblemSolver::Function::getGradient(const double* x, double* grad) + { + double eps = getGradientEps(); + int i, n = getDims(); + AutoBuffer x_buf(n); + double* x_ = x_buf; + for( i = 0; i < n; i++ ) + x_[i] = x[i]; + for( i = 0; i < n; i++ ) + { + x_[i] = x[i] + eps; + double y1 = calc(x_); + x_[i] = x[i] - eps; + double y0 = calc(x_); + grad[i] = (y1 - y0)/(2*eps); + x_[i] = x[i]; + } + } #define SEC_METHOD_ITERATIONS 4 #define INITIAL_SEC_METHOD_SIGMA 0.1 diff --git a/modules/core/src/downhill_simplex.cpp b/modules/core/src/downhill_simplex.cpp index 158dca6dbf..a0cc1320b8 100644 --- a/modules/core/src/downhill_simplex.cpp +++ b/modules/core/src/downhill_simplex.cpp @@ -235,6 +235,7 @@ protected: inline void createInitialSimplex( const Mat& x0, Mat& simplex, Mat& step ) { int i, j, ndim = step.cols; + CV_Assert( _Function->getDims() == ndim ); Mat x = x0; if( x0.empty() ) x = Mat::zeros(1, ndim, CV_64F); diff --git a/modules/core/test/test_conjugate_gradient.cpp b/modules/core/test/test_conjugate_gradient.cpp index d6b0908f85..a1fd6113c3 100644 --- a/modules/core/test/test_conjugate_gradient.cpp +++ b/modules/core/test/test_conjugate_gradient.cpp @@ -60,16 +60,19 @@ static void mytest(cv::Ptr solver,cv::Ptr solver,cv::Ptr