本文截取自《PyTorch 模型訓練實用教程》,獲取全文pdf請點擊:https://github.com/tensor-yu/PyTorch_Tutorial

@[toc]

我們知道一個良好的權值初始化,可以使收斂速度加快,甚至可以獲得更好的精度。而在實際應用中,我們通常採用一個已經訓練模型的模型的權值參數作為我們模型的初始化參數,也稱之為Finetune,更寬泛的稱之為遷移學習。遷移學習中的Finetune技術,本質上就是讓我們新構建的模型,擁有一個較好的權值初始值。

finetune權值初始化三步曲,finetune就相當於給模型進行初始化,其流程共用三步:

第一步:保存模型,擁有一個預訓練模型; 第二步:載入模型,把預訓練模型中的權值取出來; 第三步:初始化,將權值對應的「放」到新模型中

一、Finetune之權值初始化

在進行finetune之前我們需要擁有一個模型或者是模型參數,因此需要了解如何保存模型。官方文檔中介紹了兩種保存模型的方法,一種是保存整個模型,另外一種是僅保存模型參數(官方推薦用這種方法),這裡採用官方推薦的方法。

第一步:保存模型參數

若擁有模型參數,可跳過這一步。
假設創建了一個net = Net(),並且經過訓練,通過以下方式保存:
torch.save(net.state_dict(), net_params.pkl)

第二步:載入模型

進行三步曲中的第二步,載入模型,這裡只是載入模型的參數:
pretrained_dict = torch.load(net_params.pkl)

第三步:初始化

進行三步曲中的第三步,將取到的權值,對應的放到新模型中:
首先我們創建新模型,並且獲取新模型的參數字典net_state_dict:
net = Net() # 創建net
net_state_dict = net.state_dict() # 獲取已創建net的state_dict

接著將pretrained_dict裏不屬於net_state_dict的鍵剔除掉:
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}

然後,用預訓練模型的參數字典 對 新模型的參數字典net_state_dict 進行更新:
net_state_dict.update(pretrained_dict_1)

最後,將更新了參數的字典 「放」回到網路中:
net.load_state_dict(net_state_dict)

這樣,利用預訓練模型參數對新模型的權值進行初始化過程就做完了。

採用finetune的訓練過程中,有時候希望前面層的學習率低一些,改變不要太大,而後面的全連接層的學習率相對大一些。這時就需要對不同的層設置不同的學習率,下面就介紹如何為不同層配置不同的學習率。

二、不同層設置不同的學習率

在利用pre-trained model的參數做初始化之後,我們可能想讓fc層更新相對快一些,而希望前面的權值更新小一些,這就可以通過為不同的層設置不同的學習率來達到此目的。

為不同層設置不同的學習率,主要通過優化器對多個參數組進行設置不同的參數。所以,只需要將原始的參數組,劃分成兩個,甚至更多的參數組,然後分別進行設置學習率。 這裡將原始參數「切分」成fc3層參數和其餘參數,為fc3層設置更大的學習率。

請看代碼:

ignored_params = list(map(id, net.fc3.parameters())) # 返回的是parameters的 內存地址
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())
optimizer = optim.SGD([
{params: base_params},
{params: net.fc3.parameters(), lr: 0.001*10}], 0.001, momentum=0.9, weight_decay=1e-4)

第一行+ 第二行的意思就是,將fc3層的參數net.fc3.parameters()從原始參數net.parameters()中剝離出來 base_params就是剝離了fc3層的參數的其餘參數,然後在優化器中為fc3層的參數單獨設定學習率。

optimizer = optim.SGD(......)這裡的意思就是 base_params中的層,用 0.001, momentum=0.9, weight_decay=1e-4 fc3層設定學習率為: 0.001*10

完整代碼位於 github.com/tensor-yu/Py

補充:

挑選出特定的層的機制是利用內存地址作為過濾條件,將需要單獨設定的那部分參數,從總的參數中剔除。 base_params 是一個list,每個元素是一個Parameter 類 net.fc3.parameters() 是一個

ignored_params = list(map(id, net.fc3.parameters())) net.fc3.parameters() 是一個 所以迭代的返回其中的parameter,這裡有weight 和 bias 最終返回weight和bias所在內存的地址


轉載請註明出處:blog.csdn.net/u01199571

推薦閱讀:

相關文章