DeepMind最新研究解決災難性遺忘難題
【新智元導讀】當遇到序列任務時,神經網路會遭受災難性遺忘。DeepMind研究人員通過在函數空間中引入貝葉斯推理,使用誘導點稀疏GP方法和優化排練數據點來克服這個問題。今天和大家分享這篇Reddit高贊論文。
這篇由DeepMind研究團隊出品的論文名字叫「Functional Regularisation for Continual Learning」(持續學習的功能正規化)。研究人員引入了一個基於函數空間貝葉斯推理的持續學習框架,而不是深度神經網路的參數。該方法被稱為用於持續學習的函數正則化,通過在底層任務特定功能上構造和記憶一個近似的後驗信念,避免忘記先前的任務。
為了實現這一點,他們依賴於通過將神經網路的最後一層的權重視為隨機和高斯分佈而獲得的高斯過程。然後,訓練演算法依次遇到任務,並利用誘導點稀疏高斯過程方法構造任務特定函數的後驗信念。在每個步驟中,首先學習新任務,然後構建總結(summary),其包括(i)引入輸入和(ii)在這些輸入處的函數值上的後驗分佈。然後,這個總結通過Kullback-Leibler正則化術語規範學習未來任務,從而避免了對早期任務的災難性遺忘。他們在分類數據集中演示了自己的演算法,例如Split-MNIST,Permuted-MNIST和Omniglot。
通過函數正則化解決災難性遺忘
近年來,人們對持續學習(也稱為終身學習)的興趣再度興起,這是指以在線方式從可能與不斷增加的任務相關的數據中學習的系統。持續學習系統必須適應所有早期任務的良好表現,而無需對以前的數據進行大量的重新訓練。
持續學習的兩個主要挑戰是:
(i)避免災難性遺忘,比如記住如何解決早期任務;
(ii)任務數量的可擴展性。
其他可能的設計包括向前和向後轉移,比如更快地學習後面的任務和回顧性地改進前面的任務。值得注意的是,持續學習與元學習(meta-learning)或多任務學習有很大的不同。在後一種方法中,所有任務都是同時學習的,例如,訓練是通過對小批量任務進行二次抽樣,這意味著沒有遺忘的風險。
與許多最近關於持續學習的著作相似,他們關注的是理想化的情況,即一系列有監督的學習任務,具有已知的任務邊界,呈現給一個深度神經網路的持續學習系統。一個主要的挑戰是有效地規範化學習,使深度神經網路避免災難性的遺忘,即避免導致早期任務的預測性能差的網路參數配置。在不同的技術中,他們考慮了兩種不同的方法來管理災難性遺忘。
一方面,這些方法限制或規範網路的參數,使其與以前的任務中學習的參數沒有明顯的偏差。 這包括將持續學習構建為順序近似貝葉斯推理的方法,包括EWC和VCL。這種方法由於表徵漂移(representation drift)而具有脆弱性(brittleness)。也就是說,隨著參數適應新任務,其他參數被約束/正規化的值變得過時。
另一方面,他們有預演/回放緩衝方法,它使用過去觀察的記憶存儲來記住以前的任務。它們不會受到脆弱性的影響,但是它們不表示未知函數的不確定性(它們只存儲輸入-輸出),並且如果任務複雜且需要許多觀察來正確地表示,那麼它們的可擴展性會降低。優化存儲在重放緩衝區中的最佳觀察結果也是一個未解決的問題。
在論文中,研究人員發展了一種新的持續學習方法,解決了這兩個類別的缺點。它是基於近似貝葉斯推理,但基於函數空間而不是神經網路參數,因此不存在上述的脆弱性。這種方法通過記住對底層特定任務功能的近似後驗信念,避免忘記先前的任務。
為了實現這一點,他們考慮了高斯過程(GPs),並利用誘導點稀疏GP方法總結了使用少量誘導點的函數的後驗分佈。這些誘導點及其後驗分佈通過變分推理框架內的KullbackLeibler正則化項,來規範未來任務的持續學習,避免了對早期任務的災難性遺忘。因此,他們的方法與基於重播的方法相似,但有兩個重要的優勢。
首先,誘導點的近似後驗分佈捕獲了未知函數的不確定性,並總結了給定所有觀測值的全後驗分佈。其次,誘導點可以使用來自GP文獻的專門標準進行優化,實現比隨機選擇觀測更好的性能。
為了使他們的函數正則化方法能夠處理高維和複雜的數據集,他們使用具有神經網路參數化特徵的線性核。這樣的GPs可以理解為貝葉斯神經網路,其中只有最後一層的權重以貝葉斯方式處理,而早期層的權重是優化的。這種觀點允許在權重空間中進行更有效和準確的計算訓練程序,然後將近似轉換為函數空間,在函數空間中構造誘導點,然後用於規範未來任務的學習。他們在分類中展示了自己的方法,並證明它在Permuted-MNIST,Split-MNIST和Omniglot上具有最先進的性能。
實驗簡介
研究人員考慮了三個持續學習分類問題中的實驗:Split-MNIST,PermutedMNIST和Sequenn Omniglot。他們比較了其方法的兩種變體,稱為功能正則化持續學習(FRCL)。