一、前言

在遷移學習領域,除了領域自適應Domain Adaptation(DA),還有一種情形叫領域泛化Domain Generalization(DG)。我們知道DA需要源域和目標域都有數據,甚至有些方法要求的數據量還很大。DG主要做的就是通過多個標註好的源域數據,學習一個通用的特徵表示,並希望該表示也能應用於未見過的相似樣本,即目標域數據,即使目標域一個數據都沒有。

二、簡介

本文介紹一篇南洋理工大學和港城大的工作[1],作者是Haoliang Li、Sinno Jialin Pan、Shiqi Wang、Alex C. Kot。其中第二位作者想必大家也有一些了解,之前和楊強教授一起寫了遷移學習survey,是那篇survey的一作。該論文發表在CVPR2018上。

文章提出了一種新穎的機遇對抗自編碼器的框架,藉助於自編碼器的特徵提取能力,學習一個泛化的隱空間特徵表示。在隱空間中,利用MMD(最大均值差異)度量,將不同域的源域數據特徵分布,對齊到一起去。然後通過基於對抗的特徵學習,將對齊上的分布再與一個事先給定的先驗分布匹配上。此外,生成的隱空間特徵,還將被用於分類任務。最後變成一個多目標優化問題。

該方法與之前的方法不同之處就在於最後還得匹配一個先驗分布。之前完全靠數據驅動的方法可能有過擬合的問題。學習到的特徵表示,有了源域獨有的信息,這些信息不是一般化的特徵,會對網路在目標域上的表現不利。文章認為加上一個先驗分布,會減少特徵表示的過擬合風險。

自編碼器學習到的隱空間特徵分布必須滿足以下兩個性質:

  1. 所有源域的數據映射到特徵空間後,條件概率分布應該一致。
  2. 特徵分布應該包含能對學習分類器有用的判別信息。

文章用先驗分布滿足第一條,用同時訓練分類器滿足第二條。

三、方法

首先定義一些標記:

假設有K個標註好的源域,每個源域的數據是 X_l=[x_{l_1},...,x_{l_{n_l}}]^	oplin{1,...,K} ,其中 x_{l_i}inmathbb R^{d	imes1}n_l 是第 l 個源域的樣本數,數據對應的標籤是 Y_l=[y_{l_1},...,y_{l_{n_l}}]^	op ,其中 y_{l_i}in mathbb R^{m	imes 1} 是one-hot編碼。

3.1 對抗自編碼器AAE

對抗自編碼器能利用類似於GAN的對抗過程,將隱空間的特徵分布匹配一個任意的先驗分布。例如輸入 x ,隱空間輸出 h,那麼編碼解碼就是 q(h|x)q(x|h) 。我們希望隱空間特徵 h 的分布 q(h) 去匹配某個先驗 p(h) , 其中 q(h)=int_xq(h|x)p(x)dx ,這個目標就很類似於GAN,就是利用對抗網路去做的。

對抗網路損失:

mathcal J_{gan}=mathbb E_{hsim p(h)}[log D(h)]+mathbb E_{xsim p(x)}[log(1-D(Q(x))]

其中, D(cdot) 是判別器。 x 是所有源域的數據。作者通過實驗發現這邊 hsim Laplace(frac{1}{sqrt2}) 效果最好。

為了解決GAN收斂困難的問題,這邊將log損失改成了最小均方:

mathcal J_{gan}=mathbb E_{hsim p(h)}[D(h)^2]+mathbb E_{xsim p(x)}[(1-D(Q(x))^2]

相當於將優化從KL散度變成了 Pearson-mathcal X^2 散度[2]。

3.2 MMD-AAE

這邊假設編碼器是 Q(x) ,解碼器是 P(h) ,那麼自編碼器的損失是:

mathcal L_{ae}=sum_{l=1}^{K}|hat{X}_l-X_l|^2_2

其中 hat{X}_l=P(H_l)H_l=Q(X_l)

這邊要對K個源域的特徵分布做MMD,MMD我之前也介紹過,是遷移學習裡面常用的分布匹配方法。大致就是將數據投影到希爾伯特核空間,在那個空間將均值匹配上。一般用的是高斯核 k(x,x)=exp(-frac{1}{2sigma|x-x|^2}) 。MMD( H_i,H_j )可以寫為:

MMD(H_i,H_j)=|mu_{P_i}-mu_{P_j}|_{mathcal H}\ mu_P:=mu(P)=mathbb E_{xsim P}[phi(x)]=mathbb E_{xsim P}[k(x,cdot)]

作者通過一個引理證明了不同源域之間分布方差的上界:

frac{1}{K^2}sum_{1leq i,jleq K} MMD(H_i,H_j)

所以MMD正則化損失為:

mathcal R_{mmd}(H_1,...,H_K)=frac{1}{K^2}sum_{1leq i,jleq K} MMD(H_i,H_j)

實際計算就是將數據帶入如下式子:

MMD(H_l,H_t)=|frac{1}{n_l}sum_{i=1}^{n_l}phi(h_{l_i})-frac{1}{n_t}sum_{i=1}^{n_t}phi(h_{t_i})|^2_{mathcal H}\ =frac{1}{n_l^2}sum_{i=1}^{n_l}sum_{i=1}^{n_l}k(h_{l_i},h_{l_{i}})+frac{1}{n_t^2}sum_{j=1}^{n_t}sum_{j=1}^{n_t}k(h_{t_j},h_{t_{j}})-frac{2}{n_ln_t}sum_{i=1}^{n_l}sum_{j=1}^{n_t}k(h_{l_i},h_{t_{j}})

3.3 訓練過程

目標函數:

min_{C,Q,P}max_{D}mathcal L_{err}+lambda_0mathcal L_{ae}+lambda_1mathcal R_{mmd}+lambda_2 mathcal J_{gan}

其中, mathcal L_{err} 是分類器損失, C 是分類器,Q是編碼器,P是解碼器,D是判別器。

優化步驟如下:

四、實驗

作者做了大量的對比實驗,這邊以手寫數字體為例,將MNSIT數字體旋轉不同角度作為不同的源域,實驗結果如下:

可以看出大部分情況下,超過了其他對比的演算法。

五、參考文獻

[1]. Haoliang Li, Sinno Jialin Pan, Shiqi Wang, Alex C. Kot, "Domain Generalization with Adversarial Feature Learning", Computer Vision and Pattern Recognition (CVPR) 2018 IEEE/CVF Conference on, pp. 5400-5409, 2018.

[2]. X. Mao, Q. Li, H. Xie, R. Y. Lau, Z. Wang, and S. P. Smolley. Least squares generative adversarial networks. arXiv preprint ArXiv:1611.04076, 2016.

如有錯誤,歡迎交流指正。^_^

本文鏈接:zhuanlan.zhihu.com/p/68

轉載請註明出處!!^_^

作者:(知乎)種豆南山下, 中科院數學與系統科學研究院博士研究生

推薦閱讀:

相关文章