簡單答案:會!而且效果會變的很差 !
詳細答案:
這其實是一個很有趣的研究領域:連續學習(continual learning)或者叫 終身學習(lifelong learning)
為了更清楚的 理解,以及解決 這個問題,我們可以來看一下 這個領域 裡面,比較重要的 兩篇 文章 。
論文: Overcoming catastrophic forgetting in neural networks
這是 Deepmind 的一篇文章,發表在了 Proceedings of the National Academy of Sciences of the United States of America (PNAS) 上面。這個期刊上很多都是自然科學的文章,可以說也很有趣了,不過按照這個期刊的描述:It is the official journal of the National Academy of Sciences, published since 1915, and publishes original research, scientific reviews, commentaries, and letters.,倒也算契合。
問題定義
先說要解決什麼問題。
現在有 dataset A,我們搭了一個模型,然後在 A 上也 train 出了滿意的結果。然後 dataset B 來了,我們那剛剛 train 好的模型在 B 上繼續 finetune,train 之後,發現模型在 B 上的結果也很滿意。
但是,這個時候你如果拿最終這個模型去 A 上重新測一遍,你會發現準確率已經慘不忍睹了。所以這個問題就是
「機器學習的很快,但是遺忘的也快」,英文叫做 「catastrophic forgetting」
為了更好定義這個問題,必須要指出的是,在上述例子中,一旦我們開始在 B 上訓練是,A 就永遠的不可見了。
方法
High-level idea:
find a solution to a new task in the neighbourhood of an older one
就是說,我們在 B 上進行 finetune 的時候,不要離之前的 model 太遠了,這就是基本思路 。
上圖裡面,兩個 set 的重合處,就是我們要找的,既對 B 效果好,又離之前的 model 的比較近。
理論依據:
Many configurations of θ will result in the same performance [1, 2]。
[1] Robert H. Nielsen. Theory of the backpropagation neural network. In Proceedings ofthe International Joint Conference on Neural Networks, volume I, pages 593–605. Piscataway, NJ: IEEE, 1989.
[2] Héctor J. Sussmann. Uniqueness of the weights for minimal feedforward nets with a given input- output map. Neural Networks, 5:589–593, 1992.
就是說,對於同一個 task,是有可能找到很多符合條件(也就是效果不錯)的 model 的,那麼我們就可以從中找到一個離之前 model 最近的一個 model。試想,如果沒有這個結論,那麼「既在 B 上效果好,又不至於離之前的 model 太遠」 可能就是一個不可能的任務。
具體做法
基本的出發點是,當我在 B 上面 finetune 的時候,並沒有必要調整所有的參數,我們儘可能只動那些對 A 影響比較小的參數,就足夠了。
對於 train 一個 model 來說,我們在更新參數的過程可以用一個貝葉斯公式來表達,
上面公式中,
就是我們常用的 loss 的相反數。後兩項分別是 weights 和 dataset 的先驗分佈。其中 dataset 的先驗分佈由於本身在 dataset 確定之後就固定了,所以一般也不考慮。而 weights 的先驗分佈則常常是一些 regularization
,比如 weights decay
等等。
那麼,當我們從 A 轉到 B 的時候,上面公式發生了如下變化,
最關鍵的就是理解等號右邊的中間這一項。
在本文的問題中,我再不是對 weight 加一個 regularization
那麼簡單了,而是我希望這個 weights 能夠對 A dataset 也有比較好的準確率。所以這裡替換成了 A dataset 的優化目標。
上面這個公式,此時不只是 B 的優化目標,而是 A 和 B 整體的優化目標。
直接求,
是沒辦法求的,因為當我們在 train B dataset 的時候,A dataset 已經看不到了。那麼這裡我們做一個假設,認為每個 theta 是符合一個高斯分佈的,
這個高斯分佈的均值就是 現在 A 上 train 後得到的那個值。這個很好理解,因為那個值就是我們想保持的值,所以把他設為分佈的中心(也就是均值)是合理的。
這個高斯分佈的 precision(就是方差的倒數)
可以用 Fisher information matrix, F
來估計。我們暫時不講為什麼 precision 可以用 Fisher information matrix 估計,先來講一講為什麼:
方差的倒數代表這個參數對 A 的重要性。倒數值越低,改變他對 A 的影響越小。
上面的綠色的線是一個方差比較大的分佈,我們稍微改變一下參數的值,概率的變化並不大。
上面的藍色的線是一個方差比較小的分佈,我們同樣改變相同的量,他的概率變化要大得多。
我們當然希望是條件概率越大越好,所以我們要找到那些方差比較大的參數,因為改變他們對 A 的概率項影響較小。
Fisher Information Matrix
總結自:
https://www.inference.vc/on-empirical-fisher-information/
https://wiseodd.github.io/techblog/2018/03/14/natural-gradient/
對於一個監督學習的問題,
Fisher Information Matrix (FIM) 的定義是,
一個比較直觀的解釋是,FIM是 兩個分佈 KL Divergence 的 Hessian 矩陣,而且這兩個分佈之間只有非常細微的差別,下面來推導一下。
首先來看兩個細微差別的分佈的 KL Divergence 的定義,
然後我們先來對
求一階導數,
然後再求二階導數,
然後就可以得到 Hessian 矩陣了,
論文:Memory Aware Synapses: Learning What (not) to Forget
這裡的 是 neural network,
值得注意的是,這裡的 並不是損失函數對參數的導數,而是網路對應的函數對參數的導數。一般來說,後者要複雜不少。因為損失函數是一個標量,而網路的輸出則是是一個向量。加入這個向量的長度為10(總共有10類),那麼算網路導數的計算量就是算損失函數導數的10倍,所以這個問題必須解決。
本文的解決辦法也比較直接,那就是我算輸出向量 L2 norm
關於參數的導數。關於這個妥協,作者後面也做了實驗,發現並不會讓效果打折。
這裡的 是可以作為一個重要性程度衡量標準的。因為自然 越大,變動參數帶來的函數的變化也就越大。所以綜合考慮所有的 data points,我們可以得到:
局部版本
之前的 是一個函數,這個函數代表整個網路。同樣的,我們認為網路中的每一層也是一個函數
,這裡的下標 表示第幾層。背後的 idea 是,只要我們能夠保證每一層變化足夠小,那麼整個網路的變化也會比較小。
這裡紅框這個公式其實是有問題的。
實際上應該是 是這一層第 -th 輸入,也就是 的輸出;
實際上應該是 是這一層第 -th 輸出
同樣的,最終每個參數的重要程度要綜合考慮所有的 data points,
實驗
前面講了講個版本:全局的,局部的。實驗部分,作者用的都是全局的這個方法(主要也是因為全局方法效果更好一些)。
「這個方法真正牛的地方在於,他沒有對 regularization 這一項前面乘的係數進行調參。使用默認值 1,也取得了比較好的效果」
而且,這篇文章只用到了一階導數信息,相對來說比較容易計算,而且不要「網路已經train到收斂」的假設條件。
總的來說,這篇文章可以算作是一篇 「simple but work」 的工作。
這就是傳說中的「災難性遺忘的問題」,你可以看看life long learning或continual learning就是專門研究這個問題。
這問題說難也難,說簡單也簡單。只要你保留全部的老數據,微調新數據的時候,加一個multi task的分支給老數據一起訓練,基本就是這個問題的upper bound了。
這個問題難在某些NLP任務中, 需要不斷擴展類別。每次加入新類別,需要重新訓練模型。尤其當體系大海量數據時,重新訓練模型代價很高。此外還要保證已有類型的穩定,避免對線上業務的影響。
如何高效地處理這個問題?
有一篇文章Learning without Forgetting,提供了很好的思路:
簡單說就是先用老模型過一遍新數據,保留老模型的輸出,然後訓練新模型,同時對新模型的輸出計算distillation loss對衝掉新數據對老模型參數的影響。這篇文章的idea真是棒棒噠!
ICLR2020上有一篇 Editable Neural Networks 就是解決你說的問題,針對error 樣本進行調整但是不影響其他數據。
另外最近一眾Life long learning/ incremental learning 還有關於catastrophic forgetting都是研究這個問題的。
當然會,所以有個現象叫做災難遺忘嘛。
怎麼解決?Continual learning可以看看。
Transfer就解決不了了,因為遷移的目的就是要求在目標領域的性能最好,不管之前的領域了。
多任務學習倒是也可以參考。
答案是肯定的,這叫做災難性遺忘現象(catastrophic forgetting) ,已經有許多相關的研究了。比如說Overcoming catastrophic forgetting in neural networks這篇文章已經達到700+的引用量。而針對這一問題提出的連續學習(continual learning) 近期是研究的大熱門,連續學習就是要解決在連續學習多個任務的情況下,保持前部任務的高準確度,當前仍然是一個充滿挑戰的領域。
推薦閱讀: