超強半監督學習 MixMatch
人類的學習方法是半監督學習,他們能從大量的未標記數據和極少量的標記數據學習,迅速理解這個世界。半監督學習最近有沒有什麼大的突破呢?我的Twitter賬號被這篇 《The Quiet Semi-Supervised Revolution》【1】博客刷屏了。這篇博客介紹了 DeepMind 的 MixMatch 【2】方法,此方法僅用少量的標記數據,就使半監督學習的預測精度逼近監督學習。深度學習領域的未來可能因此而刷新。
以前的半監督學習方案,一直以來表現其實都很差。你可能會想到 BERT 和 GPT,這兩個超強的自然語言預訓練模型。但這兩個模型的微調只能算遷移學習,而非半監督學習。因為它們最開始訓練的時候,使用了監督學習方法。比如通過語言模型,輸入前言,預測後語;輸入語境,完形填空;輸入前言和後語,預測是否前言不搭後語。這幾種方法,很難稱作無監督學習。
下面這幾種大家很容易想到的半監督學習方法,效果都不是很好。比如使用主成分分析PCA,提取數據中方差最大的特徵,再在少量標記數據上,做監督學習;又比如使用自編碼機 AutoEncoder,以重建輸入圖像的方式,獲得數據潛在表示,對小數據監督學習;再比如使用生成對抗網路 GAN,以生成以假亂真圖像的方式,獲得數據潛在表示,對小數據做監督學習。半監督訓練很久的精度,還比不上直接在小數據上做監督學習的精度!大家的猜測是,這些非監督方法學到的特徵可能並不是分類器真正需要的特徵。
什麼才是半監督學習的正確打開方式呢?近期的一些半監督學習方法,通過在損失函數中添加與未標記數據相關的項,來鼓勵模型舉一反三,增加對陌生數據的泛化能力。
第一種方案是自洽正則化(Consistency Regularization)【3,4】。以前遇到標記數據太少,監督學習泛化能力差的時候,人們一般進行訓練數據增廣,比如對圖像做隨機平移,縮放,旋轉,扭曲,剪切,改變亮度,飽和度,加雜訊等。數據增廣能產生無數的修改過的新圖像,擴大訓練數據集。自洽正則化的思路是,對未標記數據進行數據增廣,產生的新數據輸入分類器,預測結果應保持自洽。即同一個數據增廣產生的樣本,模型預測結果應保持一致。此規則被加入到損失函數中,有如下形式,
其中 x 是未標記數據,Augment(x) 表示對x做隨機增廣產生的新數據, 是模型參數,y 是模型預測結果。注意數據增廣是隨機操作,兩個 Augment(x) 的輸出不同。這個 L2 損失項,約束機器學習模型,對同一個圖像做增廣得到的所有新圖像,作出自洽的預測。
MixMatch 集成了自洽正則化。數據增廣使用了對圖像的隨機左右翻轉和剪切(Crop)。
第二種方案稱作 最小化熵(Entropy Minimization)【5】。許多半監督學習方法都基於一個共識,即分類器的分類邊界不應該穿過邊際分布的高密度區域。具體做法就是強迫分類器對未標記數據作出低熵預測。實現方法是在損失函數中簡單的增加一項,最小化 對應的熵。
MixMatch 使用 "sharpening" 函數,最小化未標記數據的熵。這一部分後面會介紹。
第三種方案稱作傳統正則化(Traditional Regularization)。為了讓模型泛化能力更好,一般的做法對模型參數做 L2 正則化,SGD下L2正則化等價於Weight Decay。MixMaxtch 使用了 Adam 優化器,而之前有篇文章發現 Adam 和 L2 正則化同時使用會有問題,因此 MixMatch 從諫如流使用了單獨的Weight decay。
最近發明的一種數據增廣方法叫 Mixup 【6】,從訓練數據中任意抽樣兩個樣本,構造混合樣本和混合標籤,作為新的增廣數據,
其中 是一個 0 到 1 之間的正數,代表兩個樣本的混合比例。MixMatch 將 Mixup 同時用在了標記數據和未標記數據中。
MixMatch 方案
MixMatch 偷學各派武功,取三家之長,補三家之短,最終成為天下第一高手 -- 最強半監督學習模型。這種 MixMatch 方法在小數據上做半監督學習的精度,遠超其他同類模型。比如,在 CIFAR-10 數據集上,只用250個標籤,他們就將誤差減小了4倍(從38%降到11%)。在STL-10數據集上,將誤差降低了兩倍。 方法示意圖如下,