文本生成中的decoding strategy整理
最近基於Seq2Seq模型嘗試了各種不同的decoding strategy,在這裡總結記錄下各種演算法的特點和效果。
文本生成中的decoding strategy主要可以分為兩大類:
- Argmax Decoding: 主要包括beam search, class-factored softmax等
- Stochastic Decoding: 主要包括temperature sampling, top-k sampling等
下文將詳細介紹這兩類演算法的特點和應用。
1. 問題定義
在Seq2Seq模型中,RNN Encoder對輸入句子進行編碼,生成一個大小固定的hidden state ;基於輸入句子的hidden state 和先前生成的第1到t-1個詞 ,RNN Decoder會生成當前第t個詞的hidden state ,最後通過softmax函數得到第t個詞 的vocabulary probability distribution 。
兩類decoding strategy的主要區別就在於,如何從vocabulary probability distribution 中選取一個詞 :
- Argmax Decoding的做法是選擇詞表中probability最大的詞,即 ;
- Stochastic Decoding則是基於概率分佈 隨機sample一個詞 ,即 。
2. Argmax Decoding
在大多數文本生成任務中,大家都直接採用Argmax Decoding,最常見的就是beam search。但如果我們的vocabulary size較大,達到了50k甚至150k,在softmax層的運算量就會變得非常大。因為 ,計算分母時需要對vocabulary中的每一個詞都進行計算。
因此在這裡介紹兩種時間複雜度更低,效果更好的演算法:
(1) Class-factored Softmax
這裡我們將原本只有一層的softmax layer擴展為兩層:第一層為cluster層,每個cluster中包含一組語意相近的詞,每個詞只出現在一個cluster中;第二層為word層,輸出最後decode的詞 ,也就是 。儘管cluster層和word層分別包含一個softmax layer,但每一層softmax的分母部分的計算量都大大縮小了。詳見論文:Pragmatic Neural Language Modelling in Machine Translation
這裡需要注意的是,cluster的選取對decoding的效果有很大的影響,所以需要選擇合適的聚類演算法來pre-train高質量的cluster,論文中選用的是Brown cluster。
(2) Pointer-generator Network
這裡儘管只使用一層softmax layer,但引入了一個非常強大的copy network,模型訓練速度和生成句子的質量都顯著高於Seq2Seq + Standard Softmax。
簡單來說,我們首先建立一個很小(如5k)的高頻詞vocabulary,然後建立一個Attention layer,得到輸入句子的Attention distribution,在decoding階段,若vocabulary中不存在需要decode的詞 ,則直接從輸入句子的Attention distribution中copy 的attention weight作為 。詳見論文:Get To The Point: Summarization with Pointer-Generator Networks
3. Stochastic Decoding
但實際上Argmax Decoding常常會導致模型生成重複的句子,如"I dont know. I dont know. I dont know...."。因為在模型中: 。
一個可行的解決方案就是在decoding過程中引入randomness,但是The Curious Case of Neural Text Degeneration這篇論文指出,sampling from full vocabulary distribution生成的句子會非常的雜亂無章,因為當vocabulary size非常大時,每個詞的probability都會變得很小,這時模型會有非常高的可能性sample到一個tail distribution中的詞,一旦sample到了tail distribution中一個和前文非常不相關的詞,很有可能接下來的詞都受其影響,使得句子脫離原本的意思。
因此,我們需要sampling from truncated vocabulary distribution,比較常見的演算法主要有以下幾種:
(1) Temperature Sampling
在softmax中引入一個temperature t來改變vocabulary probability distribution,使其更偏向high probability words: ,其中 。
當 時,就變成了greedy decoding;當 時,就變成了uniform sampling。這樣通過調整t的大小,就可以避免sampling from tail distribution。詳見論文:The Curious Case of Neural Text Degeneration
(2) Top-k Sampling
這個方法比上一個更加簡單,也更為有效。在decoding過程中,從 中選取probability最高的前k個tokens,把它們的probability加總得到 ,然後將 調整為 ,其中 ,最後從 中sample一個token作為output token。詳見論文:Hierarchical Neural Story Generation
但Top-k Sampling存在的問題是,常數k是提前給定的值,對於長短大小不一,語境不同的句子,我們可能有時需要比k更多的tokens。