論文題目:Attentive Generative Adversarial Network for Raindrop Removal from A Single Image

論文下載地址:arxiv.org/abs/1711.1009

Github代碼:github.com/MaybeShewill

數據集下載地址: drive.google.com/drive/

OUCMachineLearning/OUCML?

github.com
圖標

這是一篇結合attention機制和GAN去除雨滴的文章。從上圖可以看出這個方法去除雨滴的效果非常好。下面介紹一下這個網路。

該網路共包含三個部分,分別為: Attention-recurrent Network,Context Autoencoder 和 Dicriminator Network。前兩個部分構成了生成器。

第一部分主要的工作是做檢測(即檢測雨滴在圖片中的位置),然後生成 attention map,它由先前的圖片中得出,每一層由若干個Residual Block,一個LSTM和一個Convs組成。需要說明的是,圖中的每一個「Block」由5層ResNet組成,其作用是得到輸入圖片的精確特徵與前一個Block的模。具體做法是首先使用 Residual block 從雨滴圖片中抽取 feature,漸進式地使用 Convs 來檢測 attentive 的區域。訓練數據集中圖片都是成對的,所以可以很容易計算出相應的 mask(M),由此可以構建出 Loss 函數;由於不同的 attention 網路刻畫 feature 的準確度不同,所以給每個 loss 一個指數的衰減。相應的 loss 函數如下:

LSTM公式如下:

隨後將 attention map 和雨滴圖像一起送給 autoencoder,生成去雨滴的超解析度圖像。由於attention map的給出,這一部分的工作類似於將attention map中attention值較高的部分通過該部分周圍的圖片信息形成的新的色塊進行替換,從而實現圖片信息的還原。autoencoder 的結構用了 16 個 Con和 Relu。為了避免網路本身造成的 blur,作者使用了 skip connection,因為在低級層次會帶來很好的效果。在構建 loss 方面,除了多尺度的考慮,還加上了一個高精度的 loss,分別是:Multi-scale loss 和perceptual loss.

Multi-scale loss:

Perceptual loss:

結合兩個部分,得到生成器的損失,如下:

最後一個是 discriminator。這個步驟有兩部分,一種是只使用 autoencoder 生成的無雨滴圖像,進行判斷;另一種則是加入 attention map 作為指導。

判別器的損失函數是:

綜上,整體損失為:

下面說一下實驗部分。

生成器中殘差塊的實現:

def _residual_block(self, input_tensor, name):
"""
attentive recurrent net中的residual block
:param input_tensor:
:param name:
:return:
"""
output = None
with tf.variable_scope(name):
inputs = input_tensor
shortcut = input_tensor
for i in range(5):
if i == 0:
inputs = self.conv2d(inputdata=inputs,out_channel=32,
nel_size=3,padding=SAME,stride=1,use_bias=False,
e=block_{:d}_conv_1.format(i))
# TODO reimplement residual block
inputs = self.lrelu(inputdata=inputs, name=block_{:d}_relu_1.format(i + 1))
output = inputs
shortcut = output
else:
inputs = self.conv2d(inputdata=inputs,out_channel=32,
kernel_size=1, padding=SAME, stride=1,
use_bias=False, name=block_{:d}_conv_1.format(i))
inputs = self.lrelu(inputdata=inputs, name=block_{:d}_conv_1.format(i + 1))
inputs = self.conv2d(inputdata=inputs, out_channel=32,
kernel_size=1, padding=SAME, stride=1,
use_bias=False, name=block_{:d}_conv_2.format(i))
inputs = self.lrelu(inputdata=inputs, name=block_{:d}_conv_2.format(i + 1))

output = self.lrelu(inputdata=tf.add(inputs, shortcut),
name=block_{:d}_add.format(i))
shortcut = output
return output

注意力機制的實現:

def build_attentive_rnn(self, input_tensor, name, reuse=False)
"""
Generator的attentive recurrent部分, 主要是為了找到attention
:param input_tenso
:param nam
:param reus
:retur
"""
[batch_size, tensor_h, tensor_w, _] = input_tensor.get_shape().as_lis
with tf.variable_scope(name, reuse=reuse)
init_attention_map = tf.constant(0.5, dtype=tf.float32
shape=[batch_size, tensor_h, tensor_w, 1]
init_cell_state = tf.constant(0.0, dtype=tf.float32
shape=[batch_size, tensor_h, tensor_w, 32]
init_lstm_feats = tf.constant(0.0, dtype=tf.float32
shape=[batch_size, tensor_h, tensor_w, 32]
attention_map_list = [

**for** i **in** range(4)
attention_input = tf.concat((input_tensor, init_attention_map), axis=-1
conv_feats = self._residual_block(input_tensor=attention_input, name=residual_block_{:d}.format(i + 1)
lstm_ret = self._conv_lstm(input_tensor=conv_feats
input_cell_state=init_cell_state
name=conv_lstm_block_{:d}.format(i +
init_attention_map = lstm_ret[attention_map
init_cell_state = lstm_ret[cell_state
init_lstm_feats = lstm_ret[lstm_feats
attention_map_list.append(lstm_ret[attention_map]
ret =
final_attention_map: init_attention_map
final_lstm_feats: init_lstm_feats
attention_map_list: attention_map_lis

**return** ret

自編碼器的實現:

def build_autoencoder(self, input_tensor, name, reuse=False):

"""
Generator的autoencoder部分, 負責獲取圖像上下文信息
:param input_tensor:
:param name:
:param reuse:
:**return**:
"""
with tf.variable_scope(name, reuse=reuse):
conv_1 = self.conv2d(inputdata=input_tensor, out_channel=64, kernel_size=5, padding=SAME, stride=1, use_bias=False, name=conv_1)
relu_1 = self.lrelu(inputdata=conv_1, name=relu_1)
conv_2 = self.conv2d(inputdata=relu_1, out_channel=128, kernel_size=3, padding=SAME, stride=2, use_bias=False, name=conv_2)
relu_2 = self.lrelu(inputdata=conv_2, name=relu_2)
conv_3 = self.conv2d(inputdata=relu_2, out_channel=128, kernel_size=3, padding=SAME, stride=1, use_bias=False, name=conv_3)
relu_3 = self.lrelu(inputdata=conv_3, name=relu_3)
conv_4 = self.conv2d(inputdata=relu_3, out_channel=128, kernel_size=3, padding=SAME, stride=2, use_bias=False, name=conv_4)
relu_4 = self.lrelu(inputdata=conv_4, name=relu_4)
conv_5 = self.conv2d(inputdata=relu_4, out_channel=256, kernel_size=3, padding=SAME, stride=1, use_bias=False, name=conv_5)
relu_5 = self.lrelu(inputdata=conv_5, name=relu_5)
conv_6 = self.conv2d(inputdata=relu_5, out_channel=256, kernel_size=3, padding=SAME, stride=1, use_bias=False, name=conv_6)
relu_6 = self.lrelu(inputdata=conv_6, name=relu_6)
dia_conv1 = self.dilation_conv(input_tensor=relu_6, k_size=3, out_dims=256, rate=2, padding=SAME, use_bias=False, name= dia_conv_1)
relu_7 = self.lrelu(dia_conv1, name=relu_7)
dia_conv2 = self.dilation_conv(input_tensor=relu_7, k_size=3, out_dims=256, rate=4, padding=SAME, use_bias=False, name= dia_conv_2)
relu_8 = self.lrelu(dia_conv2, name=relu_8)
dia_conv3 = self.dilation_conv(input_tensor=relu_8, k_size=3, out_dims=256, rate=8, padding=SAME, use_bias=False, name= dia_conv_3)
relu_9 = self.lrelu(dia_conv3, name=relu_9)
dia_conv4 = self.dilation_conv(input_tensor=relu_9, k_size=3, out_dims=256, rate=16, padding=SAME, use_bias=False, name= dia_conv_4)
relu_10 = self.lrelu(dia_conv4, name=relu_10)
conv_7 = self.conv2d(inputdata=relu_10, out_channel=256, kernel_size=3, padding=SAME, use_bias=False, stride=1,name= conv_7)
relu_11 = self.lrelu(inputdata=conv_7, name=relu_11)
conv_8 = self.conv2d(inputdata=relu_11, out_channel=256, kernel_size=3, padding=SAME, use_bias=False, stride=1,name= conv_8)
relu_12 = self.lrelu(inputdata=conv_8, name=relu_12)
deconv_1 = self.deconv2d(inputdata=relu_12, out_channel=128, kernel_size=4, stride=2, padding=SAME, use_bias=False, name= deconv_1)
avg_pool_1 = self.avgpooling(inputdata=deconv_1, kernel_size=2, stride=1, padding=SAME, name=avg_pool_1)
relu_13 = self.lrelu(inputdata=avg_pool_1, name=relu_13)
conv_9 = self.conv2d(inputdata=tf.add(relu_13, relu_3), out_chan nel=128, kernel_size=3,padding=SAME, stride=1, use_bias=False,name= conv_9)
relu_14 = self.lrelu(inputdata=conv_9, name=relu_14)
deconv_2 = self.deconv2d(inputdata=relu_14, out_channel=64, kernel_size=4, stride=2, padding=SAME, use_bias=False, name=deconv_2)
avg_pool_2 = self.avgpooling(inputdata=deconv_2, kernel_size=2, stride=1, padding=SAME, name=avg_pool_2)
relu_15 = self.lrelu(inputdata=avg_pool_2, name=relu_15)
conv_10 = self.conv2d(inputdata=tf.add(relu_15, relu_1), out_channel=32, kernel_size=3, padding=SAME, stride=1, use_bias=False, name=conv_10)
relu_16 = self.lrelu(inputdata=conv_10, name=relu_16)
skip_output_1 = self.conv2d(inputdata=relu_12, out_channel=3, kernel_size=3, padding=SAME, stride=1, use_bias=False,name=skip_ouput_1)
skip_output_2 = self.conv2d(inputdata=relu_14, out_channel=3, kernel_size=3, padding=SAME, stride=1, use_bias=False,name=skip_ouput_2)
skip_output_3 = self.conv2d(inputdata=relu_16, out_channel=3, kernel_size=3, padding=SAME, stride=1, use_bias=False,name=skip_ouput_3)
# 傳統GAN輸出層都使用tanh函數激活
skip_output_3 = tf.nn.tanh(skip_output_3, name=skip_output_3_tanh)


ret = {
skip_1: skip_output_1,
skip_2: skip_output_2,
skip_3: skip_output_3
}
rn ret

判別器的實現:

with tf.variable_scope(name, reuse=reuse):
conv_stage_1 = self._conv_stage(input_tensor=input_tensor, k_size=5, stride=1, out_dims=8, group_size=0, name=conv_stage_1)
conv_stage_2 = self._conv_stage(input_tensor=conv_stage_1, k_size=5, stride=1, out_dims=16, group_size=0, name=conv_stage_2)
conv_stage_3 = self._conv_stage(input_tensor=conv_stage_2, k_size=5, stride=1, out_dims=32, group_size=0, name=conv_stage_3)
conv_stage_4 = self._conv_stage(input_tensor=conv_stage_3, k_size=5, stride=1, out_dims=64, group_size=0, name=conv_stage_4)
conv_stage_5 = self._conv_stage(input_tensor=conv_stage_4, k_size=5, stride=1, out_dims=128, group_size=0, name=conv_stage_5)
conv_stage_6 = self._conv_stage(input_tensor=conv_stage_5, k_size=5, stride=1, out_dims=128, group_size=0, name=conv_stage_6)
attention_map = self.conv2d(inputdata=conv_stage_6, out_channel=1, kernel_size=5, padding=SAME, stride=1, use_bias=False, name=attention_map)
conv_stage_7 = self._conv_stage(input_tensor=attention_map * conv_stage_6, k_size=5, stride=4, out_dims=64, group_size=0, name=conv_stage_7)
onv_stage_8 = self._conv_stage(input_tensor=conv_stage_7, k_size=5,stride=4, out_dims=64, group_size=0, name=conv_stage_8)
onv_stage_9 = self._conv_stage(input_tensor=conv_stage_8, k_size=5, stride=4, out_dims=32, group_size=0, name=conv_stage_9)
c_1 = self.fullyconnect(inputdata=conv_stage_9, out_dim=1024, use_bias=False, name=fc_1)
c_2 = self.fullyconnect(inputdata=fc_1, out_dim=1, use_bias=False, name=fc_2)
c_out = self.sigmoid(inputdata=fc_2, name=fc_out)
c_out = tf.where(tf.not_equal(fc_out, 1.0), fc_out, fc_out - 0.0000001)
c_out = tf.where(tf.not_equal(fc_out, 0.0), fc_out, fc_out + 0.0000001)
return fc_out, attention_map, fc_2

想要了解更多代碼請看github

OUCMachineLearning/OUCML?

github.com圖標
推薦閱讀:

相关文章