論文題目:Attentive Generative Adversarial Network for Raindrop Removal from A Single Image
論文下載地址:https://arxiv.org/abs/1711.10098
Github代碼:https://github.com/MaybeShewill-CV/attentive-gan-derainnet
數據集下載地址: https://drive.google.com/drive/folders/1e7R76s6vwUJxILOcAsthgDLPSnOrQ49K
這是一篇結合attention機制和GAN去除雨滴的文章。從上圖可以看出這個方法去除雨滴的效果非常好。下面介紹一下這個網路。
該網路共包含三個部分,分別為: Attention-recurrent Network,Context Autoencoder 和 Dicriminator Network。前兩個部分構成了生成器。
第一部分主要的工作是做檢測(即檢測雨滴在圖片中的位置),然後生成 attention map,它由先前的圖片中得出,每一層由若干個Residual Block,一個LSTM和一個Convs組成。需要說明的是,圖中的每一個「Block」由5層ResNet組成,其作用是得到輸入圖片的精確特徵與前一個Block的模。具體做法是首先使用 Residual block 從雨滴圖片中抽取 feature,漸進式地使用 Convs 來檢測 attentive 的區域。訓練數據集中圖片都是成對的,所以可以很容易計算出相應的 mask(M),由此可以構建出 Loss 函數;由於不同的 attention 網路刻畫 feature 的準確度不同,所以給每個 loss 一個指數的衰減。相應的 loss 函數如下:
LSTM公式如下:
隨後將 attention map 和雨滴圖像一起送給 autoencoder,生成去雨滴的超解析度圖像。由於attention map的給出,這一部分的工作類似於將attention map中attention值較高的部分通過該部分周圍的圖片信息形成的新的色塊進行替換,從而實現圖片信息的還原。autoencoder 的結構用了 16 個 Con和 Relu。為了避免網路本身造成的 blur,作者使用了 skip connection,因為在低級層次會帶來很好的效果。在構建 loss 方面,除了多尺度的考慮,還加上了一個高精度的 loss,分別是:Multi-scale loss 和perceptual loss.
Multi-scale loss:
Perceptual loss:
結合兩個部分,得到生成器的損失,如下:
最後一個是 discriminator。這個步驟有兩部分,一種是只使用 autoencoder 生成的無雨滴圖像,進行判斷;另一種則是加入 attention map 作為指導。
判別器的損失函數是:
綜上,整體損失為:
下面說一下實驗部分。
生成器中殘差塊的實現:
def _residual_block(self, input_tensor, name): """ attentive recurrent net中的residual block :param input_tensor: :param name: :return: """ output = None with tf.variable_scope(name): inputs = input_tensor shortcut = input_tensor for i in range(5): if i == 0: inputs = self.conv2d(inputdata=inputs,out_channel=32, nel_size=3,padding=SAME,stride=1,use_bias=False, e=block_{:d}_conv_1.format(i)) # TODO reimplement residual block inputs = self.lrelu(inputdata=inputs, name=block_{:d}_relu_1.format(i + 1)) output = inputs shortcut = output else: inputs = self.conv2d(inputdata=inputs,out_channel=32, kernel_size=1, padding=SAME, stride=1, use_bias=False, name=block_{:d}_conv_1.format(i)) inputs = self.lrelu(inputdata=inputs, name=block_{:d}_conv_1.format(i + 1)) inputs = self.conv2d(inputdata=inputs, out_channel=32, kernel_size=1, padding=SAME, stride=1, use_bias=False, name=block_{:d}_conv_2.format(i)) inputs = self.lrelu(inputdata=inputs, name=block_{:d}_conv_2.format(i + 1))
output = self.lrelu(inputdata=tf.add(inputs, shortcut), name=block_{:d}_add.format(i)) shortcut = output return output
def build_attentive_rnn(self, input_tensor, name, reuse=False) """ Generator的attentive recurrent部分, 主要是為了找到attention :param input_tenso :param nam :param reus :retur """ [batch_size, tensor_h, tensor_w, _] = input_tensor.get_shape().as_lis with tf.variable_scope(name, reuse=reuse) init_attention_map = tf.constant(0.5, dtype=tf.float32 shape=[batch_size, tensor_h, tensor_w, 1] init_cell_state = tf.constant(0.0, dtype=tf.float32 shape=[batch_size, tensor_h, tensor_w, 32] init_lstm_feats = tf.constant(0.0, dtype=tf.float32 shape=[batch_size, tensor_h, tensor_w, 32] attention_map_list = [
**for** i **in** range(4) attention_input = tf.concat((input_tensor, init_attention_map), axis=-1 conv_feats = self._residual_block(input_tensor=attention_input, name=residual_block_{:d}.format(i + 1) lstm_ret = self._conv_lstm(input_tensor=conv_feats input_cell_state=init_cell_state name=conv_lstm_block_{:d}.format(i + init_attention_map = lstm_ret[attention_map init_cell_state = lstm_ret[cell_state init_lstm_feats = lstm_ret[lstm_feats attention_map_list.append(lstm_ret[attention_map] ret = final_attention_map: init_attention_map final_lstm_feats: init_lstm_feats attention_map_list: attention_map_lis
**return** ret
def build_autoencoder(self, input_tensor, name, reuse=False):
""" Generator的autoencoder部分, 負責獲取圖像上下文信息 :param input_tensor: :param name: :param reuse: :**return**: """ with tf.variable_scope(name, reuse=reuse): conv_1 = self.conv2d(inputdata=input_tensor, out_channel=64, kernel_size=5, padding=SAME, stride=1, use_bias=False, name=conv_1) relu_1 = self.lrelu(inputdata=conv_1, name=relu_1) conv_2 = self.conv2d(inputdata=relu_1, out_channel=128, kernel_size=3, padding=SAME, stride=2, use_bias=False, name=conv_2) relu_2 = self.lrelu(inputdata=conv_2, name=relu_2) conv_3 = self.conv2d(inputdata=relu_2, out_channel=128, kernel_size=3, padding=SAME, stride=1, use_bias=False, name=conv_3) relu_3 = self.lrelu(inputdata=conv_3, name=relu_3) conv_4 = self.conv2d(inputdata=relu_3, out_channel=128, kernel_size=3, padding=SAME, stride=2, use_bias=False, name=conv_4) relu_4 = self.lrelu(inputdata=conv_4, name=relu_4) conv_5 = self.conv2d(inputdata=relu_4, out_channel=256, kernel_size=3, padding=SAME, stride=1, use_bias=False, name=conv_5) relu_5 = self.lrelu(inputdata=conv_5, name=relu_5) conv_6 = self.conv2d(inputdata=relu_5, out_channel=256, kernel_size=3, padding=SAME, stride=1, use_bias=False, name=conv_6) relu_6 = self.lrelu(inputdata=conv_6, name=relu_6) dia_conv1 = self.dilation_conv(input_tensor=relu_6, k_size=3, out_dims=256, rate=2, padding=SAME, use_bias=False, name= dia_conv_1) relu_7 = self.lrelu(dia_conv1, name=relu_7) dia_conv2 = self.dilation_conv(input_tensor=relu_7, k_size=3, out_dims=256, rate=4, padding=SAME, use_bias=False, name= dia_conv_2) relu_8 = self.lrelu(dia_conv2, name=relu_8) dia_conv3 = self.dilation_conv(input_tensor=relu_8, k_size=3, out_dims=256, rate=8, padding=SAME, use_bias=False, name= dia_conv_3) relu_9 = self.lrelu(dia_conv3, name=relu_9) dia_conv4 = self.dilation_conv(input_tensor=relu_9, k_size=3, out_dims=256, rate=16, padding=SAME, use_bias=False, name= dia_conv_4) relu_10 = self.lrelu(dia_conv4, name=relu_10) conv_7 = self.conv2d(inputdata=relu_10, out_channel=256, kernel_size=3, padding=SAME, use_bias=False, stride=1,name= conv_7) relu_11 = self.lrelu(inputdata=conv_7, name=relu_11) conv_8 = self.conv2d(inputdata=relu_11, out_channel=256, kernel_size=3, padding=SAME, use_bias=False, stride=1,name= conv_8) relu_12 = self.lrelu(inputdata=conv_8, name=relu_12) deconv_1 = self.deconv2d(inputdata=relu_12, out_channel=128, kernel_size=4, stride=2, padding=SAME, use_bias=False, name= deconv_1) avg_pool_1 = self.avgpooling(inputdata=deconv_1, kernel_size=2, stride=1, padding=SAME, name=avg_pool_1) relu_13 = self.lrelu(inputdata=avg_pool_1, name=relu_13) conv_9 = self.conv2d(inputdata=tf.add(relu_13, relu_3), out_chan nel=128, kernel_size=3,padding=SAME, stride=1, use_bias=False,name= conv_9) relu_14 = self.lrelu(inputdata=conv_9, name=relu_14) deconv_2 = self.deconv2d(inputdata=relu_14, out_channel=64, kernel_size=4, stride=2, padding=SAME, use_bias=False, name=deconv_2) avg_pool_2 = self.avgpooling(inputdata=deconv_2, kernel_size=2, stride=1, padding=SAME, name=avg_pool_2) relu_15 = self.lrelu(inputdata=avg_pool_2, name=relu_15) conv_10 = self.conv2d(inputdata=tf.add(relu_15, relu_1), out_channel=32, kernel_size=3, padding=SAME, stride=1, use_bias=False, name=conv_10) relu_16 = self.lrelu(inputdata=conv_10, name=relu_16) skip_output_1 = self.conv2d(inputdata=relu_12, out_channel=3, kernel_size=3, padding=SAME, stride=1, use_bias=False,name=skip_ouput_1) skip_output_2 = self.conv2d(inputdata=relu_14, out_channel=3, kernel_size=3, padding=SAME, stride=1, use_bias=False,name=skip_ouput_2) skip_output_3 = self.conv2d(inputdata=relu_16, out_channel=3, kernel_size=3, padding=SAME, stride=1, use_bias=False,name=skip_ouput_3) # 傳統GAN輸出層都使用tanh函數激活 skip_output_3 = tf.nn.tanh(skip_output_3, name=skip_output_3_tanh) ret = { skip_1: skip_output_1, skip_2: skip_output_2, skip_3: skip_output_3 } rn ret
with tf.variable_scope(name, reuse=reuse): conv_stage_1 = self._conv_stage(input_tensor=input_tensor, k_size=5, stride=1, out_dims=8, group_size=0, name=conv_stage_1) conv_stage_2 = self._conv_stage(input_tensor=conv_stage_1, k_size=5, stride=1, out_dims=16, group_size=0, name=conv_stage_2) conv_stage_3 = self._conv_stage(input_tensor=conv_stage_2, k_size=5, stride=1, out_dims=32, group_size=0, name=conv_stage_3) conv_stage_4 = self._conv_stage(input_tensor=conv_stage_3, k_size=5, stride=1, out_dims=64, group_size=0, name=conv_stage_4) conv_stage_5 = self._conv_stage(input_tensor=conv_stage_4, k_size=5, stride=1, out_dims=128, group_size=0, name=conv_stage_5) conv_stage_6 = self._conv_stage(input_tensor=conv_stage_5, k_size=5, stride=1, out_dims=128, group_size=0, name=conv_stage_6) attention_map = self.conv2d(inputdata=conv_stage_6, out_channel=1, kernel_size=5, padding=SAME, stride=1, use_bias=False, name=attention_map) conv_stage_7 = self._conv_stage(input_tensor=attention_map * conv_stage_6, k_size=5, stride=4, out_dims=64, group_size=0, name=conv_stage_7) onv_stage_8 = self._conv_stage(input_tensor=conv_stage_7, k_size=5,stride=4, out_dims=64, group_size=0, name=conv_stage_8) onv_stage_9 = self._conv_stage(input_tensor=conv_stage_8, k_size=5, stride=4, out_dims=32, group_size=0, name=conv_stage_9) c_1 = self.fullyconnect(inputdata=conv_stage_9, out_dim=1024, use_bias=False, name=fc_1) c_2 = self.fullyconnect(inputdata=fc_1, out_dim=1, use_bias=False, name=fc_2) c_out = self.sigmoid(inputdata=fc_2, name=fc_out) c_out = tf.where(tf.not_equal(fc_out, 1.0), fc_out, fc_out - 0.0000001) c_out = tf.where(tf.not_equal(fc_out, 0.0), fc_out, fc_out + 0.0000001) return fc_out, attention_map, fc_2
OUCMachineLearning/OUCML?github.com 推薦閱讀: