人類的學習方法是半監督學習,他們能從大量的未標記數據和極少量的標記數據學習,迅速理解這個世界。半監督學習最近有沒有什麼大的突破呢?我的Twitter賬號被這篇 《The Quiet Semi-Supervised Revolution》【1】博客刷屏了。這篇博客介紹了 DeepMind 的 MixMatch 【2】方法,此方法僅用少量的標記數據,就使半監督學習的預測精度逼近監督學習。深度學習

領域的未來可能因此而刷新。

以前的半監督學習方案,一直以來表現其實都很差。你可能會想到 BERT 和 GPT,這兩個超強的自然語言預訓練模型。但這兩個模型的微調只能算遷移學習,而非半監督學習。因為它們最開始訓練的時候,使用了監督學習方法。比如通過語言模型,輸入前言,預測後語;輸入語境,完形填空;輸入前言和後語,預測是否前言不搭後語。這幾種方法,很難稱作無監督學習。

下面這幾種大家很容易想到的半監督學習方法,效果都不是很好。比如使用主成分分析PCA,提取數據中方差最大的特徵,再在少量標記數據上,做監督學習;又比如使用自編碼機 AutoEncoder,以重建輸入圖像的方式,獲得數據潛在表示,對小數據監督學習;再比如使用生成對抗網路 GAN,以生成以假亂真圖像的方式,獲得數據潛在表示,對小數據做監督學習。半監督訓練很久的精度,還比不上直接在小數據上做監督學習的精度!大家的猜測是,這些非監督方法學到的特徵可能並不是分類器真正需要的特徵。

什麼才是半監督學習的正確打開方式呢?近期的一些半監督學習方法,通過在損失函數中添加與未標記數據相關的項,來鼓勵模型舉一反三,增加對陌生數據的泛化能力。

第一種方案是自洽正則化(Consistency Regularization)【3,4】。以前遇到標記數據太少,監督學習泛化能力差的時候,人們一般進行訓練數據增廣,比如對圖像做隨機平移,縮放,旋轉,扭曲,剪切,改變亮度,飽和度,加雜訊等。數據增廣能產生無數的修改過的新圖像,擴大訓練數據集。自洽正則化的思路是,對未標記數據進行數據增廣,產生的新數據輸入分類器,預測結果應保持自洽。即同一個數據增廣產生的樣本,模型預測結果應保持一致。此規則被加入到損失函數中,有如下形式,

| mathrm{p}_{	ext { model }}(y | 	ext { Augment }(x) ; 	heta)-mathrm{p}_{	ext { model }}(y | 	ext { Augment }(x) ; 	heta) |_{2}^{2}

其中 x 是未標記數據,Augment(x) 表示對x做隨機增廣產生的新數據, 	heta 是模型參數,y 是模型預測結果。注意數據增廣是隨機操作,兩個 Augment(x) 的輸出不同。這個 L2 損失項,約束機器學習模型,對同一個圖像做增廣得到的所有新圖像,作出自洽的預測。

MixMatch 集成了自洽正則化。數據增廣使用了對圖像的隨機左右翻轉和剪切(Crop)。

第二種方案稱作 最小化熵(Entropy Minimization)【5】。許多半監督學習方法都基於一個共識,即分類器的分類邊界不應該穿過邊際分布的高密度區域。具體做法就是強迫分類器對未標記數據作出低熵預測。實現方法是在損失函數中簡單的增加一項,最小化 mathrm{p}_{	ext { model }}(y | x) 對應的熵。

MixMatch 使用 "sharpening" 函數,最小化未標記數據的熵。這一部分後面會介紹。

第三種方案稱作傳統正則化(Traditional Regularization)。為了讓模型泛化能力更好,一般的做法對模型參數做 L2 正則化,SGD下L2正則化等價於Weight Decay。MixMaxtch 使用了 Adam 優化器,而之前有篇文章發現 Adam 和 L2 正則化同時使用會有問題,因此 MixMatch 從諫如流使用了單獨的Weight decay。

最近發明的一種數據增廣方法叫 Mixup 【6】,從訓練數據中任意抽樣兩個樣本,構造混合樣本和混合標籤,作為新的增廣數據,

egin{array}{ll}{	ilde{x}=lambda x_{i}+(1-lambda) x_{j},} & {	ext { where } x_{i}, x_{j} 	ext { are raw input vectors }} \ {	ilde{y}=lambda y_{i}+(1-lambda) y_{j},} & {	ext { where } y_{i}, y_{j} 	ext { are one-hot label encodings }}end{array}

其中 lambda 是一個 0 到 1 之間的正數,代表兩個樣本的混合比例。MixMatch 將 Mixup 同時用在了標記數據和未標記數據中。

MixMatch 方案

MixMatch 偷學各派武功,取三家之長,補三家之短,最終成為天下第一高手 -- 最強半監督學習模型。這種 MixMatch 方法在小數據上做半監督學習的精度,遠超其他同類模型。比如,在 CIFAR-10 數據集上,只用250個標籤,他們就將誤差減小了4倍(從38%降到11%)。在STL-10數據集上,將誤差降低了兩倍。 方法示意圖如下,

MixMatch 實現方法:對無標籤數據,做數據增廣,得到 K 個新的數據。因為數據增廣引入雜訊,將這 K 個新的數據,輸入到同一個分類器,得到不同的預測分類概率。MinMax 利用演算法(Sharpen),使多個概率分布的平均(Average)方差更小,預測結果更加自洽,系統熵更小。

註:Google原文並未比較 MixMatch 和使用生成對抗網路GAN做半監督學習時的表現孰好孰壞。但從搜索到的資料來看,2016年 OpenAI 的 Improved GAN 【8】,使用4000張CIFAR10的標記數據,做半監督學習得到測試誤差18.6。2017年,GAN做半監督學習的測試誤差,在4000張CIFAR10標記數據上,將測試誤差降低到14.41 【10】。2018年,GAN + 流形正則化,得到測試誤差14.45。目前並沒有看到來自GAN的更好結果。對比 MixMatch 使用 250 張標記圖片,就可以將測試誤差降低到 11.08,使用4000張標記圖片,可以將測試誤差降低到 6.24,應該算是大幅度超越使用GAN做半監督學習的效果。

具體步驟:

  1. 使用 MixMatch 演算法,對一個 Batch 的標記數據 mathcal{X} 和一個 Batch 的未標記數據 mathcal{U} 做數據增廣,分別得到一個 Batch 的增廣數據  mathcal{X}^{prime} 和 K 個Batch的 mathcal{U}^{prime}

 mathcal{X}^{prime}, mathcal{U}^{prime} =operatorname{MixMatch}(mathcal{X}, mathcal{U}, T, K, alpha)

其中 T, K, alpha 是超參數,後面會介紹。MixMatch 數據增廣演算法如下,

MixMatch 演算法。

演算法描述:for 循環對一個Batch的標記圖片和未標記圖片做數據增廣。對標記圖片,只做一次增廣,標籤不變,記為 p 。對未標記數據,做 K 次隨機增廣(文章中超參數K=2),輸入分類器,得到平均分類概率,應用溫度Sharpen 演算法(T 是溫度參數,此演算法後面介紹),得到未標記數據的「猜測」標籤 q 。此時增廣後的標記數據 mathcal{hat{X}} 有一個Batch,增廣後的未標記數據 mathcal{hat{U}} 有 K 個Batch。將 mathcal{hat{X}}mathcal{hat{U}} 混合在一起,隨機重排得到數據集 mathcal{W} 。最終 MixMatch 增廣演算法輸出的,是將 mathcal{hat{X}}mathcal{W} 做了MixUp() 的一個 Batch 的標記數據 mathcal{X} ,以及 mathcal{hat{U}}mathcal{W} 做了MixUp() 的 K 個Batch 的無標記增廣數據 mathcal{U}

2. 對增廣後的標記數據 mathcal{X} ,和無標記增廣數據 mathcal{U} 分別計算損失項,

egin{aligned} mathcal{L}_{mathcal{X}} &=frac{1}{left|mathcal{X}^{prime}
ight|} sum_{x, p in mathcal{X}^{prime}} mathrm{H}left(p, mathrm{p}_{	ext { model }}(y | x ; 	heta)
ight) \ mathcal{L}_{mathcal{U}} &=frac{1}{Lleft|mathcal{U}^{prime}
ight|} sum_{u, q in mathcal{U}^{prime}}left|q-mathrm{p}_{	ext { model }}(y | u ; 	heta)
ight|_{2}^{2} \ end{aligned}

其中 left|mathcal{X}^{prime}
ight| 等於 Batch Size, left|mathcal{U}^{prime}
ight|等於 K 倍 Batch Size,L 是分類類別個數, H(p, p_{
m model}) 是簡單的 Cross Entropy 函數, x, p 是增廣的標記數據輸入和標籤, u, q 是增廣的未標記數據輸入以及猜測的標籤。

對未標記數據損失 mathcal{L_{U}} 使用 L2 Loss 而不是像 mathcal{L_{X}} 一樣使用 Cross Entropy Loss 的原因文章中沒有提到。但在引用的NVIDIA文章【3】第三頁提供了一個解釋。即 L2 Loss 比 Cross Entropy Loss 更加嚴格。原因是 Cross Entropy 計算是需要先使用 Softmax 函數,將Dense Layer輸出的類分數 z_i 轉化為類概率,

{
m softmax} (z_i) = frac{exp(z_i)}{sum_j exp (z_j)}

而 softmax 函數對於常數疊加不敏感,即如果將最後一個 Dense Layer 的所有輸出類分數 z_i 同時添加一個常數 c, 則類概率不發生改變,Cross Entropy Loss 不發生改變。

{
m softmax} (z_i + c) = frac{exp(z_i + c)}{sum_j exp (z_j + c)} = frac{exp(z_i )}{sum_j exp (z_j )} = {
m softmax} (z_i )

因此,如果對未標記數據使用 Cross Entropy Loss, 由同一張圖片增廣得到的兩張新圖片,最後一個Dense Layer的輸出被允許相差一個常數。使用 L2 Loss, 約束更加嚴格。

3. 最終的整體損失函數是兩者的加權,

mathcal{L} =mathcal{L}_{mathcal{X}}+lambda_{mathcal{U}} mathcal{L}_{mathcal{U}}

其中 lambda_{mathcal{U}} 是非監督學習損失函數的加權因子,這個超參數的數值可調,文章使用 lambda_{mathcal{U}} = 100

在上面的步驟描述中,還有另外兩個超參數,溫度 T 和 alpha 。T 被用在 Sharpening 過程中, alpha 是 Mixup 的超參數。下面分別解釋這兩個超參數的來歷。

不是說未標記數據沒標籤嗎?我們可以用分類器「猜測」一些標籤。演算法描述中的這一步,就是分類器對 K 次增廣的無標籤數據分類結果做平均,猜測的「偽」標籤。對應示意圖中 Average 分布。但這個平均預測分布比較平坦,就像在貓狗二分類中,分類器說,這張圖片中 50% 幾率是貓,50%幾率是狗一樣,對各類別分類概率預測比較平均。

overline{q}_{b}=frac{1}{K} sum_{k} mathrm{p}_{mathrm{model}}left(y | hat{u}_{b, k} ; 	heta
ight)

MixMatch 使用了 Sharpen,來使得「偽」標籤熵更低,即貓狗分類中,要麼百分之九十多是貓,要麼百分之九十多是狗。做法也是前人發明的,

	ext { Sharpen }(p, T)_{i} :=p_{i}^{frac{1}{T}} / sum_{j=1}^{L} p_{j}^{frac{1}{T}}

其中, p 是類別概率,在 MixMatch 中對應 overline{q}_{b} 。T 是溫度參數,可以調節分類熵。調節 T 趨於0, 	ext { Sharpen }(p, T)_{i} 趨近於 One-Hot 分布,即對某一類別輸出概率 1,其他所有類別輸出概率0,此時分類熵最低。註: 熵 = - sum_{i=1}^{c} p_i log p_i , 可以計算得到,在二分類中,兩個類的輸出概率是One-Hot時 (p_0=1, p_1=0) 的熵遠小於輸出概率比較平均 (p_0=0.5, p_1=0.5) 的熵。在 MixMatch 中,降低溫度T,可以鼓勵模型作出低熵預測。

最後一個尚未解釋的超參數 alpha 被用在 Mixup 數據增廣中。與之前的 Mixup 方法不同,MixMatch方法將標記數據與未標記數據做了混合,進行 Mixup。對應演算法描述中的混合與隨機重排。

MixMatch 修改了 Mixup 演算法。對於兩個樣本以及他們的標籤 (x_1, p_1)(x_2, p_2), 混合後的樣本為,

egin{array}{l}{x^{prime}=lambda^{prime} x_{1}+left(1-lambda^{prime}
ight) x_{2}} \ {p^{prime}=lambda^{prime} p_{1}+left(1-lambda^{prime}
ight) p_{2}}end{array}

其中,權重因子 lambda 使用超參數 alpha 通過 Beta 函數抽樣得到,

egin{aligned} lambda & sim operatorname{Beta}(alpha, alpha) \ lambda^{prime} &=max (lambda, 1-lambda) end{aligned}

文章使用超參數 alpha = 0.75 , 如果將此 Beta 分布畫圖表示,則如下圖所示,

權重因子的分布。根據此 Beta(0.75, 0.75) 分布抽樣,大部分數值落在接近 0 或 1 的區域。

原始的 Mixup 演算法中,第一步不變,第二步 lambda = lambda 。MixMatch 做了極小的修改,使用 lambda^{prime} =max (lambda, 1-lambda) 。如上圖所示,根據 {
m Beta} (alpha=0.75, alpha=0.75) 抽樣得到的 lambda 數值大部分落在 0 或 1 附近, lambda^{prime} =max (lambda, 1-lambda) 函數則使得 lambda^{prime} 數值接近 1 。這樣的好處是在 Mixup 標記數據 mathcal{hat{X}} 與混合數據 mathcal{W}時,增加 mathcal{hat{X}}的權重;在 Mixup 未標記數據 mathcal{hat{U}}mathcal{W}時,增加 mathcal{hat{U}}的權重。分別對應於演算法描述中的{
m Mixup}( mathcal{hat{X}},  mathcal{W}){
m Mixup}( mathcal{hat{U}},  mathcal{W})

細節:損失函數中使用了對未標記數據猜測的標籤 q , 此標籤依賴於模型參數 	heta 。遵循標準處理方案,不將 q	heta 的梯度做向後誤差傳遞。

半監督學習 MixMatch 訓練結果

在 CIFAR-10 數據集上,使用全部五萬個數據做監督學習,最低誤差能降到百分之4.13。使用 MixMatch,250 個數據就能將誤差降到百分之11,4000 個數據就能將誤差降到百分之 6.24。結果驚艷。

更直觀的效果對比

MixMatch 演算法測試誤差用黑色星號表示,監督學習演算法用虛線表示。觀察最底下,誤差最小的兩條線,可看到 MixMatch 測試誤差直逼監督學習演算法!

解剖各部分貢獻 (Ablation Test )

可以看到對結果貢獻最大的是對未標記數據的 MixUp,Average 以及 Sharpen。

結論:

半監督學習是深度學習裡面最可能接近人類智能的方法。這個方向的進展,這篇文章的突破,都是領域的極大進展。因未在其他公眾號看到這篇文章的介紹,特此作此解讀。

另有一篇文章,Unsupervised Data Augmentation,貌似在4000張標記圖片的CIFAR10上達到了 5.27 的測試誤差,超過了 MixMatch 方法。如有時間,會進一步解讀那篇文章。以觀察兩篇文章的方法是否可以一同使用。

參考文獻:

  1. The Quiet Semi-Supervised Revolution
  2. MixMatch: A Holistic Approach to Semi-Supervised Learning
  3. Temporal ensembling for semi-supervised learning. ICLR, 2017.
  4. Regularization with stochastic transformations and perturbations for deep semi-supervised learning. NIPS, 2016.
  5. Semi-supervised Learning by Entropy Minimization
  6. Mixup: Beyond empirical risk minimization
  7. Realistic Evaluation of Deep Semi-Supervised Learning Algorithms
  8. Improved Techniques for Training GANs ,OpenAI 2016, get 18.6 test error using 4000 labeled images in CIFAR10.
  9. SEMI-SUPERVISED LEARNING WITH GANS: REVISITING MANIFOLD REGULARIZATION , 2018, GAN + Manifold Regularization, get 14.45 test error using 4000 labeled images in CIFAR10.
  10. Good Semi-supervised Learning That Requires a Bad GAN , 2017, get 14.41 test error using 4000 labeled images in CIFAR10.
  11. [free online book] Semi Supervised Learning

推薦閱讀:

相关文章