CVPR2019年的一篇文章 RePr: Improved Training of Convolutional Filters
這篇文章初看abstract和introduction,差點以為是model pruning,看到後面發現是針對卷積神經網路的訓練方法,而且這個方法比較簡單,但文章通過大量的分析和實驗,驗證了提出的訓練方法非常有效,在cifar、ImageNet、VQA、object detection上漲點很多,個人覺得paper writing/representation做的非常好,ablation study做的非常充分,是我目前看過的CVPR2019中最好的一篇文章。
遺憾的是個人水平有限,無法深入的分析這篇文章,就談談自己的淺見,也算是自己的閱讀筆記。
先來看效果圖,用新的訓練方法,測試準確率遠遠超過了標準的訓練方法,是不是amazing?
下面看看在object detection上的測試結果,RePr在ResNet50-FPN上有4.1個點的提升,在ResNet101-FPN有2.8個點的提升,可以說是非常明顯了。
【Introduction】
卷積神經網路在視覺任務中取得了SOTA性能,我們會為不同的任務單獨設計不同的網路結構,雖然網路結構不同,但使用的優化方法都是一樣的,而且這些優化方法將網路權重視為單獨的個體,沒有考慮彼此之前的相關性。而事實上,網路權重之間是有很大聯繫的。為了獲取最好的性能,網路經常過參數化(over-parameterized)。然而即使是過參數化的網路,也會存在很多冗餘的參數。model pruning證明了一個大的網路可以通過丟棄一部分參數權重得到一個性能損失不大的小網路,從而實現網路壓縮和加速。
因此文章提出了一個新的訓練方法。既然網路中有些參數權重是多餘,那我們訓練的時候把他們丟棄(pruning),接著訓練剩下的網路,為了不損失模型的capacity,然後再把丟棄的參數拿回來,效果是不是會好一點呢?基於這個想法,文章作者任務有幾個重要的點:一是pruning哪些權重,而是如何再把丟棄的權重拿回來讓他們發揮更大的作用。本文的一個貢獻在於提出了一個metric,用於選擇哪些filters丟棄。同時作者指出,即使是一個參數很少( under-parameterized )的網路,也會出現學到冗餘的參數的情況,這不僅僅在多參數的網路中存在,原因就在於訓練過程低效。
【Motivation】
特徵之間的相關性越高,其泛化性能越差。為了降低特徵之間的相關性,有人提出了各種方法,比如在損失函數中加入特徵相關性的項,在優化目標函數的時候使的模型自動學習出低相關的特徵,然而並沒有什麼效果。還有通過分析不同層的特徵並進行聚類,這種方法在實際中不可行因為計算量巨大。還有人嘗試過loss添加正則項讓模型學習到相互正交的權重,最後也發現收效甚微。實驗發現,僅僅通過正則化項讓網路自動學習到正交的權重是不夠的,文章提出的訓練方法其實已經隱式地起到了正則化效果,並且對模型的收斂沒有任何影響。
為了說明即使是參數少的模型,由於訓練的低效,也會存在大量冗餘的卷積核,通過一個比較小的卷積神經網路,作者可視化了不同層的卷積核對性能的影響,如下圖右,layer2的大部分卷積核對性能的影響僅僅只有1%,即使是淺層的layer1,也存在大量不重要的權重。
【新的訓練方法:RePr】
訓練過程如下:方式比較簡單,先訓練整個網路,根據metric drop掉30%的filter,再訓練剩下的網路,再把drop的filter拿回來,用於現有filters正交的方式初始化。迭代這個過程N次。
演算法中最重要的其實這個metric,即如何選出需要drop的filters。
文章寫的很明白,一個layer的多個卷積核可以用一個matrix表示,也就是 ,先對 歸一化,再與自己的轉置相乘得到 ,這是一個 x 大小的matrix,第i行表示其他filter對第i個filter的projection,可以看成是相關性,如果是正交性的越大,那麼這個值就越小,一行的數值之和越小,說明其他filter與這個filter相關性越低。因此可以通過這個sum來對filter進行rank。
同時文章還說明了,計算這個metric是在一個layer內,但rank是在所有layer進行的,目的是為了不讓layer這個因數影響filter的rank,避開layer的差異性,同時也不引入過多的超參。
文章一個值得稱讚的點就是ablation study部分做的非常詳細而充分。文章做了大量的對比實驗,對該方法涉及的參數進行了討論,並對比了不同的optimization的影響,同時也比較了dropout、knowledge distillation,指出該方法不僅和他們有很大區別,與他們結合還能得到更好的結果。
【Results】
作者做了大量實驗,驗證該方法在cifar10,cifar100,imagenets上都能取得很好的性能。
這是RePr早其他task上的表現,VQA上有不同程度的漲點,效果明顯。
object detection上漲點達到了ResNet50 4.1個點,ResNet101 2.8個點,可是說是非常明顯了。
總的來說,這是一篇看起來真正做work的paper,做法簡單有效,實驗充分合理,相信很多人會去復現這篇paper,有些超參還是需要調一調的,具體效果如何還需要看實際情況,特別是detection部分,如果真的work,未來會成為刷SOTA的一個標配。
幾個疑問的點:
推薦閱讀: