雪花台湾

PyTorch學習筆記——多分類交叉熵損失函數

理解交叉熵

關於樣本集的兩個概率分布p和q,設p為真實的分布,比如[1, 0, 0]表示當前樣本屬於第一類,q為擬合的分布,比如[0.7, 0.2, 0.1]。

按照真實分布p來衡量識別一個樣本所需的編碼長度的期望,即平均編碼長度(信息熵):

如果使用擬合分布q來表示來自真實分布p的編碼長度的期望,即平均編碼長度(交叉熵):

直觀上,用p來描述樣本是最完美的,用q描述樣本就不那麼完美,根據吉布斯不等式, 恆成立,當q為真實分布時取等,我們將由q得到的平均編碼長度比由p得到的平均編碼長度多出的bit數稱為相對熵,也叫KL散度:

在機器學習的分類問題中,我們希望縮小模型預測和標籤之間的差距,即KL散度越小越好,在這裡由於KL散度中的 項不變(在其他問題中未必),故在優化過程中只需要關注交叉熵就可以了,因此一般使用交叉熵作為損失函數。

多分類任務中的交叉熵損失函數

其中 是一個概率分布,每個元素 表示樣本屬於第i類的概率; 是樣本標籤的onehot表示,當樣本屬於第類別i時 ,否則 ;c是樣本標籤。

PyTorch中的交叉熵損失函數實現

PyTorch提供了兩個類來計算交叉熵,分別是CrossEntropyLoss() 和NLLLoss()。

類定義如下

torch.nn.CrossEntropyLoss(
weight=None,
ignore_index=-100,
reduction="mean",
)

表示一個樣本的非softmax輸出,c表示該樣本的標籤,則損失函數公式描述如下,

如果weight被指定,

其中,

import torch
import torch.nn as nn

model = nn.Linear(10, 3)
criterion = nn.CrossEntropyLoss()

x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,)) # (16, )
logits = model(x) # (16, 3)

loss = criterion(logits, y)

類定義如下

torch.nn.NLLLoss(
weight=None,
ignore_index=-100,
reduction="mean",
)

表示一個樣本對每個類別的對數似然(log-probabilities),c表示該樣本的標籤,損失函數公式描述如下,

其中,

import torch
import torch.nn as nn

model = nn.Sequential(
nn.Linear(10, 3),
nn.LogSoftmax()
)
criterion = nn.NLLLoss()

x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,)) # (16, )
out = model(x) # (16, 3)

loss = criterion(out, y)

torch.nn.CrossEntropyLoss在一個類中組合了nn.LogSoftmax和nn.NLLLoss,

This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class. The input is expected to contain scores for each class.

推薦閱讀:

相关文章