很多問題最終歸結為一個最小二乘問題,如SLAM演算法中的Bundle Adjustment,位姿圖優化等等。求解最小二乘的方法有很多,高斯-牛頓法就是其中之一。

推導

對於一個非線性最小二乘問題:

x = mathrm{arg}min_{x}frac{1}{2}parallel f(x) parallel^2.    (1)

高斯牛頓的思想是把 f(x) 利用泰勒展開,取一階線性項近似。

f(x+Delta x)=f(x) +f(x)Delta x = f(x) +J(x)Delta x.    (2)

帶入到(1)式:

frac{1}{2}parallel f(x+Delta x) parallel^2 = frac{1}{2} { f(x)^Tf(x) + 2f(x)^TJ(x)Delta x +Delta x^TJ(x)^TJ(x)Delta x}.     (3)

對上式求導,令導數為0。

J(x)^TJ(x)Delta x = - J(x)^Tf(x).    (4)

H=J^TJB=-J^Tf , 式(4)即為

HDelta x = B.    (5)

求解式(5),便可以獲得調整增量 Delta x 。這要求 H可逆(正定),但實際情況並不一定滿足這個條件,因此可能發散,另外步長Delta x可能太大,也會導致發散。

綜上,高斯牛頓法的步驟為

STEP1. 給定初值 x_0

STEP2. 對於第k次迭代,計算 雅克比 J , 矩陣HB ;根據(5)式計算增量 Delta x_k ;STEP3. 如果 Delta x_k 足夠小,就停止迭代,否則,更新 x_{k+1}=x_k+Delta x_k .STEP4. 循環執行STEP2. SPTE3,直到達到最大循環次數,或者滿足STEP3的終止條件。

編程實現

問題:非線性方程: y=exp(ax^2+bx+c) ,給定n組觀測數據 {x,y} ,求係數 X=[ a,b,c ]^T .

分析:令 f(X)=y-exp(ax^2+bx+c) ,N組數據可以組成一個大的非線性方程組

F(X)=left[egin{array}[c]{c} y_1-exp(ax_1^2+bx_1+c)\ dots\ y_N-exp(ax_N^2+bx_N+c)end{array} 
ight]

我們可以構建一個最小二乘問題:

x = mathrm{arg}min_{x}frac{1}{2}parallel F(X) parallel^2 .

要求解這個問題,根據推導部分可知,需要求解雅克比。

J(X)=left[egin{matrix} -x_1^2exp(ax_1^2+bx_1+c) & -x_1exp(ax_1^2+bx_1+c) &-exp(ax_1^2+bx_1+c) \ dots & dots & dots \ -x_N^2exp(ax_N^2+bx_N+c) & -x_Nexp(ax_N^2+bx_N+c) &-exp(ax_N^2+bx_N+c) end{matrix} 
ight]

使用推導部分所述的步驟就可以進行解算。代碼實現:

ydsf16/Gauss_Newton_solver?

github.com


/**
* This file is part of Gauss-Newton Solver.
*
* Copyright (C) 2018-2020 Dongsheng Yang <[email protected]> (Beihang University)
* For more information see <https://github.com/ydsf16/Gauss_Newton_solver>
*
* Gauss_Newton_solver is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Gauss_Newton_solver is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Gauss_Newton_solver. If not, see <http://www.gnu.org/licenses/>.
*/

#include <iostream>
#include <eigen3/Eigen/Core>
#include <vector>
#include <opencv2/opencv.hpp>
#include <eigen3/Eigen/Cholesky>
#include <eigen3/Eigen/QR>
#include <eigen3/Eigen/SVD>
#include <chrono>

/* 計時類 */
class Runtimer{
public:
inline void start()
{
t_s_ = std::chrono::steady_clock::now();
}

inline void stop()
{
t_e_ = std::chrono::steady_clock::now();
}

inline double duration()
{
return std::chrono::duration_cast<std::chrono::duration<double>>(t_e_ - t_s_).count() * 1000.0;
}

private:
std::chrono::steady_clock::time_point t_s_; //start time ponit
std::chrono::steady_clock::time_point t_e_; //stop time point
};

/* 優化方程 */
class CostFunction{
public:
CostFunction(double* a, double* b, double* c, int max_iter, double min_step, bool is_out):
a_(a), b_(b), c_(c), max_iter_(max_iter), min_step_(min_step), is_out_(is_out)
{}

void addObservation(double x, double y)
{
std::vector<double> ob;
ob.push_back(x);
ob.push_back(y);
obs_.push_back(ob);
}

void calcJ_fx()
{
J_ .resize(obs_.size(), 3);
fx_.resize(obs_.size(), 1);

for ( size_t i = 0; i < obs_.size(); i ++)
{
std::vector<double>& ob = obs_.at(i);
double& x = ob.at(0);
double& y = ob.at(1);
double j1 = -x*x*exp(*a_ * x*x + *b_*x + *c_);
double j2 = -x*exp(*a_ * x*x + *b_*x + *c_);
double j3 = -exp(*a_ * x*x + *b_*x + *c_);
J_(i, 0 ) = j1;
J_(i, 1) = j2;
J_(i, 2) = j3;
fx_(i, 0) = y - exp( *a_ *x*x + *b_*x +*c_);
}
}

void calcH_b()
{
H_ = J_.transpose() * J_;
B_ = -J_.transpose() * fx_;
}

void calcDeltax()
{
deltax_ = H_.ldlt().solve(B_);
}

void updateX()
{
*a_ += deltax_(0);
*b_ += deltax_(1);
*c_ += deltax_(2);
}

double getCost()
{
Eigen::MatrixXd cost= fx_.transpose() * fx_;
return cost(0,0);
}

void solveByGaussNewton()
{
double sumt =0;
bool is_conv = false;
for( size_t i = 0; i < max_iter_; i ++)
{
Runtimer t;
t.start();
calcJ_fx();
calcH_b();
calcDeltax();
double delta = deltax_.transpose() * deltax_;
t.stop();
if( is_out_ )
{
std::cout << "Iter: " << std::left <<std::setw(3) << i << " Result: "<< std::left <<std::setw(10) << *a_ << " " << std::left <<std::setw(10) << *b_ << " " << std::left <<std::setw(10) << *c_ <<
" step: " << std::left <<std::setw(14) << delta << " cost: "<< std::left <<std::setw(14) << getCost() << " time: " << std::left <<std::setw(14) << t.duration() <<
" total_time: "<< std::left <<std::setw(14) << (sumt += t.duration()) << std::endl;
}
if( delta < min_step_)
{
is_conv = true;
break;
}
updateX();
}

if( is_conv == true)
std::cout << "
Converged
";
else
std::cout << "
Diverged

";
}

Eigen::MatrixXd fx_;
Eigen::MatrixXd J_; // 雅克比矩陣
Eigen::Matrix3d H_; // H矩陣
Eigen::Vector3d B_;
Eigen::Vector3d deltax_;
std::vector< std::vector<double> > obs_; // 觀測
double* a_, *b_, *c_;

int max_iter_;
double min_step_;
bool is_out_;
};//class CostFunction

int main(int argc, char **argv) {

const double aa = 0.1, bb = 0.5, cc = 2; // 實際方程的參數
double a =0.0, b=0.0, c=0.0; // 初值

/* 構造問題 */
CostFunction cost_func(&a, &b, &c, 50, 1e-10, true);

/* 製造數據 */
const size_t N = 100; //數據個數
cv::RNG rng(cv::getTickCount());
for( size_t i = 0; i < N; i ++)
{
/* 生產帶有高斯雜訊的數據 */
double x = rng.uniform(0.0, 1.0) ;
double y = exp(aa*x*x + bb*x + cc) + rng.gaussian(0.05);

/* 添加到觀測中 */
cost_func.addObservation(x, y);
}
/* 用高斯牛頓法求解 */
cost_func.solveByGaussNewton();
return 0;
}

推薦閱讀:

相關文章