Conditional GAN 自從出現以來,產生了很多種 discriminator。cGANs with Projection Discriminator 介紹了一種新的 conditional GAN 判別器設計方法。從實踐中來看,它還是比較好用的,被用在了 Spectral Normalization for Generative Adversarial Networks 和 Self-Attention Generative Adversarial Networks 中。
本文將詳細介紹該演算法的推導過程以及本人的理解,指出它其實是化簡版本的 AC-GANs。文中代碼來自:crcrpar/pytorch.sngan_projection,該代碼為 Spectral Normalization GAN 的 pytorch 版本。
下圖為論文中列舉的幾種常見的 conditional discriminator。
(a) 中,label 和 data 在輸入端拼接起來,一起輸入一個神經網路,輸出 data 為真實數據的概率。
(b) 和 (a) 很像,只不過是把 label 和 data 在神經網路的中間層拼接起來,輸出 data 屬於真實數據的概率。
(c) 圖顯示的 discriminator 有兩部分輸出,一部分和傳統的 GAN 一樣,表示輸入數據真實的概率,用 adversarial loss 指導;另一部分為 classification output,表示數據屬於各類的概率,用 classification loss 指導。
(d) 圖為本文提出的結構,輸入圖片首先經過網路 提取特徵,然後把特徵分成兩路:一路與經過編碼的類別標籤 y 做點乘,另一路再通過網路 映射成一維向量。最後兩路相加,作為神經網路最終的輸出。注意這個輸出類似於 W-GAN,不經過 sigmoid 函數映射,越大代表越真實。
原文中的推導有點跳躍,這裡給出我的補充。看起來很大塊的數學公式,其實都是基本操作。
首先回顧 Ian Goodfellow 在 GAN 的原始論文中,給出的 discriminator 損失函數:
把 (1) 對 D(x) 求導,令導數為0,得到最優 discriminator:
即最優的 discriminator 就是數據為 real 的概率比上數據為 real 與數據為 fake 的概率之和。
接下來我們看 conditional discriminator 的損失函數:
根據 (1) 與 (2) 的推導,可知 conditional GAN 的最優 discriminator 應為:
而傳統的 GAN 輸出一個代表數據真實程度的概率, 。其中 為 activation function,通常為 sigmoid 函數:f(x)=1/(1+exp(-x))。因此,公式 (4) 可化為:
進一步化簡 (5),得到原文公式:
我們在這裡先暫停一下推導,思考公式 (6) 說了一件什麼事。
首先看公式 (6) 倒數第二行第二項 :分子為 x 屬於真實數據的概率,分母為 x 屬於虛假數據的概率。對於一個最優分類器,當輸入的 x 為真實數據時,我們希望它的輸出值越大越好,而當輸入 x 為虛假數據時,我們希望它的輸出值越小越好,因此,這一項相當於在判斷 x 的真實性。
接著看公式 (6) 倒數第二行第一項 :分子為假如 x 真實,那麼 x 屬於類別 y 的概率,而分母為假如 x 虛假,x 屬於類別 y 的概率。顯然,這項在判斷 x 的類別是不是我們想要的,這項越大,代表 x 真實時,屬於類別 y 的概率越大。
接下來繼續推導:
作者研究公式 (6) 倒數第二行第一項 中的條件概率。對於一個 C 維輸出的分類器,我們通常用 softmax 函數計算 x 屬於各個類別的概率:
(7) 中的 為神經網路全連接層的輸出,我們可以把它分解成倒數第二層的 feature 乘以一個 C 行的矩陣:
(8) 中的 即為圖 2 中的 ,把 (8) 帶入 (7) 可得:
因此:
把公式 (10) 帶入公式 (6) 得:
令
得:
令矩陣 中各行向量為 , 為 one-hot label。則最終有:
可以把 理解成 label 的 embedding 層。這正是圖 2 中的網路結構。
如果跳出繁瑣的推導過程,直接看公式 (12),我們發現這個最優分類器包含兩部分: 和 。
對於 ,其實就起 vanilla GAN discriminator 的作用,用於判斷數據 x 是否為真實數據。
而 ,相當於神經網路的輸出 與 one-hot label 的點乘,從而取出來輸出部分對應的 target 類的值,這項越大,代表越逼真。我們回顧 multi-class crossentropy 的公式,如果某個數據的 one-hot label 為 ,網路的輸出的概率分佈為 ,則對該條數據的損失函數為:
注意 (13) 中的 大部分為0,只有 ground truth 是 1,因此,對於 multi-class crossentropy,相當於把神經網路輸出的概率分佈中,對應 ground truth 類的概率提了出來,求了個對數。而 中的 相當於一個特殊的分類網路,它輸出的數字沒有經過 softmax 映射成概率分佈,但仍然可以代表輸入數據屬於某個類的程度深淺,越大代表越屬於某類,越小則反之。
因此,這個文章所提出來的 discriminator,和 AC-GAN 非常類似,只不過
總的來看,感覺這篇文章有點借鑒了 vanilla GAN 到 W-GAN 之間的轉變思想,即不再拘泥於輸出概率,大膽拋棄 softmax 或者 sigmoid activation,用數字的大小直接代表屬於某類的程度,這樣反而比之前映射成概率的表現要好。
這裡只貼出關鍵部分,詳細代碼見鏈接:
def forward(self, x, y=None): h = x h = self.block1(h) h = self.block2(h) h = self.block3(h) h = self.block4(h) h = self.block5(h) h = self.activation(h) # Global pooling h = torch.sum(h, dim=(2, 3)) # 提取x特徵,送入兩路,一路判斷是否真實,一路判斷是否屬於label類 output = self.l6(h) # 相當於 vanilla GAN, 判斷 x 是否真實 if y is not None: # 相當於不加 softmax 的 classifier, 直接提取 classifier 在 label 對應的維度的輸出 class_out = torch.sum(self.l_y(y) * h, dim=1, keepdim=True) # 把兩部分加起來作為 discriminator 的 output output += class_out
return output
推薦閱讀: