这里用 p_	heta(x|z) 代表生成模型, q_phi(z|x) 代表编码模型。

首先:egin{align} mathop{argmin}_	heta mathsf{KL}(p parallel p_	heta) &= mathop{argmin}_	heta int p(x) log frac{p(x)}{p_	heta(x)}  , dx \&= mathop{argmin}_	heta Big[ int p(x) log p(x)  , dx  - int p(x) log p_	heta(x)  , dx Big] \&= mathop{argmin}_	heta Big[H(p(x))  - int p(x) log p_	heta(x)  , dx Big] \&= mathop{argmax}_	heta int p(x) log p_	heta(x)  , dx  \&=  mathop{argmax}_	heta mathbb{E}_{x sim p(x)}log p_	heta(x) end{align}

写成 MLE,比 KL 散度少一个常数,形式更简单。没必要时刻写出 KL 散度。


让我们 1 秒钟推导出 VAE。

思想是,匹配 p(x,z)p_	heta(x,z) ,就可以同时匹配 xz 的边缘分布。

MLE 如下:

mathop{argmax}_	heta mathbb{E}_{(x,z) sim p(x,z)}log p_	heta(x,z)

显然等价于:

mathop{argmax}_	heta mathbb{E}_{x sim p(x)} Big[ mathbb{E}_{z sim p(z|x)} Big[ log p_	heta(x|z) + log p_	heta(z) Big]  Big]

恭喜,推导出了 VAE。

此外,这个 MLE 显然也等价于:

mathop{argmin}_	heta mathsf{KL}(p(x,z) parallel p_	heta(x,z))


加 10 秒钟,把它变成更常见的样子。

展开:

mathop{argmax}_	heta mathbb{E}_{x sim p(x)} Big[ mathbb{E}_{z sim p(z|x)}  log p_	heta(x|z) + mathbb{E}_{z sim p(z|x)} log p_	heta(z) Big]

改变符号,显然等价于:

mathop{argmin}_	heta mathbb{E}_{x sim p(x)} Big[ mathbb{E}_{z sim p(z|x)} - log p_	heta(x|z) + mathsf{KL}ig( p(z|x) parallel  p_	heta(z)ig) Big]

再加上先验:

p(z|x) = N(mu_x, sigma_x^2)

p_	heta(z) = N(0, I)

并令 p_	heta(x|z) = N(G_	heta(z), sigma^2) 是固定 stdev 的 Gaussian 以造出 MSE:

- log p_	heta(x|z) = frac{1}{2 sigma^2} cdot |G_	heta(z)-x|^2 + log(sqrt{2pi sigma^2})

忽略常数,就和实际用的一模一样了。

注意:由此可见,VAE 一点儿也不模糊,真正的 VAE p_	heta(x|z) 有很多噪音(由于这里的概率模型是每点独立,因此噪音也是每点独立的噪音)。许多论文显示的模糊图像,是"平均图像"。

注意:我们完全可以用更复杂的先验,例如用 PixelXNN 生成图像,这样就完全没有"模糊"。

注意:似乎没人实验可变 stdev 的 Gaussian。所以我做了一些实验,见本文末尾。


最终结果( eta-	ext{VAE} ):

mathop{argmin}_{	heta,, mu_i,, sigma_i} mathbb{E}_{x sim p(x)} Big[ mathbb{E}_{z_i sim N(mu_i,sigma_i^2)} vert G_	heta({z_i}) - x vert^2 + eta cdot  frac{1}{2} sum_i ig(mu_i^2 + sigma_i^2 - log sigma_i^2 -1 ig) Big]

其中每个 z_i 来自独立的 N(mu_i,sigma_i^2) 采样。

定义 t_i=log sigma_i^2 ,并加入重参数化 trick:

mathop{argmin}_{	heta,, mu_i,, t_i} mathbb{E}_{x sim p(x)} Big[ mathbb{E}_{epsilon_i sim N(0,1)} vert G_	heta({mu_i + epsilon_i cdot exp(t_i/2)}) - x vert^2 + eta cdot  frac{1}{2} sum_i ig(mu_i^2 + exp(t_i) - t_i -1 ig) Big]

于是可求 LOSS 对 	heta,, mu_i,, t_i 的导数,进行 SGD。


补充传统的推导过程。如前所述,用 p_	heta(x|z) 代表生成模型, q_phi(z|x) 代表编码模型。

我们的目标是边缘分布的 MLE:

mathop{argmax}_	heta mathbb{E}_{x sim p(x)}log p_	heta(x)

注意到这里有 ELBO:

egin{align} log p_	heta(x) &= log int p_	heta(x,z) ,dz \&=  log int q_phi(z|x) frac{p_	heta(x,z)}{q_phi(z|x)} ,dz \&= log E_{ z sim q_phi(z|x)} frac{p_	heta(x,z)}{q_phi(z|x)} \&geq  E_{ z sim q_phi(z|x)} log frac{p_	heta(x,z)}{q_phi(z|x)} \&=  E_{ z sim q_phi(z|x)} log frac{p_	heta(x|z) , p_	heta(z)}{q_phi(z|x)} \&= E_{ z sim q_phi(z|x)} log p_	heta(x|z) - E_{ z sim q_phi(z|x)} log frac{q_phi(z|x)}{p_	heta(z)} \&=E_{ z sim q_phi(z|x)} log p_	heta(x|z) - mathsf{KL}(q_phi(z|x) parallel p_	heta(z)) end{align}

因此目标等价于:

mathop{argmax}_{	heta, , phi} mathbb{E}_{x sim p(x)} Big[ E_{ z sim q_phi(z|x)} log p_	heta(x|z) - mathsf{KL}(q_phi(z|x) parallel p_	heta(z)) Big]

即:

mathop{argmin}_{	heta,, phi} mathbb{E}_{x sim p(x)} Big[E_{ z sim q_phi(z|x)} - log p_	heta(x|z) + mathsf{KL}(q_phi(z|x) parallel p_	heta(z)) Big]

这与我们之前的推导相同,只是多了一个 q_phi(z|x) 去逼近之前的 p(z|x)


这里还有一种写法,注意到:

E_{ z sim q_phi(z|x)} log frac{p_	heta(x,z)}{q_phi(z|x)} = E_{ z sim q_phi(z|x)} log frac{p_	heta(x) , p_	heta(z|x)}{q_phi(z|x)} = log p_	heta(x) - mathsf{KL}( q_phi(z|x) parallel  p_	heta(z|x))

因此目标等价于:

mathop{argmin}_{	heta,, phi} mathbb{E}_{x sim p(x)} Big[ - log p_	heta(x) + E_{ z sim q_phi(z|x)} mathsf{KL}(q_phi(z|x) parallel p_	heta(z|x)) Big]


我在另一篇文章,简单实验了 p_	heta(x|z) 有可变 stdev 的情况:

PENG Bo:DGN v2:生成器应该输出分布,清晰图像并不是 GAN 的特权?

zhuanlan.zhihu.com
图标

推荐阅读:
相关文章