雪花臺灣

RNN 的梯度消失問題

前言

在講述RNN與LSTM的時候曾經說過, RNN一個最大的缺陷就是梯度消失與梯度爆炸問題, 由於這一缺陷,使得RNN在長文本中難以訓練, 這才誕生了LSTM及各種變體。

梯度消失與梯度爆炸問題不僅僅在RNN中存在, 在其餘深層網路中同樣普遍存在, 本文先從梯度消失與梯度爆炸談起, 然後再回到具體的RNN梯度問題中, 最後, 再探討LSTM是通過什麼技術來解決RNN中的這一問題的。

淺談神經網路中的梯度爆炸問題

機器學習總結(二):梯度消失和梯度爆炸

神經網路訓練中的梯度消失與梯度爆炸

1. 什麼是梯度消失,梯度爆炸?

首先,你得知道梯度是什麼, 不瞭解的可以看看我關於梯度下降的那篇文章。

從那篇文章中就可以看出對於神經網路的訓練,梯度在訓練中起到很關鍵的作用。 如果在訓練過程中發生了梯度消失,這也就意味著我們的權重無法被更新,最終導致訓練失敗。而梯度爆炸所帶來的梯度過大,從而大幅度更新網路參數,造成網路不穩定(可以理解為梯度步伐太大)。在極端情況下,權重的值變得特別大,以至於結果會溢出(NaN值)

注意,梯度消失和梯度爆炸只會造成神經網路中較淺的網路的權重無法更新(畢竟神經網路中是反向傳播)

2. 這會造成哪些問題?

3. 原因何在?

讓我們以一個很簡單的例子分析一下,這樣便於理解。

如上圖,是一個每層只有一個神經元的神經網路,且每一層的激活函數為sigmoid,則有:

( 是sigmoid函數)。

我們根據反向傳播演算法有:

而sigmoid函數的導數公式為: 它的圖形曲線為:

由上可見,sigmoid函數的導數 的最大值為 ,通常我們會將權重初始值 初始化為為小於1的隨機值,因此我們可以得到 ,隨著層數的增多,那麼求導結果 越小,這也就導致了梯度消失問題。

那麼如果我們設置初始權重 較大,那麼會有 ,造成梯度太大(也就是下降的步伐太大),這也是造成梯度爆炸的原因。

總之,無論是梯度消失還是梯度爆炸,都是源於網路結構太深,造成網路權重不穩定,從本質上來講是因為梯度反向傳播中的連乘效應。

4. RNN中的梯度消失,爆炸問題

參考:RNN梯度消失和爆炸的原因, 這篇文章是我看到講的最清楚的了,在這裡添加一些我的思考, 若侵立刪。

我們給定一個三個時間的RNN單元,如下:

我們假設最左端的輸入 為給定值, 且神經元中沒有激活函數(便於分析), 則前向過程如下:

時刻, 損失函數為 ,那麼如果我們要訓練RNN時, 實際上就是是對 求偏導, 並不斷調整它們以使得 儘可能達到最小(參見反向傳播演算法與梯度下降演算法)。

那麼我們得到以下公式:

將上述偏導公式與第三節中的公式比較,我們發現, 隨著神經網路層數的加深對 而言並沒有什麼影響, 而對 會隨著時間序列的拉長而產生梯度消失和梯度爆炸問題。

根據上述分析整理一下公式可得, 對於任意時刻t對 求偏導的公式為:

我們發現, 導致梯度消失和爆炸的就在於 , 而加上激活函數後的S的表達式為:

那麼則有:

而在這個公式中, tanh的導數總是小於1 的, 如果 也是一個大於0小於1的值, 那麼隨著t的增大, 上述公式的值越來越趨近於0, 這就導致了梯度消失問題。 那麼如果 很大, 上述公式會越來越趨向於無窮, 這就產生了梯度爆炸。

5. 為什麼LSTM能解決梯度問題?

在閱讀此篇文章之前,確保自己對LSTM的三門機制有一定了解, 參見:LSTM:RNN最常用的變體

從上述中我們知道, RNN產生梯度消失與梯度爆炸的原因就在於 , 如果我們能夠將這一坨東西去掉, 我們的不就解決掉梯度問題了嗎。 LSTM通過門機制來解決了這個問題。

我們先從LSTM的三個門公式出發:

  • 遺忘門:
  • 輸入門:
  • 輸出門:
  • 當前單元狀態 :
  • 當前時刻的隱層輸出:

我們注意到, 首先三個門的激活函數是sigmoid, 這也就意味著這三個門的輸出要麼接近於0 , 要麼接近於1。這就使得 是非0即1的,當門為1時, 梯度能夠很好的在LSTM中傳遞,很大程度上減輕了梯度消失發生的概率, 當門為0時,說明上一時刻的信息對當前時刻沒有影響, 我們也就沒有必要傳遞梯度回去來更新參數了。所以, 這就是為什麼通過門機制就能夠解決梯度的原因: 使得單元間的傳遞 為0 或 1。

最後

理解基本的神經網路單元是必要的, 但在上層task的使用中, 基本的神經單元通常只是拿來即用, 更多的在deep, 遷移學習, embedding , attention上做文章,這需要多看paper, 多寫代碼, 慢慢積累。 最後, 願我早日超神。

推薦閱讀:

相關文章