Conditional GAN 自從出現以來,產生了很多種 discriminator。cGANs with Projection Discriminator 介紹了一種新的 conditional GAN 判別器設計方法。從實踐中來看,它還是比較好用的,被用在了 Spectral Normalization for Generative Adversarial Networks 和 Self-Attention Generative Adversarial Networks 中。

本文將詳細介紹該演算法的推導過程以及本人的理解,指出它其實是化簡版本的 AC-GANs。文中代碼來自:crcrpar/pytorch.sngan_projection,該代碼為 Spectral Normalization GAN 的 pytorch 版本。

背景

下圖為論文中列舉的幾種常見的 conditional discriminator。

圖 1

(a) 中,label 和 data 在輸入端拼接起來,一起輸入一個神經網路,輸出 data 為真實數據的概率。

(b) 和 (a) 很像,只不過是把 label 和 data 在神經網路的中間層拼接起來,輸出 data 屬於真實數據的概率。

(c) 圖顯示的 discriminator 有兩部分輸出,一部分和傳統的 GAN 一樣,表示輸入數據真實的概率,用 adversarial loss 指導;另一部分為 classification output,表示數據屬於各類的概率,用 classification loss 指導。

(d) 圖為本文提出的結構,輸入圖片首先經過網路 phi 提取特徵,然後把特徵分成兩路:一路與經過編碼的類別標籤 y 做點乘,另一路再通過網路 psi 映射成一維向量。最後兩路相加,作為神經網路最終的輸出。注意這個輸出類似於 W-GAN,不經過 sigmoid 函數映射,越大代表越真實。

推導

原文中的推導有點跳躍,這裡給出我的補充。看起來很大塊的數學公式,其實都是基本操作。

首先回顧 Ian Goodfellow 在 GAN 的原始論文中,給出的 discriminator 損失函數:

egin{align*} maxlimits_{D} V(G, D)&=mathbb{E}_{xsim p_{data}(x)}[log D(x)]+mathbb{E}_{zsim p_{z}(z)}[log(1-D(G(z)))]\ &=int_xp_{data}(x)log(D(x))mathrm{d}x+int_z p_z(z)log(1-D(g(z)))mathrm{d}z\ &=int_xp_{data}(x)log(D(x))+p_g(x)log(1-D(x))mathrm{d}x	ag{1} end{align*}

把 (1) 對 D(x) 求導,令導數為0,得到最優 discriminator:

D_G^*(x)=frac{p_{data}(x)}{p_{data}(x)+p_g(x)}	ag{2}

即最優的 discriminator 就是數據為 real 的概率比上數據為 real 與數據為 fake 的概率之和。

接下來我們看 conditional discriminator 的損失函數:

egin{align*} minlimits_{D}mathcal{L}(D)=&-mathbb{E}_{ysim q_{data}(y)}left[mathbb{E}_{xsim q_{data}(x|y)}[log(D(x,y))]
ight]\&-mathbb{E}_{ysim p_g(y)}left[mathbb{E}_{xsim p_g(x|y)}[log(1-D(x,y))]
ight]	ag{3} end{align*}

根據 (1) 與 (2) 的推導,可知 conditional GAN 的最優 discriminator 應為:

egin{align*} D_G^*(x,y)&=frac{q_{data}(x,y)}{q_{data}(x,y)+p_g(x,y)}	ag{4}\ &=frac{q_{data}(x|y)q_{data}(y)}{q_{data}(x|y)q_{data}(y)+p_g(x|y)p_g(y)} end{align*}

而傳統的 GAN 輸出一個代表數據真實程度的概率, D(x,y;	heta)=mathcal{A(f(x,y;	heta))} 。其中 mathcal{A} 為 activation function,通常為 sigmoid 函數:f(x)=1/(1+exp(-x))。因此,公式 (4) 可化為:

frac{1}{1+mathrm{exp}(-f^*(x,y))}=frac{q_{data}(x|y)q_{data}(y)}{q_{data}(x|y)q_{data}(y)+p_g(x|y)p_g(y)}	ag{5}

進一步化簡 (5),得到原文公式:

egin{align*} f^*(x,y)&=logfrac{q_{data}(x|y)q_{data}(y)}{p_g(x|y)p_g(y)}=logfrac{q_{data}(x,y)}{p_g(x,y)}\ &=logfrac{q_{data}(y|x)q_{data}(x)}{p_g(y|x)p_g(x)}\ &=logfrac{q_{data}(y|x)}{p_g(y|x)}+logfrac{q_{data}(x)}{p_{g}(x)}\ &=r(y|x)+r(x)	ag{6} end{align*}

我們在這裡先暫停一下推導,思考公式 (6) 說了一件什麼事。

首先看公式 (6) 倒數第二行第二項 log [q_{data}(x)/p_g(x)]:分子為 x 屬於真實數據的概率,分母為 x 屬於虛假數據的概率。對於一個最優分類器,當輸入的 x 為真實數據時,我們希望它的輸出值越大越好,而當輸入 x 為虛假數據時,我們希望它的輸出值越小越好,因此,這一項相當於在判斷 x 的真實性。

接著看公式 (6) 倒數第二行第一項 log [q_{data}(y|x)/p_g(y|x)] :分子為假如 x 真實,那麼 x 屬於類別 y 的概率,而分母為假如 x 虛假,x 屬於類別 y 的概率。顯然,這項在判斷 x 的類別是不是我們想要的,這項越大,代表 x 真實時,屬於類別 y 的概率越大。

接下來繼續推導:

圖 2

作者研究公式 (6) 倒數第二行第一項 log [q_{data}(y|x)/p_g(y|x)] 中的條件概率。對於一個 C 維輸出的分類器,我們通常用 softmax 函數計算 x 屬於各個類別的概率:

p(y=c|x)=frac{mathrm{exp}(o_c)}{sum_{j=1}^{C}mathrm{exp}(o_j)}	ag{7}

(7) 中的 o_j 為神經網路全連接層的輸出,我們可以把它分解成倒數第二層的 feature 乘以一個 C 行的矩陣:

o_j=v_j^Tphi(x)	ag{8}

(8) 中的 phi(x) 即為圖 2 中的 phi ,把 (8) 帶入 (7) 可得:

egin{align*} log p(y=c|x)&=log frac{mathrm{exp}(v_c^{pT}phi(x))}{sum_{j=1}^Cmathrm{exp}(v_j^{pT}phi(x))}\ &=v_c^{pT}phi(x)-log left(sum_{j=1}^Cmathrm{exp}left(v_j^{pT}phi(x)
ight)
ight)\ &=v_c^{pT}phi(x)-log Z^p(phi(x)) end{align*}	ag{9}

因此:

egin{align*} log frac{q_{data}(y=c|x)}{p_g(y=c|x)}=&log q_{data}(y=c|x)-log p_g(y=c|x)\ =&v_c^{q_{data}T}phi(x)-log Z^{q_{data}}ig(phi(x)ig)-\ &Big(v_c^{p_{g}T}phi(x)-log Z^{p_{g}}ig(phi(x)ig)Big)\ =&(v_c^{q_{data}}-v_c^{p_g})^Tphi(x)-\ &Big(log Z^{q_{data}}ig(phi(x)ig)-log Z^{p_g}ig(phi(x)ig)Big) end{align*}	ag{10}

把公式 (10) 帶入公式 (6) 得:

egin{align*} f^*(x,y)&=logfrac{q_{data}(y|x)}{p_g(y|x)}+logfrac{q_{data}(x)}{p_{g}(x)}\ &=(v_c^{q_{data}}-v_c^{p_g})^Tphi(x)-\ &Big(log Z^{q_{data}}ig(phi(x)ig)-log Z^{p_g}ig(phi(x)ig)Big)+\ &logfrac{q_{data}(x)}{p_g(x)} end{align*}	ag{11}

 v_c^{q_{data}}-v_c^{p_g}=v_c

-ig(log Z^{q_{data}}(phi(x))-log Z^{p_g}(phi(x))ig)+logfrac{q_{data}(x)}{p_g(x)}=psi(phi(x))

得:

f^*(x,y=c)=v_c^Tphi(x)+psi(phi(x))

令矩陣 V 中各行向量為 v_j^Ty 為 one-hot label。則最終有:

f^*(x,y)=y^TVphi(x)+psi(phi(x))	ag{12}

可以把 V 理解成 label 的 embedding 層。這正是圖 2 中的網路結構。

理解

如果跳出繁瑣的推導過程,直接看公式 (12),我們發現這個最優分類器包含兩部分: y^TVphi(x)psi(phi(x))

對於psi(phi(x)) ,其實就起 vanilla GAN discriminator 的作用,用於判斷數據 x 是否為真實數據。

y^TVphi(x)=(Vphi(x))^Ty ,相當於神經網路的輸出 Vphi(x) 與 one-hot label y 的點乘,從而取出來輸出部分對應的 target 類的值,這項越大,代表越逼真。我們回顧 multi-class crossentropy 的公式,如果某個數據的 one-hot label 為 y ,網路的輸出的概率分佈為 p,則對該條數據的損失函數為:

mathcal{L}(p, y)=sum_{j=1}^c y_j log p_j	ag{13}

注意 (13) 中的 y_j 大部分為0,只有 ground truth 是 1,因此,對於 multi-class crossentropy,相當於把神經網路輸出的概率分佈中,對應 ground truth 類的概率提了出來,求了個對數。而 y^TVphi(x)=(Vphi(x))^Ty 中的 (Vphi(x)) 相當於一個特殊的分類網路,它輸出的數字沒有經過 softmax 映射成概率分佈,但仍然可以代表輸入數據屬於某個類的程度深淺,越大代表越屬於某類,越小則反之。

因此,這個文章所提出來的 discriminator,和 AC-GAN 非常類似,只不過

  1. AC-GAN 中 auxiliary classifier 輸出的是概率,這裡的輸出是任意數字;
  2. AC-GAN 的 classification loss 和 adversarial loss 之間權重可調,而這個 conditional discriminator 的權重為 1:1 的關係。
  3. 此外需要注意一個細節,如果看 AC-GAN 的原文,發現當訓練 discriminator 時,如果輸入數據為 fake,作者仍舊希望 auxiliary classifier 將其歸為 conditional 類,即希望 classifier 在 conditional label 對應的維度輸出一個較大的值;而本文中,當輸入數據為 fake 的時候,作者希望它在 conditional label 對應的維度上輸出的值越小越好。

總的來看,感覺這篇文章有點借鑒了 vanilla GAN 到 W-GAN 之間的轉變思想,即不再拘泥於輸出概率,大膽拋棄 softmax 或者 sigmoid activation,用數字的大小直接代表屬於某類的程度,這樣反而比之前映射成概率的表現要好。

代碼

這裡只貼出關鍵部分,詳細代碼見鏈接:

def forward(self, x, y=None):
h = x
h = self.block1(h)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.block5(h)
h = self.activation(h)
# Global pooling
h = torch.sum(h, dim=(2, 3)) # 提取x特徵,送入兩路,一路判斷是否真實,一路判斷是否屬於label類
output = self.l6(h) # 相當於 vanilla GAN, 判斷 x 是否真實
if y is not None:
# 相當於不加 softmax 的 classifier, 直接提取 classifier 在 label 對應的維度的輸出
class_out = torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
# 把兩部分加起來作為 discriminator 的 output
output += class_out

return output

推薦閱讀:

相關文章