終於下決心寫點東西,seqGAN

來自專欄機器學習小白從入門到...?4 人贊了文章

seqGAN小筆記,裡面可能充斥著一個學渣捉急的腦迴路,望各位大牛幫忙堪錯,Orz,thx。

一直以來都覺得NLP是下一個AI的突破點,語言作為承載人類文明、智慧乃至意識的載體,凝聚著人類對這個世界的理解,也是人類文明傳播的主要工具,如果有朝一日計算機能夠理解自然語言,理解其中的含義,並能夠做出正確的響應,那麼可以說它已經具有了一定智能。因此一直以來想看一些關於對話方面的研究進展,因為對話不僅需要具備理解能力,同時需要具備決策能力,這不正是智能的體現嗎?

趁最近工作不忙讀了兩篇關於diglog的paper,一篇名為《Adversarial Learning for Neural Dialogue Generation》的文章吸引了我的注意,如題Adversarial Learning自然指的就是GAN,不出所料文中也是採用RL來處理離散不可導的問題。額這不就是一個seqGAN的應用嗎。因此又翻出seqGAN的paper來學習一下。以下只是一個簡單的筆記。

首先考慮文本生成是如何解決的,我們需要一個生成模型,額自然想到超新星GAN,當然你也可以是其他任何生成模型例如VAE,或者新出現的GLOW等。最近GAN似乎已經佔據了頂會論文的半壁江山,它的意義被人們稱為下一個深度學習?有必要再深入重新學習一下哈哈。那麼如何利用GAN來解決文本生成的問題呢?不難想到利用Generator來生成文本樣本,通過Discriminator來區分是真實樣本還是生成樣本。想法很自然但是其中不免出現一個GAN的根本性問題:如何解決離散數據梯度無法回傳給Generator的問題。(這一塊可以參考Role of RL in Text Generation by GAN(強化學習在生成對抗網路文本生成中扮演的角色),可以直接移步至此,作者對seqGAN寫得很清楚詳細)剛開始的時候我也產生了為什麼不可以把softmax直接傳給Discriminator的問題,原因是如果直接將softmax結果傳給Discriminator,則根本無需判斷生成分佈是否與真實分佈是否接近,而只需要判斷分佈是不是one-hot形式就可以了。其次seqGAN作者還指出Discriminator只能對完整的句子進行判斷,而無法判斷部分句子的好壞,而實際上一個句子並不是全部都很差,而僅僅其中部分不好而已。因此原始seqGAN的出現主要解決以上兩個問題。

首先如何解決第一個問題?seqGAN給出的答案是:RL。將Generator的優化轉化為最大化rewards,然後利用policy gredient優化Generator就可以了。GAN+RL,簡直不能更cool了,Orz。對於第二個問題,如何對部分文本進行評價呢?顯然我們不能直接將部分文本直接拋給Discriminator,因為Discriminator僅具備對完整句子評價的能力,給出的部分文本自然無法獲得高分。seqGAN是如何解決的呢?MC Search,直接多次採樣到末尾,額,低效但好使。具體推導還請移駕原文,這裡貼出原文中的一個圖:

到這裡已經對seqGAN有了整體初步的瞭解,下面結合代碼對整體進行一個梳理。

第一步:預訓練Generator

這裡很簡單,首先利用一個完美訓練的神經網路target_lstm生成一些真實數據,利用這些數據通過MLE訓練Generator。示例代碼中Generator就是一個簡單的lstm。

第二步:預訓練Discriminator

將target_lstm生成的真實數據作為正樣本,Generator生成的數據作為負樣本,訓練一個二分類的神經網路。示例代碼中Discriminator是一個具有highway的CNN。

第三步:開始進行對抗訓練,首先訓練一輪Generator

首先利用當前的Generator生成batch_size大小數據,利用rollout網路進行MC search並採樣句子餵給Discriminator進行評價返回rewards,最後根據rewards更新Generator參數。細緻一點,當我們生成到第t個詞後利用MC搜索出後續的n條路徑餵給Discriminator,並將反饋的rewards取平均。

rollout是什麼?可以將rollout理解為就是Generator,通過rollout進行採樣同時餵給Discriminator獲得rewards,原文將這樣的採樣策略稱為roll-out policy。那麼為什麼不直接使用Generator呢?這個我也表示不很清楚,希望有大牛進行解答。可以看到後續rollout在進行參數更新的時候和Generator並不一致,難道是並不希望它訓練的太好導致採樣結果差別不大?

參數更新部分的代碼中可以看到Generator的更新是將句子的似然乘以了rewards,就這樣?回過來看一眼推導。對rewards求梯度推導結果如圖。

結果不正是句子的log似然乘以rewards麼。

第四步:更新rollout參數

第五步:訓練n輪Discriminator

採用更新後的Generator生成一些新的負樣本對Discriminator訓練n輪

三、四、五步需要迭代多輪,這樣便完成了seqGAN的訓練。可以說seqGAN的設計很巧妙了,其中還有很多需要深入理解、慢慢消化的部分。

參考:

SeqGAN Sequence Generative Adversarial Nets with Policy Gradient

Adversarial Learning for Neural Dialogue Generation

胡楊:Role of RL in Text Generation by GAN(強化學習在生成對抗網路文本生成中扮演的角色)


推薦閱讀:
相關文章