簡單答案:會!而且效果會變的很差

詳細答案:

這其實是一個很有趣的研究領域:連續學習(continual learning)或者叫 終身學習(lifelong learning)

為了更清楚的 理解,以及解決 這個問題,我們可以來看一下 這個領域 裡面,比較重要的 兩篇 文章

論文: Overcoming catastrophic forgetting in neural networks

這是 Deepmind 的一篇文章,發表在了 Proceedings of the National Academy of Sciences of the United States of America (PNAS) 上面。這個期刊上很多都是自然科學的文章,可以說也很有趣了,不過按照這個期刊的描述:It is the official journal of the National Academy of Sciences, published since 1915, and publishes original research, scientific reviews, commentaries, and letters.,倒也算契合。


問題定義

先說要解決什麼問題。

現在有 dataset A,我們搭了一個模型,然後在 A 上也 train 出了滿意的結果。然後 dataset B 來了,我們那剛剛 train 好的模型在 B 上繼續 finetune,train 之後,發現模型在 B 上的結果也很滿意。

但是,這個時候你如果拿最終這個模型去 A 上重新測一遍,你會發現準確率已經慘不忍睹了。所以這個問題就是

「機器學習的很快,但是遺忘的也快」,英文叫做 「catastrophic forgetting」

為了更好定義這個問題,必須要指出的是,在上述例子中,一旦我們開始在 B 上訓練是,A 就永遠的不可見了。

方法

High-level idea:

find a solution to a new task in the neighbourhood of an older one

就是說,我們在 B 上進行 finetune 的時候,不要離之前的 model 太遠了,這就是基本思路

上圖裡面,兩個 set 的重合處,就是我們要找的,既對 B 效果好,又離之前的 model 的比較近。

理論依據:

Many configurations of θ will result in the same performance [1, 2]。

[1] Robert H. Nielsen. Theory of the backpropagation neural network. In Proceedings ofthe International Joint Conference on Neural Networks, volume I, pages 593–605. Piscataway, NJ: IEEE, 1989.

[2] Héctor J. Sussmann. Uniqueness of the weights for minimal feedforward nets with a given input- output map. Neural Networks, 5:589–593, 1992.

就是說,對於同一個 task,是有可能找到很多符合條件(也就是效果不錯)的 model 的,那麼我們就可以從中找到一個離之前 model 最近的一個 model。試想,如果沒有這個結論,那麼「既在 B 上效果好,又不至於離之前的 model 太遠」 可能就是一個不可能的任務。

具體做法

基本的出發點是,當我在 B 上面 finetune 的時候,並沒有必要調整所有的參數,我們儘可能只動那些對 A 影響比較小的參數,就足夠了。

對於 train 一個 model 來說,我們在更新參數的過程可以用一個貝葉斯公式來表達,

上面公式中,

[公式]

就是我們常用的 loss 的相反數。後兩項分別是 weights 和 dataset 的先驗分佈。其中 dataset 的先驗分佈由於本身在 dataset 確定之後就固定了,所以一般也不考慮。而 weights 的先驗分佈則常常是一些 regularization,比如 weights decay 等等。

那麼,當我們從 A 轉到 B 的時候,上面公式發生了如下變化,

最關鍵的就是理解等號右邊的中間這一項。

在本文的問題中,我再不是對 weight 加一個 regularization 那麼簡單了,而是我希望這個 weights 能夠對 A dataset 也有比較好的準確率。所以這裡替換成了 A dataset 的優化目標。

上面這個公式,此時不只是 B 的優化目標,而是 A 和 B 整體的優化目標。

直接求,

[公式]

是沒辦法求的,因為當我們在 train B dataset 的時候,A dataset 已經看不到了。那麼這裡我們做一個假設,認為每個 theta 是符合一個高斯分佈的,

  1. 這個高斯分佈的均值就是 現在 A 上 train 後得到的那個值。這個很好理解,因為那個值就是我們想保持的值,所以把他設為分佈的中心(也就是均值)是合理的。
  2. 這個高斯分佈的 precision(就是方差的倒數)

[公式]

可以用 Fisher information matrix, F 來估計。我們暫時不講為什麼 precision 可以用 Fisher information matrix 估計,先來講一講為什麼:

方差的倒數代表這個參數對 A 的重要性。倒數值越低,改變他對 A 的影響越小。

上面的綠色的線是一個方差比較大的分佈,我們稍微改變一下參數的值,概率的變化並不大。

上面的藍色的線是一個方差比較小的分佈,我們同樣改變相同的量,他的概率變化要大得多。

我們當然希望是條件概率越大越好,所以我們要找到那些方差比較大的參數,因為改變他們對 A 的概率項影響較小。

Fisher Information Matrix

總結自:

https://www.inference.vc/on-empirical-fisher-information/

https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/

對於一個監督學習的問題,

[公式]

Fisher Information Matrix (FIM) 的定義是,

[公式]

一個比較直觀的解釋是,FIM是 兩個分佈 KL Divergence 的 Hessian 矩陣,而且這兩個分佈之間只有非常細微的差別,下面來推導一下。

首先來看兩個細微差別的分佈的 KL Divergence 的定義,

[公式]

然後我們先來對

[公式]

求一階導數,

然後再求二階導數,

[公式]

然後就可以得到 Hessian 矩陣了,

論文:Memory Aware Synapses: Learning What (not) to Forget

這裡的 [公式] 是 neural network,

值得注意的是,這裡的 [公式] 並不是損失函數對參數的導數,而是網路對應的函數對參數的導數。一般來說,後者要複雜不少。因為損失函數是一個標量,而網路的輸出則是是一個向量。加入這個向量的長度為10(總共有10類),那麼算網路導數的計算量就是算損失函數導數的10倍,所以這個問題必須解決。

本文的解決辦法也比較直接,那就是我算輸出向量 L2 norm 關於參數的導數。關於這個妥協,作者後面也做了實驗,發現並不會讓效果打折。

這裡的 [公式]是可以作為一個重要性程度衡量標準的。因為自然 [公式]越大,變動參數帶來的函數的變化也就越大。所以綜合考慮所有的 data points,我們可以得到:

局部版本

之前的 [公式] 是一個函數,這個函數代表整個網路。同樣的,我們認為網路中的每一層也是一個函數

[公式],這裡的下標 [公式] 表示第幾層。背後的 idea 是,只要我們能夠保證每一層變化足夠小,那麼整個網路的變化也會比較小。

這裡紅框這個公式其實是有問題的。

[公式] 實際上應該是 [公式] 是這一層第 [公式]-th 輸入,也就是 [公式] 的輸出;

[公式] 實際上應該是 [公式] 是這一層第 [公式]-th 輸出

同樣的,最終每個參數的重要程度要綜合考慮所有的 data points,

實驗

前面講了講個版本:全局的,局部的。實驗部分,作者用的都是全局的這個方法(主要也是因為全局方法效果更好一些)。

「這個方法真正牛的地方在於,他沒有對 regularization 這一項前面乘的係數進行調參。使用默認值 1,也取得了比較好的效果」

而且,這篇文章只用到了一階導數信息,相對來說比較容易計算,而且不要「網路已經train到收斂」的假設條件。

總的來說,這篇文章可以算作是一篇 「simple but work」 的工作。


這就是傳說中的「災難性遺忘的問題」,你可以看看life long learning或continual learning就是專門研究這個問題。

這問題說難也難,說簡單也簡單。只要你保留全部的老數據,微調新數據的時候,加一個multi task的分支給老數據一起訓練,基本就是這個問題的upper bound了。

這個問題難在某些NLP任務中, 需要不斷擴展類別。每次加入新類別,需要重新訓練模型。尤其當體系大海量數據時,重新訓練模型代價很高。此外還要保證已有類型的穩定,避免對線上業務的影響。

如何高效地處理這個問題?

有一篇文章Learning without Forgetting,提供了很好的思路:

簡單說就是先用老模型過一遍新數據,保留老模型的輸出,然後訓練新模型,同時對新模型的輸出計算distillation loss對衝掉新數據對老模型參數的影響。這篇文章的idea真是棒棒噠!


ICLR2020上有一篇 Editable Neural Networks 就是解決你說的問題,針對error 樣本進行調整但是不影響其他數據。

另外最近一眾Life long learning/ incremental learning 還有關於catastrophic forgetting都是研究這個問題的。


當然會,所以有個現象叫做災難遺忘嘛。

怎麼解決?Continual learning可以看看。

Transfer就解決不了了,因為遷移的目的就是要求在目標領域的性能最好,不管之前的領域了。

多任務學習倒是也可以參考。


答案是肯定的,這叫做災難性遺忘現象(catastrophic forgetting),已經有許多相關的研究了。比如說Overcoming catastrophic forgetting in neural networks這篇文章已經達到700+的引用量。而針對這一問題提出的連續學習(continual learning)近期是研究的大熱門,連續學習就是要解決在連續學習多個任務的情況下,保持前部任務的高準確度,當前仍然是一個充滿挑戰的領域。


推薦閱讀:
相關文章