摘要: 本文講解了梯度下降的基本概念,並以線性回歸為例詳細講解梯度下降演算法,主要以圖的形式講解,清晰簡單明了。

敏捷在軟體開發過程中是一個非常著名的術語,它背後的基本思想很簡單:快速構建一些東西,然後得到一些反饋,根據反饋做出改變,重複此過程。目標是讓產品更貼合用,讓用戶做出反饋,以獲得設計開發出的產品與優秀的產品二者之間誤差最小,梯度下降演算法背後的原理和這基本一樣。

目的

梯度下降演算法是一個迭代過程,它將獲得函數的最小值。下面的公式將整個梯度下降演算法匯總在一行中。

但是這個公式是如何得出的呢?實際上很簡單,只需要具備一些高中的數學知識即可理解。本文將嘗試講解這個公式,並以線性回歸模型為例,構建此類公式。

機器學習模型

  • 考慮二維空間中的一堆數據點。假設數據與一組學生的身高和體重有關。試圖預測這些數量之間的某種關係,以便我們可以預測一些新生的體重。這本質上是一種有監督學習的簡單例子。
  • 現在在空間中繪製一條穿過其中一些數據點的任意直線,該直線方程的形如Y=mX+b,其中m是斜率,b是其在Y軸的截距。

預測

給定一組已知的輸入及其相應的輸出,機器學習模型試圖對一組新的輸入做出一些預測。

兩個預測之間的差異即為錯誤。

這涉及成本函數或損失函數的概念(cost function or loss function)。

成本函數

成本函數/損失函數用來評估機器學習演算法的性能。二者的區別在於,損失函數計算單個訓練示例的錯誤,而成本函數是整個訓練集上錯誤的平均值。

成本函數基本上能告訴我們模型在給定m和b的值時,其預測能「有多好」。

比方說,數據集中總共有N個點,我們想要最小化所有N個數據點的誤差。因此,成本函數將是總平方誤差,即

為什麼採取平方差而不是絕對差?因為平方差使得導出回歸線更容易。實際上,為了找到這條直線,我們需要計算成本函數的一階導數,而計算絕對值的導數比平方值更難。

最小化成本函數

任何機器學習演算法的目標都是最小化成本函數。

這是因為實際值和預測值之間的誤差對應著表示演算法在學習方面的性能。由於希望誤差值最小,因此盡量使得那些mb值能夠產生儘可能小的誤差。

如何最小化一個任意函數?

仔細觀察上述的成本函數,其形式為Y=X2。在笛卡爾坐標系中,這是一個拋物線方程,用圖形表示如下:

為了最小化上面的函數,需要找到一個x,函數在該點能產生小值Y,即圖中的紅點。由於這是一個二維圖像,因此很容易找到其最小值,但是在維度比較大的情況下,情況會更加複雜。對於種情況,需要設計一種演算法來定位最小值,該演算法稱為梯度下降演算法(Gradient Descent)。

梯度下降

梯度下降是優化模型的方法中最流行的演算法之一,也是迄今為止優化神經網路的最常用方法。它本質上是一種迭代優化演算法,用於查找函數的最小值。

表示

假設你是沿著下面的圖表走,目前位於曲線綠點處,而目標是到達最小值,即點位置,但你是無法看到該最低點。

可能採取的行動:
  • 可能向上或向下;
  • 如果決定走哪條路,可能會採取更大的步伐或小的步伐來到達目的地;

從本質上講,你應該知道兩件事來達到最小值,即走哪條和走多遠。

梯度下降演算法通過使用導數幫助我們有效地做出這些決策。導數是來源於積分,用於計算曲線特定點處的斜率。通過在該點處繪製圖形的切線來描述斜率。因此,如果能夠計算出這條切線,可能就能夠計算達到最小值的所需方向。

最小值

在下圖中,在綠點處繪製切線,如果向上移動,就將遠離最小值,反之亦然。此外,切線也能讓我們感覺到斜坡的陡峭程度。

藍點處的斜率比綠點處的斜率低,這意味著從藍點到綠點所需的步長要小得多。

成本函數的數學解釋

現在將上述內容納入數學公式中。在等式y=mX+b中,mb是其參數。在訓練過程中,其值也會發生微小變化,用δ表示這個小的變化。參數值將分別更新為m = m-δmb = b-δb。最終目標是找到mb的值,以使得y=mx+b 的誤差最小,即最小化成本函數。

重寫成本函數:

想法是,通過計算函數的導數/斜率,就可以找到函數的最小值。

學習率

達到最小值或最低值所採取的步長大小稱為學習率。學習率可以設置的比較大,但有可能會錯過最小值。而另一方面,小的學習率將花費大量時間訓練以達到最低點。

下面的可視化給出了學習率的基本概念。在第三個圖中,以最小步數達到最小點,這表明該學習率是此問題的最佳學習率。

從上圖可以看到,當學習率太低時,需要花費很長訓練時間才能收斂。而另一方面,當學習率太高時,梯度下降未達到最小值,如下面所示:

導數

機器學習在優化問題中使用導數。梯度下降等優化演算法使用導數來決定是增加還是減少權重,進而增加或減少目標函數。

如果能夠計算出函數的導數,就可以知道在哪個方向上能到達最小化。

主要處理方法源自於微積分中的兩個基本概念:

  • 指數法則指數法則求導公式:
  • 鏈式法則鏈式法則用於計算複合函數的導數,如果變數z取決於變數y,且它本身也依賴於變數x,因此y和z是因變數,那麼z對x的導數也與y有,這稱為鏈式法則,在數學上寫為:

舉個例子加強理解:

使用指數法則和鏈式發規,計算成本函數相對於m和c的變化方式。這涉及偏導數的概念,即如果存在兩個變數的函數,那麼為了找到該函數對其中一個變數的偏導數,需將另一個變數視為常數。舉個例子加強理解:

計算梯度下降

現在將這些微積分法則的知識應用到原始方程中,並找到成本函數的導數,即mb。修改成本函數方程:

為簡單起見,忽略求和符號。求和部分其實很重要,尤其是隨機梯度下降(SGD)與批量梯度下降的概念。在批量梯度下降期間,我們一次查看所有訓練樣例的錯誤,而在SGD中一次只查看其中的一個錯誤。這裡為了簡單起見,假設一次只查看其中的一個錯誤:

現在計算誤差對m和b的梯度:

將值對等到成本函數中並將其乘以學習率:

其中這個等式中的係數項2是一個常數,求導時並不重要,這裡將其忽略。因此,最終,整篇文章歸結為兩個簡單的方程式,它們代表了梯度下降的方程。

其中m1b1是下一個位置的參數;m?b?是當前位置的參數。

因此,為了求解梯度,使用新的mb值迭代數據點並計算偏導數。這個新的梯度會告訴我們當前位置的成本函數的斜率以及我們應該更新參數的方向。另外更新參數的步長由學習率控制。

結論

本文的重點是展示梯度下降的基本概念,並以線性回歸為例講解梯度下降演算法。通過繪製最佳擬合線來衡量學生身高和體重之間的關係。但是,這裡為了簡單起見,舉的例子是機器學習演算法中較簡單的線性回歸模型,讀者也可以將其應用到其它機器學習方法中。

以上為譯文,由阿里云云棲社區組織翻譯。

譯文鏈接

文章原標題《Understanding the Mathematics behind Gradient Descent》

譯者:海棠,審校:Uncle_LLD。

文章簡譯,更為詳細的內容,請查看原文。

更多技術乾貨敬請關注云棲社區知乎機構號:阿里云云棲社區 - 知乎

本文為雲棲社區原創內容,未經允許不得轉載。


推薦閱讀:
相关文章