Image-to-image translation for cross-domain disentanglement論文解讀及代碼分析

Image-to-image translation for cross-domain disentanglement?

arxiv.org

代碼:

agonzgarc/cross-domain-disen?

github.com
圖標

該篇文章發表於2018NIPS上,與上一篇文章採用unpaired的圖像數據不同,該文採用的是pair數據,但同樣採用的是分解的思想(disentanglement)

提出一種跨域分解(cross-domain disentanglement)的方法,將兩個圖像域中對應圖像中域共有的東西從域獨有的東西分離出來,那怎樣才能保證生成的部分是兩個域共有的,而另外一部分是域間獨有的呢?且看作者是怎麼分析的。

一、先來看看整體的模型框圖:

同樣採用解開表示的方法,將兩個域中的圖像分解為兩部分:(1)域內獨有的部分和(2)域間共有的部分。如Figure 1 右邊圖所示。兩張圖都是表示5,所有共有的部分是數字5,但5的顏色和背景每個圖像域是不同,這就是域獨有的部分。

為保證域獨有和域共有,作者做了以下的工作(X域和Y域是相同的,採用X域來進行說明):

(1)Exclusive representation. E^{x} 是X域分解後得到的域獨有的信息,因此由E^{x}應該無法重構出該圖像在Y域對應的圖像,為達到這個目的,作者提出了利用Gradient Reversal Layer (GRL)。具體地就是在 G_{e}生成E^{x}後面再接一個小解碼器(a small decoder) G_{d}^{X} ,該小解碼器希望可以重構出X域對應的Y域照片,採用的訓練方式是對抗損失。但我們並不是真的需要由E^{x}可以生成對應的Y域圖像,反而更加需要的是生成不了這張對應的圖像,所以就有了GRL發揮的作用的地方。GRL反轉了反向傳播到編碼器Ge的梯度符號,僅影響生成域獨享特徵E^{x} 所涉及的那些單元。

這部分的代碼,作者給出的代碼是:

# Create generators/discriminators for exclusive representation

## X域
with tf.variable_scope("generator_exclusiveX2Y_decoder"):
outputs_exclusiveX2Y = create_generator_decoder_exclusive(eR_X2Y, out_channels, a)

with tf.name_scope("real_discriminator_exclusiveX2Y"):
with tf.variable_scope("discriminator_exclusiveX2Y"):
predict_real_exclusiveX2Y = create_discriminator(inputsX, targetsX, a)

with tf.name_scope("fake_discriminator_exclusiveX2Y"):
with tf.variable_scope("discriminator_exclusiveX2Y", reuse=True):
predict_fake_exclusiveX2Y = create_discriminator(inputsX, outputs_exclusiveX2Y, a)
## Y域
with tf.variable_scope("generator_exclusiveY2X_decoder"):
outputs_exclusiveY2X = create_generator_decoder_exclusive(eR_Y2X, out_channels, a)

with tf.name_scope("real_discriminator_exclusiveY2X"):
with tf.variable_scope("discriminator_exclusiveY2X"):
predict_real_exclusiveY2X = create_discriminator(inputsY, targetsY, a)

with tf.name_scope("fake_discriminator_exclusiveY2Y"):
with tf.variable_scope("discriminator_exclusiveY2X", reuse=True):
predict_fake_exclusiveY2X = create_discriminator(inputsY, outputs_exclusiveY2X, a)

對抗損失部分的代碼如下:

with tf.name_scope("generator_exclusiveX2Y_loss"):
gen_exclusiveX2Y_loss_GAN = -tf.reduce_mean(predict_fake_exclusiveX2Y)
gen_exclusiveX2Y_loss = gen_exclusiveX2Y_loss_GAN * a.gan_exclusive_weight

with tf.name_scope("discriminator_exclusiveX2Y_loss"):
discrim_exclusiveX2Y_loss = tf.reduce_mean(predict_fake_exclusiveX2Y) - tf.reduce_mean(predict_real_exclusiveX2Y)
alpha = tf.random_uniform(shape=[a.batch_size,1], minval=0., maxval=1.)
differences = tf.reshape(outputs_exclusiveX2Y,[-1,OUTPUT_DIM])-tf.reshape(targetsX,[-1,OUTPUT_DIM])
interpolates = tf.reshape(targetsX,[-1,OUTPUT_DIM]) + (alpha*differences)
with tf.variable_scope("discriminator_exclusiveX2Y", reuse=True):
gradients = tf.gradients(create_discriminator(inputsX,tf.reshape(interpolates,[-1,IMAGE_SIZE,IMAGE_SIZE,3]),a),
[interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients),
reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
discrim_exclusiveX2Y_loss += LAMBDA*gradient_penalty

with tf.name_scope("generator_exclusiveY2X_loss"):
gen_exclusiveY2X_loss_GAN = -tf.reduce_mean(predict_fake_exclusiveY2X)
gen_exclusiveY2X_loss = gen_exclusiveY2X_loss_GAN * a.gan_exclusive_weight

with tf.name_scope("discriminator_exclusiveY2X_loss"):
discrim_exclusiveY2X_loss = tf.reduce_mean(predict_fake_exclusiveY2X) - tf.reduce_mean(predict_real_exclusiveY2X)
alpha = tf.random_uniform(shape=[a.batch_size,1], minval=0., maxval=1.)
differences = tf.reshape(outputs_exclusiveY2X,[-1,OUTPUT_DIM])-tf.reshape(targetsX,[-1,OUTPUT_DIM])
interpolates = tf.reshape(targetsX,[-1,OUTPUT_DIM]) + (alpha*differences)
with tf.variable_scope("discriminator_exclusiveY2X", reuse=True):
gradients = tf.gradients(create_discriminator(inputsX,tf.reshape(interpolates,[-1,IMAGE_SIZE,IMAGE_SIZE,3]),a),
[interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients),
reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)
discrim_exclusiveY2X_loss += LAMBDA*gradient_penalty

但是我好像沒找到GRL那部分對應的代碼(是我忽略了哪一部分嗎?要是有哪位大神知道的話,麻煩告訴我,謝謝。)

(2)Shared representation.這篇文章用的是paired的數據,所以它們共享部分的信息應該是相同的(如,都是表示數字5),所以 S^{X}和S^{Y} 應該是相似的,作者這裡直接採用的是L1損失:

但(1)中的損失會促使模型產生小的信號,作者解決方法是加入小的雜訊:

這個加雜訊的代碼我在原代碼中也沒找到,是我忽略了什麼嗎?哪位大神知道,麻煩告知。

(3)Reconstructing the latent space.這部分就是進行重構,和BicycleGan的思想類似,我對這部分的理解更多是保留互信息,不至於加進去的雜訊被生成器忽略。同時共享內容部分同樣也是要進行重構。這樣的損失函數在 文章Multimodal UnsupervisedImage-to-Image Translations中也有用到(該文章中的公式(2)和(3)),感興趣的可以看我上一篇分解表示的文章。這麼多文章都用到了這樣的重構損失,可見這樣的損失函數應該挺實用的。

代碼:

with tf.name_scope("code_recon_loss"):
code_sR_X2Y_recon_loss = tf.reduce_mean(tf.abs(sR_X2Y_recon-sR_X2Y)) ###公式(2) 共享部分的重構
code_sR_Y2X_recon_loss = tf.reduce_mean(tf.abs(sR_Y2X_recon-sR_Y2X))
code_eR_X2Y_recon_loss = tf.reduce_mean(tf.abs(eR_X2Y_recon-z))
code_eR_Y2X_recon_loss = tf.reduce_mean(tf.abs(eR_Y2X_recon-z)) ##雜訊重構
code_recon_loss = a.l1_weight*(code_sR_X2Y_recon_loss + code_sR_Y2X_recon_loss
+code_eR_X2Y_recon_loss + code_eR_Y2X_recon_loss)()

(4)WGAN-GP loss. 就是使用WAN-GP作為跨域重構的損失。

with tf.name_scope("generatorX2Y_loss"):
genX2Y_loss_GAN = -tf.reduce_mean(predict_fakeX2Y) ##公式(4)
genX2Y_loss = genX2Y_loss_GAN * a.gan_weight

with tf.name_scope("discriminatorX2Y_loss"):
discrimX2Y_loss = tf.reduce_mean(predict_fakeX2Y) - tf.reduce_mean(predict_realX2Y)
alpha = tf.random_uniform(shape=[a.batch_size,1], minval=0., maxval=1.)
differences = tf.reshape(outputsX2Y,[-1,OUTPUT_DIM])-tf.reshape(targetsX,[-1,OUTPUT_DIM])
interpolates = tf.reshape(targetsX, [-1,OUTPUT_DIM]) + (alpha*differences)
with tf.variable_scope("discriminatorX2Y", reuse=True):
gradients = tf.gradients(create_discriminator(inputsX,tf.reshape(interpolates,[-1,IMAGE_SIZE,IMAGE_SIZE,3]),a),
[interpolates])[0]
slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients),
reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2) ##公式(3)

(5)Cross-domain autoencoders。這部分很好理解,跨域自編碼。由於域共有的部分是域間都是一樣的,所以可以進行交換,進行交換的一個好處就是額外地激勵把域獨有的信息放在E^{x}中,最後得到:

代碼:

###跨域自編碼
######### CROSS-DOMAIN AUTOENCODERS
with tf.name_scope("autoencoderX"):
# Use here decoder Y2X but with shared input from X2Y encoder
with tf.variable_scope("generatorY2X_decoder", reuse=True):
out_channels = int(inputsX.get_shape()[-1])
auto_outputX = create_generator_decoder(sR_Y2X, eR_X2Y, out_channels, a) ##sR_Y2X是Y2X的共享部分,這裡是把這個共享部分替代了X2Y中的共享部分

with tf.name_scope("autoencoderY"):
# Use here decoder X2Y but with input from Y2X encoder
with tf.variable_scope("generatorX2Y_decoder", reuse=True):
out_channels = int(inputsY.get_shape()[-1])
auto_outputY = create_generator_decoder(sR_X2Y, eR_Y2X, out_channels, a) ##sR_X2Y是X2Y的共享部分,這裡是把這個共享部分替代了Y2X中的共享部分

with tf.name_scope("autoencoderX_loss"):
autoencoderX_loss = a.l1_weight*tf.reduce_mean(tf.abs(auto_outputX-inputsX))

with tf.name_scope("autoencoderY_loss"):
autoencoderY_loss = a.l1_weight*tf.reduce_mean(tf.abs(auto_outputY-inputsY)) ##公式(5)

最後感謝作者的開源代碼,幫助快速理解

歡迎討論交流!!

你的點贊會是我繼續分享的動力!

參考文獻:

[1]Image-to-image translation for cross-domain disentanglement


推薦閱讀:
相關文章