首先,這篇博文整理自谷歌開源的神經機器翻譯項目Neural Machine Translation (seq2seq) Tutorial。如果你直接克隆這個項目按照Tutorial中的說明操作即可,那麼也就不用再往下看了。
而之所以寫這篇博文的目的是,雖然Seq2Seq的原理並不太難,但是在用Tensorflow實現起來的時候卻不那麼容易。即使谷歌開源了源碼,但是對於初學者來說面對複雜的工程結構文件,依舊是一頭霧水(看來好幾天,源碼也沒弄懂)。於是筆者就根據Tutorial中的說明以及各種摸索,終於搭建出了一個簡單的翻譯模型。下面就來大致介紹整個模型的搭建過程,數據的預處理,以及一些重要參數的說明等等。
由於筆者本身不搞自然語言這方面的內容,只是想學習這方面技術在Tensorflow中的使用,所以對於如何。
Seq2Seq模型的主要原理如圖p0105所示,先是一個Sequence通過RNN網路結構編碼(左邊藍色部分)後得到 「thought vector"(圖中白色矩形框,後稱為中間向量),也就是說此時的中間向量包含了輸入向量的所有信息,可以將其視為一個「加密」的過程。緊接著就是將中間向量再次餵給另外一個RNN網路對其進行解碼(右邊棕色部分),然後得到解碼後的輸出,可將其視為一個「解密」的過程。可能這個圖太抽象了,我們再來進一步細化這個圖:?
?如圖p0106所示,輸入部分以每個單詞作為RNN對應每個時刻的輸入,而輸出部分呢則以RNN上一時刻的輸出作為下一時刻的輸入,直到輸出為終止符"</s>"為止。有沒有發現,這同我們之間介紹的用LSTM來生成唐詩的原理一模一樣? 但是呢請注意這個問題:在訓練的時候我們並不能保證解碼部分每個時刻的輸出就是正確的。換句話說就是,假設第一個時刻的輸出為「我",然後接著將」我「餵給下一時刻,但此時預測的結果為」你好「,然後再把」你好「餵給下一時刻預測出」明天「,最後將」明天「餵給下一此時預測出」</s>「結束。也就是最終預測的序列為?" 我 你好 明天 </s>" ,雖然這樣也能同正確標籤?"我 是 一個 學生 </s> "做交叉熵然後訓練網路,但這就導致訓練出來的網路可能效果不好。而再翻譯模型中,普遍的做法就是在訓練時,解碼部分每個時刻的輸入就是正確標籤,然後再將預測結果同正確標籤做交叉熵;而在預測(inference)時再採取上一時刻的輸出作為下一時刻的輸入(此時也沒有所謂的正確標籤)這一策略,如圖p0107所示:
1.2 前期準備
為了方便後面在介紹Tensorflow時一些函數的使用方法(參數的左右),在這裡首先來大致介紹以下幾個重要又不容易理解的變數。
從圖p0108可知,整個網路模型至少需要三個placeholder,即encoder_inputs,decoder_inputs,decoder_outputs,其分別為source input wors,target input words,target output words三個部分的輸入或輸出。同時,由於每個sequence的長短都是不一樣的,因此在NMT這個模型中,這三個地方的變數的shape都不是固定的。有人可能會說了,將所有的句子都Padding成一個長度不久行了嗎? 雖然來說理論上可以這樣,但是由於sequences之間長短相差太大(至少是在NMT中),如果所有sequence都padding成一個長度,效果肯定不好,所以NMT採取的做法是:只在同一個batch中保持所有sequence的長度一樣(不夠的以最長的為標準再padding),也就是說同一batch保持一致,不同batch之間可以不同。
placeholder
encoder_inputs,decoder_inputs,decoder_outputs
假設現在source input words中有一個batch,batch中有5個sequence,其長度分別為5,7,3,8,6,則:
encoder_inputs.shape=[8,5]; 指定了time_major=True(不明白time_major戳此處見第3點)
encoder_inputs.shape=[8,5]
time_major=True
time_major
source_lengths=[5,7,3,8,6]; 記錄每個sequence的長度
source_lengths=[5,7,3,8,6]
max_source_length=8; 記錄最長sequence的長度
max_source_length=8
以下以3個樣本為例來主要介紹一下數據預處理部分。
漢:[[你 是 誰 ?], [你 從 哪裡 來?],[你 要 到 哪裡 去 ?]]
英: [[who are you ?],[where are you from ?],[where are you going ?]]
src_vocab_table,tgt_vocab_table
UNK,SOS,EOS,PAD
source_inputs=[[4,5,6,7,3,3],[4,8,9,10,7,3],[4,11,12,9,13,7]] source_lengths=[4,5,6] max_source_length = 6 ? target_inputs=[[1,4,5,6,7,3],[1,8,5,6,9,7],[1,8,5,6,10,7]] target_lengths=[5,6,6] max_target_length=6 target_outputs=[[4,5,6,7,2,3],[8,5,6,9,7,2],[8,5,6,10,7,2]] ?
注意,對於target_inputs,target_outputs來說,一定是先加上起始符和終止符再padding.
target_inputs,target_outputs
encoder編碼部分和寫LSTM這種網路結構幾乎一樣,都是通過dynamic_rnn這個函數來完成的。目前發現唯一的區別在於此處多了一個參數sorce_lengths,其原因是因為每個sequence的長度不一樣(儘管每個baatch裏padding成一樣了),所以要告訴dynamic_rnn展開的時間維度。
LSTM
dynamic_rnn
sorce_lengths
def _build_encoder(self): def get_encoder_cell(rnn_size): lstm_cell = tf.nn.rnn_cell.LSTMCell(rnn_size) return lstm_cell ? encoder_cell = tf.nn.rnn_cell.MultiRNNCell([get_encoder_cell(self.encoder_rnn_size) for _ in range(self.encoder_rnn_layer)]) self.encoder_outputs, self.encoder_final_state = tf.nn.dynamic_rnn(cell=encoder_cell, nputs=self.encoder_emb_inp, sequence_length=self.source_lengths, time_major=True, dtype=tf.float32)
Note that sentences have different lengths to avoid wasting computation, we tell dynamic_rnn the exact source sentence lengths through source_sequence_length.
2.2 解碼decoder
在上面的1.1節中我們說到,NMT的解碼部分在實現的時候分為訓練和推斷(預測)兩個部分,因此對於這兩個部分也要分開來寫。
在訓練時通過TrainingHelper這個函數來構造一個輔助對象,達到給每個時刻輸入正確label的目的,然後通過BasicDecoder和dynamic_decoder進行解碼;而在預測時則通過GreedyEmbeddingHelper這個輔助對象來完成將上以時刻的輸出作為下一時刻的輸入這一步驟,然後同樣通過BasicDecoder和dynamic_decoder進行解碼。
TrainingHelper
BasicDecoder
dynamic_decoder
GreedyEmbeddingHelper
由於這部分代碼貼出來排版看起來很亂影響閱讀體驗,所以就不貼了,直接參考文末貼出的代碼即可。當然,接下來的就是構造損失函數等其它步驟了,參照代碼中的注釋即可。
總體來說對於使用Tensorflow來完成這個示例的難點在於一些參數的理解上,也就是1.1節中提到的幾個參數。只要把這幾個參數的含義弄明白了,照葫蘆畫瓢相對來說還是不那麼困難。
源碼戳此處