Zero-shot learning的一個關鍵技術是建立語義信息與圖像信息的對應關係,很多人會嘗試用生成模型來間接構造visual-semantic embedding space。有時間也看看相關的一些研究,整理一下。今天主要講講騰訊2017年發在IJCAI 上的一篇論文的理解和思考:Variational Deep embedding: An Unsupervised and Generative Approach to Clustering。
聚類是將相似的對象組合在一起的過程,是機器學習和人工智慧中最基本的任務之一。聚類方法可以分為兩大類:Similarity-based clustering 和 Feature-based clustering 。其中Similarity-based clustering比較出名的是Spectral Clustering (SC)方法;Feature-based clustering比較出名的方法有K-means和Gaussian Mixture Model (GMM)方法。最近,隨著deep learning在各大機器學習任務中取得巨大勝利,也有很多學者將deep learning應用於聚類的,deep learning可以提供好的表達,避免在原始數據上進行聚類。這篇文章的motivation就在於基於神經網路提出一種聚類方法,一是可以用deep learning學習一種好的表達,而這種表達可以獲取數據的統計結構;二是可以生成樣本。
高斯混合模型(Gaussian Mixture Model, GMM )可以通過最大期望(EM)優化演算法來進行聚類。 作者在這個基礎上引入了一個深度神經網路(a deep neural network, DNN )來建模數據生成的過程,從而可以用隨機梯度變分貝葉斯(Stochastic Gradient Variational Bayes, SGVB )來進行模型優化。換個角度看,也可以看作是在變分自編碼器(Variational Auto-Encoder, VAE )的基礎上,用GMM代替了之前的單個的高斯分布,因此提出的VaDE在優化的過程也可以用重參數技巧(reparameterization trick)。下面是文章摘要的表述:
Variational Deep embedding (VaDE ) 是一個基於變分自編碼器(Variational Auto-Encoder, VAE )的非監督的生成聚類(unsupervised generative clustering)方法。VaDE通過一個高斯混合模型(Gaussian Mixture Model, GMM )和一個深度神經網路(a deep neural network, DNN )來建模數據生成的過程。建模的過程分為3步:1)由GMM選擇一個cluster;2)根據cluster生成一個潛在的embedding;3)用DNN將embedding編碼為observable。VaDE的優化還是以VAE的形式,所以加了一個不同的DNN來將observable解碼為潛在embedding,這樣證據下限(evidence lower bound, ELBO )就能用隨機梯度變分貝葉斯(Stochastic Gradient Variational Bayes, SGVB )和重參數技巧(reparameterization trick)進行優化了。下圖是VaDE的結構圖。
The diagram of VaDE.
左上角的紅框即第一步,由高斯混合模型選取cluster;藍框即用該cluster生成潛在embedding ;黃框表示用DNN函數 對 進行編碼為 ;紫框表示用另一個DNN函數 將 解碼為 。從而最大化VaDE的證據下限ELBO。
VaDE generalizes VAE in that a Mixture-of-Gaussians prior replaces the single Gaussian prior.
從結構圖可以很直觀的看出,VaDE是用高斯混合先驗替代單一的高斯先驗後的VAE。作者提出將VAE和GMM結合起來,並通過使用SGVB和重參數trick來最大化ELBO從而優化VaDE。下面按照論文的結構,從生成過程、變分下屆以及理解變分下屆對VaDE的作用對VaDE進行介紹。為了更便於理解,我們結合下面的代碼來進行講解。
GuHongyang/VaDE-pytorch ?
github.com
1. The Generative Process
先假設這裡共有 類,下面我們以Mnist手寫數字為例,即 ,而觀測樣本 (Mnist的圖像維度是28*28,即 )的生成過程如下:
其中, 是預定義好的值,這裡我們知道 , 是聚類器 的先驗分布,則 ,那麼屬於每一類的概率和即為1( ), 是類別分布。選擇好cluster 後,則可以得到對應的高斯混合模型的均值 和方差 ,從而採樣到潛在變數 。用DNN 將 編碼為觀測樣本 了。聯合概率 可以分解為: 。因為他們是條件獨立的,所以概率又可以定義為:
在代碼里,也是基於mnist數據的(訓練集和測試集一同進行操作,共70000個數據),其先是構造了一個784-500-500-2000-10 -2000-500-500-784的autoencoder,通過最小化輸入圖片 和重構圖像 的均方誤差Mean Squared Error(MSE),來學習 的表達 。然後將學到的特徵向量 輸入到10個高斯混合的GMM中進行聚類操作。
值得一提的是,在計算聚類的準確率的時候,coder先是得到一個10*10的矩陣,矩陣每個元素 表示將第 類聚類為第 類的次數。下圖就是一個聚類結果矩陣:
因為是第一次接觸聚類,我以為聚類正確率的計算方式是將對角線的數加起來除以總數目。但實際上操作是先根據這個矩陣,找到最大值(6619),然後用6619減去矩陣的每個元素,這樣做是為了用指派問題 里的匈牙利法。 指派問題是有n項不同的任務,需要n個人分別完成其中的1項,每個人完成任務的時間不一樣,如何分配任務使得花費時間最少。這裡是要讓經可能多的數據找到組織,所以要先對這個矩陣反處理一下:
根據得到的結果[[0 7], [1 0], [2 2], [3 1], [4 3], [5 9], [6 8], [7 6], [8 5], [9 4]],就可以得到聚類結果啦。我理解的是,聚類與分類不同在於,分類你要明確的告訴我這是屬於哪一類,是正方形還是圓形,雖然正方形、圓形也是我們俗成約定的。而聚類只需要知道這些方方正正的是一類,那堆圈圈是是一類的,至於他們叫什麼,我不care,也是可以根據我們各自的喜好給他們賦予名稱的。在Zero-shot Learning裡面,我們在構建visual-semantic embedding的時候,我們一般也只比較關注這個圖像的語義特徵與那個屬性的語義特徵是一類的,至於這個屬性的語義特徵對應的標籤,就不是那麼重要了。
然後,這個GMM的權重,即我們要找到 ,GMM的均值和對數的方差即為 和 。然後coder在搭建那個autoencoder的時候,其實在中間接了兩個維度為10的layer,一個表示 ,一個表示 ,因為前面只用到了 層,且coder讓這兩層的參數一致,所以前面沒有提。根據 和 ,我們就可以根據 了。
2. Variational Lower Bound
通過最大化給定數據的似然,我們可以對VaDE進行調整。通過Jensens不等式,log-likelihood可以寫作:
其中, 是evidence lower bound, 是變分後驗,用於逼近真實後驗 ,其可以被分解為:
最後,根據上面的公式, 就可以被寫為:
與VAE類似,作者也用一個神經網路 來建模 :
最後,通過SGVB和reparameterization trick, 可以被寫為:
同樣結合代碼來看這個loss function。公式12行的第一行是重構誤差。將前面第一步得到的 輸入到decoder里,從而得到重構圖片 ,即公式里的 (對應公式13、14,吐槽一下,這個符號用的太容易混淆了)。 是Monte Carlo Samples的個數,代碼里是設置為1的。 是輸入、重構向量的維度,即784。這裡用Binary Cross Entropy計算 和 之間的重構誤差。然後根據 和前面得到的GMM的 、 和 ,可以計算公式12裡面第二行的 。公式16給出了其計算公式:
公式12裡面的 , 。第三行的計算也就很容易了。 代表的都是cluster的個數,這裡為10。通過優化這個 誤差,可以更新整個VaDE模型。而最終的聚類結果則通過 得到。
3. Understanding the ELBO of VaDE
公式 又可以寫成下面的樣子:
第一項是重構誤差,通過這項優化,可以讓VaDE更好的對數據進行表達。第二項是KL散度,通過約束混合高斯Mixture-of-Gaussians (MoG)的先驗 和變分先驗 的距離,來使latent embedding 分布在一個MoG的平均場上。為了證明第二項KL散度的重要性,作者還對比了各種類似的模型,如autoencoder+GMM,VAE+GMM等。下圖是一個在MNIST數據上的聚類準確率的對比結果:
下圖是生成結果,效果還是很nice的:
這篇論文看的還不是很明白,特別是 相關的推導那裡。如果只是按照公式進行代碼復現,也許沒有問題,但如果想要借鑒這種思想應用於ZSL中,還是需要再仔細的揣摩揣摩的。對聚類和各種概率模型的了解不深,可能有很多錯誤,還請大家指正。也歡迎大家一起討論呀~
推薦閱讀:
Please enable JavaScript to view the comments powered by Disqus.