前言

在講述RNN與LSTM的時候曾經說過, RNN一個最大的缺陷就是梯度消失與梯度爆炸問題, 由於這一缺陷,使得RNN在長文本中難以訓練, 這才誕生了LSTM及各種變體。

梯度消失與梯度爆炸問題不僅僅在RNN中存在, 在其餘深層網路中同樣普遍存在, 本文先從梯度消失與梯度爆炸談起, 然後再回到具體的RNN梯度問題中, 最後, 再探討LSTM是通過什麼技術來解決RNN中的這一問題的。

淺談神經網路中的梯度爆炸問題

機器學習總結(二):梯度消失和梯度爆炸

神經網路訓練中的梯度消失與梯度爆炸

1. 什麼是梯度消失,梯度爆炸?

首先,你得知道梯度是什麼, 不了解的可以看看我關於梯度下降的那篇文章。

從那篇文章中就可以看出對於神經網路的訓練,梯度在訓練中起到很關鍵的作用。 如果在訓練過程中發生了梯度消失,這也就意味著我們的權重無法被更新,最終導致訓練失敗。而梯度爆炸所帶來的梯度過大,從而大幅度更新網路參數,造成網路不穩定(可以理解為梯度步伐太大)。在極端情況下,權重的值變得特別大,以至於結果會溢出(NaN值)

注意,梯度消失和梯度爆炸只會造成神經網路中較淺的網路的權重無法更新(畢竟神經網路中是反向傳播)

2. 這會造成哪些問題?

  • 梯度消失會導致我們的神經網路中前面層的網路權重無法得到更新,也就停止了學習。
  • 梯度爆炸會使得學習不穩定, 參數變化太大導致無法獲取最優參數。
  • 在深度多層感知機網路中,梯度爆炸會導致網路不穩定,最好的結果是無法從訓練數據中學習,最壞的結果是由於權重值為NaN而無法更新權重。
  • 在循環神經網路(RNN)中,梯度爆炸會導致網路不穩定,使得網路無法從訓練數據中得到很好的學習,最好的結果是網路不能在長輸入數據序列上學習。

3. 原因何在?

讓我們以一個很簡單的例子分析一下,這樣便於理解。

如上圖,是一個每層只有一個神經元的神經網路,且每一層的激活函數為sigmoid,則有:

y_i = sigma(z_i) = sigma(w_ix_i + b_i) ( sigma 是sigmoid函數)。

我們根據反向傳播演算法有:

 frac{delta C}{delta b_1} = frac{delta C}{ delta y_4} frac{delta y_4}{delta z_4} frac{delta z_4}{delta x_4} frac{ delta x_4}{delta z_3} frac{delta z_3}{ delta x_3} frac{ delta x_3}{delta z_2} frac{delta z_2}{ delta x_2} frac{ delta x_2}{delta z_1} frac{delta z_1}{delta b_1} \ = frac{ delta C}{delta y_4} (sigma (z_4) w_4)( sigma(z_3) w_3)( sigma  (z_2) w_2)( sigma  (z_1))

而sigmoid函數的導數公式為:  S(x ) = frac{e^{-x}}{(1+ e^{-x})^2} = S(x)(1- S(x)) 它的圖形曲線為:

由上可見,sigmoid函數的導數 sigma(x) 的最大值為 frac{1}{4} ,通常我們會將權重初始值 |w| 初始化為為小於1的隨機值,因此我們可以得到 |sigma (z_4) w_4| < frac{1}{4} ,隨著層數的增多,那麼求導結果 frac{delta C}{delta b_1} 越小,這也就導致了梯度消失問題。

那麼如果我們設置初始權重 |w| 較大,那麼會有 |sigma (z_4) w_4| > 1 ,造成梯度太大(也就是下降的步伐太大),這也是造成梯度爆炸的原因。

總之,無論是梯度消失還是梯度爆炸,都是源於網路結構太深,造成網路權重不穩定,從本質上來講是因為梯度反向傳播中的連乘效應。

4. RNN中的梯度消失,爆炸問題

參考:RNN梯度消失和爆炸的原因, 這篇文章是我看到講的最清楚的了,在這裡添加一些我的思考, 若侵立刪。

我們給定一個三個時間的RNN單元,如下:

我們假設最左端的輸入 S_0 為給定值, 且神經元中沒有激活函數(便於分析), 則前向過程如下:

S_1 = W_xX_1 + W_sS_0 + b_1 qquad qquad qquad O_1 = W_oS_1 + b_2 \ S_2 = W_xX_2 + W_sS_1 + b_1 qquad qquad qquad O_2 = W_oS_2 + b_2 \ S_3 = W_xX_3 + W_sS_2 + b_1 qquad qquad qquad O_3 = W_oS_3 + b_2

t=3 時刻, 損失函數為 L_3 = frac{1}{2}(Y_3 - O_3)^2 ,那麼如果我們要訓練RNN時, 實際上就是是對 W_x, W_s, W_o,b_1,b_2 求偏導, 並不斷調整它們以使得 L_3 儘可能達到最小(參見反向傳播演算法與梯度下降演算法)。

那麼我們得到以下公式:

frac{delta L_3}{delta W_0} = frac{delta L_3}{delta O_3} frac{delta O_3}{delta W_0} \ frac{delta L_3}{delta W_x} = frac{delta L_3}{delta O_3} frac{delta O_3}{delta S_3} frac{delta S_3}{delta W_x} + frac{delta L_3}{delta O_3} frac{delta O_3}{delta S_3} frac{delta S_3}{delta S_2} frac{delta S_2}{delta W_x} + frac{delta L_3}{delta O_3} frac{delta O_3}{delta S_3} frac{delta S_3}{delta S_2} frac{delta S_2}{delta S_1}frac{delta S_1}{delta W_x} \ frac{delta L_3}{delta W_s} = frac{delta L_3}{delta O_3} frac{delta O_3}{delta S_3} frac{delta S_3}{delta W_s} + frac{delta L_3}{delta O_3} frac{delta O_3}{delta S_3} frac{delta S_3}{delta S_2} frac{delta S_2}{delta W_s} + frac{delta L_3}{delta O_3} frac{delta O_3}{delta S_3} frac{delta S_3}{delta S_2} frac{delta S_2}{delta S_1}frac{delta S_1}{delta W_s}

將上述偏導公式與第三節中的公式比較,我們發現, 隨著神經網路層數的加深對 W_0 而言並沒有什麼影響, 而對 W_x, W_s 會隨著時間序列的拉長而產生梯度消失和梯度爆炸問題。

根據上述分析整理一下公式可得, 對於任意時刻t對 W_x, W_s 求偏導的公式為:

frac{delta L_t}{delta W_x } = sum_{k=0}^t frac{delta L_t}{delta O_t} frac{delta O_t}{delta S_t}( prod_{j=k+1}^t frac{delta S_j}{delta S_{j-1}} ) frac{ delta S_k }{delta W_x} \ frac{delta L_t}{delta W_s } = sum_{k=0}^t frac{delta L_t}{delta O_t} frac{delta O_t}{delta S_t}( prod_{j=k+1}^t frac{delta S_j}{delta S_{j-1}} ) frac{ delta S_k }{delta W_s}

我們發現, 導致梯度消失和爆炸的就在於 prod_{j=k+1}^t frac{delta S_j}{delta S_{j-1}} , 而加上激活函數後的S的表達式為:

S_j = tanh(W_xX_j + W_sS_{j-1} + b_1)

那麼則有:

prod_{j=k+1}^t frac{delta S_j}{delta S_{j-1}} = prod_{j=k+1}^t tanh W_s

而在這個公式中, tanh的導數總是小於1 的, 如果 W_s 也是一個大於0小於1的值, 那麼隨著t的增大, 上述公式的值越來越趨近於0, 這就導致了梯度消失問題。 那麼如果 W_s 很大, 上述公式會越來越趨向於無窮, 這就產生了梯度爆炸。

5. 為什麼LSTM能解決梯度問題?

在閱讀此篇文章之前,確保自己對LSTM的三門機制有一定了解, 參見:LSTM:RNN最常用的變體

從上述中我們知道, RNN產生梯度消失與梯度爆炸的原因就在於 prod_{j=k+1}^t frac{delta S_j}{delta S_{j-1}} , 如果我們能夠將這一坨東西去掉, 我們的不就解決掉梯度問題了嗎。 LSTM通過門機制來解決了這個問題。

我們先從LSTM的三個門公式出發:

  • 遺忘門: f_t = sigma( W_f cdot [h_{t-1}, x_t] + b_f)
  • 輸入門: i_t = sigma(W_i cdot [h_{t-1}, x_t] + b_i)
  • 輸出門: o_t = sigma(W_o cdot [h_{t-1}, x_t ] + b_0 )
  • 當前單元狀態 c_t : c_t = f_t circ c_{t-1} + i_t circ tanh(W_c cdot [h_{t-1}, x_t] + b_c )
  • 當前時刻的隱層輸出: h_t = o_t circ tanh(c_t)

我們注意到, 首先三個門的激活函數是sigmoid, 這也就意味著這三個門的輸出要麼接近於0 , 要麼接近於1。這就使得 frac{delta c_t}{delta c_{t-1}} = f_t, frac{delta h_t}{delta h_{t-1}} = o_t 是非0即1的,當門為1時, 梯度能夠很好的在LSTM中傳遞,很大程度上減輕了梯度消失發生的概率, 當門為0時,說明上一時刻的信息對當前時刻沒有影響, 我們也就沒有必要傳遞梯度回去來更新參數了。所以, 這就是為什麼通過門機制就能夠解決梯度的原因: 使得單元間的傳遞 frac{delta S_j}{delta S_{j-1}} 為0 或 1。

最後

理解基本的神經網路單元是必要的, 但在上層task的使用中, 基本的神經單元通常只是拿來即用, 更多的在deep, 遷移學習, embedding , attention上做文章,這需要多看paper, 多寫代碼, 慢慢積累。 最後, 願我早日超神。

推薦閱讀:

相关文章