Conditional GAN 自从出现以来,产生了很多种 discriminator。cGANs with Projection Discriminator 介绍了一种新的 conditional GAN 判别器设计方法。从实践中来看,它还是比较好用的,被用在了 Spectral Normalization for Generative Adversarial Networks 和 Self-Attention Generative Adversarial Networks 中。
本文将详细介绍该演算法的推导过程以及本人的理解,指出它其实是化简版本的 AC-GANs。文中代码来自:crcrpar/pytorch.sngan_projection,该代码为 Spectral Normalization GAN 的 pytorch 版本。
下图为论文中列举的几种常见的 conditional discriminator。
(a) 中,label 和 data 在输入端拼接起来,一起输入一个神经网路,输出 data 为真实数据的概率。
(b) 和 (a) 很像,只不过是把 label 和 data 在神经网路的中间层拼接起来,输出 data 属于真实数据的概率。
(c) 图显示的 discriminator 有两部分输出,一部分和传统的 GAN 一样,表示输入数据真实的概率,用 adversarial loss 指导;另一部分为 classification output,表示数据属于各类的概率,用 classification loss 指导。
(d) 图为本文提出的结构,输入图片首先经过网路 提取特征,然后把特征分成两路:一路与经过编码的类别标签 y 做点乘,另一路再通过网路 映射成一维向量。最后两路相加,作为神经网路最终的输出。注意这个输出类似于 W-GAN,不经过 sigmoid 函数映射,越大代表越真实。
原文中的推导有点跳跃,这里给出我的补充。看起来很大块的数学公式,其实都是基本操作。
首先回顾 Ian Goodfellow 在 GAN 的原始论文中,给出的 discriminator 损失函数:
把 (1) 对 D(x) 求导,令导数为0,得到最优 discriminator:
即最优的 discriminator 就是数据为 real 的概率比上数据为 real 与数据为 fake 的概率之和。
接下来我们看 conditional discriminator 的损失函数:
根据 (1) 与 (2) 的推导,可知 conditional GAN 的最优 discriminator 应为:
而传统的 GAN 输出一个代表数据真实程度的概率, 。其中 为 activation function,通常为 sigmoid 函数:f(x)=1/(1+exp(-x))。因此,公式 (4) 可化为:
进一步化简 (5),得到原文公式:
我们在这里先暂停一下推导,思考公式 (6) 说了一件什么事。
首先看公式 (6) 倒数第二行第二项 :分子为 x 属于真实数据的概率,分母为 x 属于虚假数据的概率。对于一个最优分类器,当输入的 x 为真实数据时,我们希望它的输出值越大越好,而当输入 x 为虚假数据时,我们希望它的输出值越小越好,因此,这一项相当于在判断 x 的真实性。
接著看公式 (6) 倒数第二行第一项 :分子为假如 x 真实,那么 x 属于类别 y 的概率,而分母为假如 x 虚假,x 属于类别 y 的概率。显然,这项在判断 x 的类别是不是我们想要的,这项越大,代表 x 真实时,属于类别 y 的概率越大。
接下来继续推导:
作者研究公式 (6) 倒数第二行第一项 中的条件概率。对于一个 C 维输出的分类器,我们通常用 softmax 函数计算 x 属于各个类别的概率:
(7) 中的 为神经网路全连接层的输出,我们可以把它分解成倒数第二层的 feature 乘以一个 C 行的矩阵:
(8) 中的 即为图 2 中的 ,把 (8) 带入 (7) 可得:
因此:
把公式 (10) 带入公式 (6) 得:
令
得:
令矩阵 中各行向量为 , 为 one-hot label。则最终有:
可以把 理解成 label 的 embedding 层。这正是图 2 中的网路结构。
如果跳出繁琐的推导过程,直接看公式 (12),我们发现这个最优分类器包含两部分: 和 。
对于 ,其实就起 vanilla GAN discriminator 的作用,用于判断数据 x 是否为真实数据。
而 ,相当于神经网路的输出 与 one-hot label 的点乘,从而取出来输出部分对应的 target 类的值,这项越大,代表越逼真。我们回顾 multi-class crossentropy 的公式,如果某个数据的 one-hot label 为 ,网路的输出的概率分布为 ,则对该条数据的损失函数为:
注意 (13) 中的 大部分为0,只有 ground truth 是 1,因此,对于 multi-class crossentropy,相当于把神经网路输出的概率分布中,对应 ground truth 类的概率提了出来,求了个对数。而 中的 相当于一个特殊的分类网路,它输出的数字没有经过 softmax 映射成概率分布,但仍然可以代表输入数据属于某个类的程度深浅,越大代表越属于某类,越小则反之。
因此,这个文章所提出来的 discriminator,和 AC-GAN 非常类似,只不过
总的来看,感觉这篇文章有点借鉴了 vanilla GAN 到 W-GAN 之间的转变思想,即不再拘泥于输出概率,大胆抛弃 softmax 或者 sigmoid activation,用数字的大小直接代表属于某类的程度,这样反而比之前映射成概率的表现要好。
这里只贴出关键部分,详细代码见链接:
def forward(self, x, y=None): h = x h = self.block1(h) h = self.block2(h) h = self.block3(h) h = self.block4(h) h = self.block5(h) h = self.activation(h) # Global pooling h = torch.sum(h, dim=(2, 3)) # 提取x特征,送入两路,一路判断是否真实,一路判断是否属于label类 output = self.l6(h) # 相当于 vanilla GAN, 判断 x 是否真实 if y is not None: # 相当于不加 softmax 的 classifier, 直接提取 classifier 在 label 对应的维度的输出 class_out = torch.sum(self.l_y(y) * h, dim=1, keepdim=True) # 把两部分加起来作为 discriminator 的 output output += class_out
return output
推荐阅读: