2024-06-22 04:39:50 +08:00
|
|
|
diff --git a/src/shogun/CMakeLists.txt b/src/shogun/CMakeLists.txt
|
|
|
|
index 31a0d2c7b..e700bd7c7 100644
|
|
|
|
--- a/src/shogun/CMakeLists.txt
|
|
|
|
+++ b/src/shogun/CMakeLists.txt
|
|
|
|
@@ -307,7 +307,7 @@ IF(NOT EIGEN3_FOUND)
|
|
|
|
)
|
|
|
|
ELSE()
|
|
|
|
# https://github.com/shogun-toolbox/shogun/issues/4870
|
|
|
|
- IF(${EIGEN3_VERSION_STRING} VERSION_GREATER 3.3.9)
|
|
|
|
+ IF(0)
|
|
|
|
MESSAGE(FATAL_ERROR "The system Eigen3 version ${EIGEN3_VERSION_STRING} isn't supported!")
|
|
|
|
ENDIF()
|
|
|
|
SHOGUN_INCLUDE_DIRS(SCOPE PUBLIC SYSTEM ${EIGEN3_INCLUDE_DIR})
|
|
|
|
diff --git a/src/shogun/machine/gp/MultiLaplaceInferenceMethod.cpp b/src/shogun/machine/gp/MultiLaplaceInferenceMethod.cpp
|
|
|
|
index a1677177e..0c9ca8f78 100644
|
2022-02-11 02:06:26 +08:00
|
|
|
--- a/src/shogun/machine/gp/MultiLaplaceInferenceMethod.cpp
|
|
|
|
+++ b/src/shogun/machine/gp/MultiLaplaceInferenceMethod.cpp
|
2024-06-22 04:39:50 +08:00
|
|
|
@@ -87,10 +87,10 @@ public:
|
2022-02-11 02:06:26 +08:00
|
|
|
float64_t result=0;
|
|
|
|
for(index_t bl=0; bl<C; bl++)
|
|
|
|
{
|
2024-06-22 04:39:50 +08:00
|
|
|
- eigen_f.block(bl * n, 0, n, 1) =
|
|
|
|
- K * alpha->block(bl * n, 0, n, 1) * std::exp(log_scale * 2.0);
|
2022-02-11 02:06:26 +08:00
|
|
|
- result+=alpha->block(bl*n,0,n,1).dot(eigen_f.block(bl*n,0,n,1))/2.0;
|
|
|
|
- eigen_f.block(bl*n,0,n,1)+=eigen_m;
|
2024-06-22 04:39:50 +08:00
|
|
|
+ eigen_f.segment(bl * n, n) =
|
|
|
|
+ K * alpha->segment(bl * n, n) * std::exp(log_scale * 2.0);
|
2022-02-11 02:06:26 +08:00
|
|
|
+ result+=alpha->segment(bl*n,n).dot(eigen_f.segment(bl*n,n))/2.0;
|
|
|
|
+ eigen_f.segment(bl*n,n)+=eigen_m;
|
|
|
|
}
|
|
|
|
|
|
|
|
// get first and second derivatives of log likelihood
|
2024-06-22 04:39:50 +08:00
|
|
|
@@ -278,9 +278,9 @@ void MultiLaplaceInferenceMethod::update_alpha()
|
2022-02-11 02:06:26 +08:00
|
|
|
{
|
|
|
|
Map<VectorXd> alpha(m_alpha.vector, m_alpha.vlen);
|
|
|
|
for(index_t bl=0; bl<C; bl++)
|
2024-06-22 04:39:50 +08:00
|
|
|
- eigen_mu.block(bl * n, 0, n, 1) = eigen_ktrtr *
|
|
|
|
+ eigen_mu.segment(bl * n, n) = eigen_ktrtr *
|
|
|
|
std::exp(m_log_scale * 2.0) *
|
|
|
|
- alpha.block(bl * n, 0, n, 1);
|
|
|
|
+ alpha.segment(bl * n, n);
|
2022-02-11 02:06:26 +08:00
|
|
|
|
|
|
|
//alpha'*(f-m)/2.0
|
|
|
|
Psi_New=alpha.dot(eigen_mu)/2.0;
|
2024-06-22 04:39:50 +08:00
|
|
|
@@ -324,7 +324,7 @@ void MultiLaplaceInferenceMethod::update_alpha()
|
2022-02-11 02:06:26 +08:00
|
|
|
|
|
|
|
for(index_t bl=0; bl<C; bl++)
|
|
|
|
{
|
|
|
|
- VectorXd eigen_sD=eigen_dpi.block(bl*n,0,n,1).cwiseSqrt();
|
|
|
|
+ VectorXd eigen_sD=eigen_dpi.segment(bl*n,n).cwiseSqrt();
|
2024-06-22 04:39:50 +08:00
|
|
|
LLT<MatrixXd> chol_tmp(
|
|
|
|
(eigen_sD * eigen_sD.transpose())
|
|
|
|
.cwiseProduct(eigen_ktrtr * std::exp(m_log_scale * 2.0)) +
|
|
|
|
@@ -351,14 +351,14 @@ void MultiLaplaceInferenceMethod::update_alpha()
|
2022-02-11 02:06:26 +08:00
|
|
|
VectorXd tmp2=m_tmp.array().rowwise().sum();
|
|
|
|
|
|
|
|
for(index_t bl=0; bl<C; bl++)
|
|
|
|
- eigen_b.block(bl*n,0,n,1)+=eigen_dpi.block(bl*n,0,n,1).cwiseProduct(eigen_mu.block(bl*n,0,n,1)-eigen_mean_bl-tmp2);
|
|
|
|
+ eigen_b.segment(bl*n,n)+=eigen_dpi.segment(bl*n,n).cwiseProduct(eigen_mu.segment(bl*n,n)-eigen_mean_bl-tmp2);
|
|
|
|
|
|
|
|
Map<VectorXd> &eigen_c=eigen_W;
|
|
|
|
for(index_t bl=0; bl<C; bl++)
|
2024-06-22 04:39:50 +08:00
|
|
|
- eigen_c.block(bl * n, 0, n, 1) =
|
|
|
|
+ eigen_c.segment(bl * n, n) =
|
|
|
|
eigen_E.block(0, bl * n, n, n) *
|
|
|
|
(eigen_ktrtr * std::exp(m_log_scale * 2.0) *
|
|
|
|
- eigen_b.block(bl * n, 0, n, 1));
|
|
|
|
+ eigen_b.segment(bl * n, n));
|
2022-02-11 02:06:26 +08:00
|
|
|
|
|
|
|
Map<MatrixXd> c_tmp(eigen_c.data(),n,C);
|
|
|
|
|
2024-06-22 04:39:50 +08:00
|
|
|
@@ -422,7 +422,7 @@ float64_t MultiLaplaceInferenceMethod::get_derivative_helper(SGMatrix<float64_t>
|
2022-02-11 02:06:26 +08:00
|
|
|
{
|
|
|
|
result+=((eigen_E.block(0,bl*n,n,n)-eigen_U.block(0,bl*n,n,n).transpose()*eigen_U.block(0,bl*n,n,n)).array()
|
|
|
|
*eigen_dK.array()).sum();
|
|
|
|
- result-=(eigen_dK*eigen_alpha.block(bl*n,0,n,1)).dot(eigen_alpha.block(bl*n,0,n,1));
|
|
|
|
+ result-=(eigen_dK*eigen_alpha.segment(bl*n,n)).dot(eigen_alpha.segment(bl*n,n));
|
|
|
|
}
|
|
|
|
|
|
|
|
return result/2.0;
|
2024-06-22 04:39:50 +08:00
|
|
|
@@ -504,7 +504,7 @@ SGVector<float64_t> MultiLaplaceInferenceMethod::get_derivative_wrt_mean(
|
2022-02-11 02:06:26 +08:00
|
|
|
result[i]=0;
|
|
|
|
//currently only compute the explicit term
|
|
|
|
for(index_t bl=0; bl<C; bl++)
|
|
|
|
- result[i]-=eigen_alpha.block(bl*n,0,n,1).dot(eigen_dmu);
|
|
|
|
+ result[i]-=eigen_alpha.segment(bl*n,n).dot(eigen_dmu);
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|