Conditional GAN 自从出现以来,产生了很多种 discriminator。cGANs with Projection Discriminator 介绍了一种新的 conditional GAN 判别器设计方法。从实践中来看,它还是比较好用的,被用在了 Spectral Normalization for Generative Adversarial Networks 和 Self-Attention Generative Adversarial Networks 中。

本文将详细介绍该演算法的推导过程以及本人的理解,指出它其实是化简版本的 AC-GANs。文中代码来自:crcrpar/pytorch.sngan_projection,该代码为 Spectral Normalization GAN 的 pytorch 版本。

背景

下图为论文中列举的几种常见的 conditional discriminator。

图 1

(a) 中,label 和 data 在输入端拼接起来,一起输入一个神经网路,输出 data 为真实数据的概率。

(b) 和 (a) 很像,只不过是把 label 和 data 在神经网路的中间层拼接起来,输出 data 属于真实数据的概率。

(c) 图显示的 discriminator 有两部分输出,一部分和传统的 GAN 一样,表示输入数据真实的概率,用 adversarial loss 指导;另一部分为 classification output,表示数据属于各类的概率,用 classification loss 指导。

(d) 图为本文提出的结构,输入图片首先经过网路 phi 提取特征,然后把特征分成两路:一路与经过编码的类别标签 y 做点乘,另一路再通过网路 psi 映射成一维向量。最后两路相加,作为神经网路最终的输出。注意这个输出类似于 W-GAN,不经过 sigmoid 函数映射,越大代表越真实。

推导

原文中的推导有点跳跃,这里给出我的补充。看起来很大块的数学公式,其实都是基本操作。

首先回顾 Ian Goodfellow 在 GAN 的原始论文中,给出的 discriminator 损失函数:

egin{align*} maxlimits_{D} V(G, D)&=mathbb{E}_{xsim p_{data}(x)}[log D(x)]+mathbb{E}_{zsim p_{z}(z)}[log(1-D(G(z)))]\ &=int_xp_{data}(x)log(D(x))mathrm{d}x+int_z p_z(z)log(1-D(g(z)))mathrm{d}z\ &=int_xp_{data}(x)log(D(x))+p_g(x)log(1-D(x))mathrm{d}x	ag{1} end{align*}

把 (1) 对 D(x) 求导,令导数为0,得到最优 discriminator:

D_G^*(x)=frac{p_{data}(x)}{p_{data}(x)+p_g(x)}	ag{2}

即最优的 discriminator 就是数据为 real 的概率比上数据为 real 与数据为 fake 的概率之和。

接下来我们看 conditional discriminator 的损失函数:

egin{align*} minlimits_{D}mathcal{L}(D)=&-mathbb{E}_{ysim q_{data}(y)}left[mathbb{E}_{xsim q_{data}(x|y)}[log(D(x,y))]
ight]\&-mathbb{E}_{ysim p_g(y)}left[mathbb{E}_{xsim p_g(x|y)}[log(1-D(x,y))]
ight]	ag{3} end{align*}

根据 (1) 与 (2) 的推导,可知 conditional GAN 的最优 discriminator 应为:

egin{align*} D_G^*(x,y)&=frac{q_{data}(x,y)}{q_{data}(x,y)+p_g(x,y)}	ag{4}\ &=frac{q_{data}(x|y)q_{data}(y)}{q_{data}(x|y)q_{data}(y)+p_g(x|y)p_g(y)} end{align*}

而传统的 GAN 输出一个代表数据真实程度的概率, D(x,y;	heta)=mathcal{A(f(x,y;	heta))} 。其中 mathcal{A} 为 activation function,通常为 sigmoid 函数:f(x)=1/(1+exp(-x))。因此,公式 (4) 可化为:

frac{1}{1+mathrm{exp}(-f^*(x,y))}=frac{q_{data}(x|y)q_{data}(y)}{q_{data}(x|y)q_{data}(y)+p_g(x|y)p_g(y)}	ag{5}

进一步化简 (5),得到原文公式:

egin{align*} f^*(x,y)&=logfrac{q_{data}(x|y)q_{data}(y)}{p_g(x|y)p_g(y)}=logfrac{q_{data}(x,y)}{p_g(x,y)}\ &=logfrac{q_{data}(y|x)q_{data}(x)}{p_g(y|x)p_g(x)}\ &=logfrac{q_{data}(y|x)}{p_g(y|x)}+logfrac{q_{data}(x)}{p_{g}(x)}\ &=r(y|x)+r(x)	ag{6} end{align*}

我们在这里先暂停一下推导,思考公式 (6) 说了一件什么事。

首先看公式 (6) 倒数第二行第二项 log [q_{data}(x)/p_g(x)]:分子为 x 属于真实数据的概率,分母为 x 属于虚假数据的概率。对于一个最优分类器,当输入的 x 为真实数据时,我们希望它的输出值越大越好,而当输入 x 为虚假数据时,我们希望它的输出值越小越好,因此,这一项相当于在判断 x 的真实性。

接著看公式 (6) 倒数第二行第一项 log [q_{data}(y|x)/p_g(y|x)] :分子为假如 x 真实,那么 x 属于类别 y 的概率,而分母为假如 x 虚假,x 属于类别 y 的概率。显然,这项在判断 x 的类别是不是我们想要的,这项越大,代表 x 真实时,属于类别 y 的概率越大。

接下来继续推导:

图 2

作者研究公式 (6) 倒数第二行第一项 log [q_{data}(y|x)/p_g(y|x)] 中的条件概率。对于一个 C 维输出的分类器,我们通常用 softmax 函数计算 x 属于各个类别的概率:

p(y=c|x)=frac{mathrm{exp}(o_c)}{sum_{j=1}^{C}mathrm{exp}(o_j)}	ag{7}

(7) 中的 o_j 为神经网路全连接层的输出,我们可以把它分解成倒数第二层的 feature 乘以一个 C 行的矩阵:

o_j=v_j^Tphi(x)	ag{8}

(8) 中的 phi(x) 即为图 2 中的 phi ,把 (8) 带入 (7) 可得:

egin{align*} log p(y=c|x)&=log frac{mathrm{exp}(v_c^{pT}phi(x))}{sum_{j=1}^Cmathrm{exp}(v_j^{pT}phi(x))}\ &=v_c^{pT}phi(x)-log left(sum_{j=1}^Cmathrm{exp}left(v_j^{pT}phi(x)
ight)
ight)\ &=v_c^{pT}phi(x)-log Z^p(phi(x)) end{align*}	ag{9}

因此:

egin{align*} log frac{q_{data}(y=c|x)}{p_g(y=c|x)}=&log q_{data}(y=c|x)-log p_g(y=c|x)\ =&v_c^{q_{data}T}phi(x)-log Z^{q_{data}}ig(phi(x)ig)-\ &Big(v_c^{p_{g}T}phi(x)-log Z^{p_{g}}ig(phi(x)ig)Big)\ =&(v_c^{q_{data}}-v_c^{p_g})^Tphi(x)-\ &Big(log Z^{q_{data}}ig(phi(x)ig)-log Z^{p_g}ig(phi(x)ig)Big) end{align*}	ag{10}

把公式 (10) 带入公式 (6) 得:

egin{align*} f^*(x,y)&=logfrac{q_{data}(y|x)}{p_g(y|x)}+logfrac{q_{data}(x)}{p_{g}(x)}\ &=(v_c^{q_{data}}-v_c^{p_g})^Tphi(x)-\ &Big(log Z^{q_{data}}ig(phi(x)ig)-log Z^{p_g}ig(phi(x)ig)Big)+\ &logfrac{q_{data}(x)}{p_g(x)} end{align*}	ag{11}

 v_c^{q_{data}}-v_c^{p_g}=v_c

-ig(log Z^{q_{data}}(phi(x))-log Z^{p_g}(phi(x))ig)+logfrac{q_{data}(x)}{p_g(x)}=psi(phi(x))

得:

f^*(x,y=c)=v_c^Tphi(x)+psi(phi(x))

令矩阵 V 中各行向量为 v_j^Ty 为 one-hot label。则最终有:

f^*(x,y)=y^TVphi(x)+psi(phi(x))	ag{12}

可以把 V 理解成 label 的 embedding 层。这正是图 2 中的网路结构。

理解

如果跳出繁琐的推导过程,直接看公式 (12),我们发现这个最优分类器包含两部分: y^TVphi(x)psi(phi(x))

对于psi(phi(x)) ,其实就起 vanilla GAN discriminator 的作用,用于判断数据 x 是否为真实数据。

y^TVphi(x)=(Vphi(x))^Ty ,相当于神经网路的输出 Vphi(x) 与 one-hot label y 的点乘,从而取出来输出部分对应的 target 类的值,这项越大,代表越逼真。我们回顾 multi-class crossentropy 的公式,如果某个数据的 one-hot label 为 y ,网路的输出的概率分布为 p,则对该条数据的损失函数为:

mathcal{L}(p, y)=sum_{j=1}^c y_j log p_j	ag{13}

注意 (13) 中的 y_j 大部分为0,只有 ground truth 是 1,因此,对于 multi-class crossentropy,相当于把神经网路输出的概率分布中,对应 ground truth 类的概率提了出来,求了个对数。而 y^TVphi(x)=(Vphi(x))^Ty 中的 (Vphi(x)) 相当于一个特殊的分类网路,它输出的数字没有经过 softmax 映射成概率分布,但仍然可以代表输入数据属于某个类的程度深浅,越大代表越属于某类,越小则反之。

因此,这个文章所提出来的 discriminator,和 AC-GAN 非常类似,只不过

  1. AC-GAN 中 auxiliary classifier 输出的是概率,这里的输出是任意数字;
  2. AC-GAN 的 classification loss 和 adversarial loss 之间权重可调,而这个 conditional discriminator 的权重为 1:1 的关系。
  3. 此外需要注意一个细节,如果看 AC-GAN 的原文,发现当训练 discriminator 时,如果输入数据为 fake,作者仍旧希望 auxiliary classifier 将其归为 conditional 类,即希望 classifier 在 conditional label 对应的维度输出一个较大的值;而本文中,当输入数据为 fake 的时候,作者希望它在 conditional label 对应的维度上输出的值越小越好。

总的来看,感觉这篇文章有点借鉴了 vanilla GAN 到 W-GAN 之间的转变思想,即不再拘泥于输出概率,大胆抛弃 softmax 或者 sigmoid activation,用数字的大小直接代表属于某类的程度,这样反而比之前映射成概率的表现要好。

代码

这里只贴出关键部分,详细代码见链接:

def forward(self, x, y=None):
h = x
h = self.block1(h)
h = self.block2(h)
h = self.block3(h)
h = self.block4(h)
h = self.block5(h)
h = self.activation(h)
# Global pooling
h = torch.sum(h, dim=(2, 3)) # 提取x特征,送入两路,一路判断是否真实,一路判断是否属于label类
output = self.l6(h) # 相当于 vanilla GAN, 判断 x 是否真实
if y is not None:
# 相当于不加 softmax 的 classifier, 直接提取 classifier 在 label 对应的维度的输出
class_out = torch.sum(self.l_y(y) * h, dim=1, keepdim=True)
# 把两部分加起来作为 discriminator 的 output
output += class_out

return output

推荐阅读:

相关文章