理解交叉熵

關於樣本集的兩個概率分布p和q,設p為真實的分布,比如[1, 0, 0]表示當前樣本屬於第一類,q為擬合的分布,比如[0.7, 0.2, 0.1]。

按照真實分布p來衡量識別一個樣本所需的編碼長度的期望,即平均編碼長度(信息熵):

H(p)=-sum_{i=1}^{C}{p(x_i)log(p(x_i))}

如果使用擬合分布q來表示來自真實分布p的編碼長度的期望,即平均編碼長度(交叉熵):

H(p, q)=-sum_{i=1}^{C}{p(x_i)log(q(x_i))}

直觀上,用p來描述樣本是最完美的,用q描述樣本就不那麼完美,根據吉布斯不等式, H(p, q) geq H(p) 恆成立,當q為真實分布時取等,我們將由q得到的平均編碼長度比由p得到的平均編碼長度多出的bit數稱為相對熵,也叫KL散度:

D(p || q)=H(p, q) - H(p) =sum_{i=1}^{C}{p(x_i)log(frac{p(x_i)}{q(x_i)})}

在機器學習的分類問題中,我們希望縮小模型預測和標籤之間的差距,即KL散度越小越好,在這裡由於KL散度中的 H(p) 項不變(在其他問題中未必),故在優化過程中只需要關注交叉熵就可以了,因此一般使用交叉熵作為損失函數。

多分類任務中的交叉熵損失函數

Loss=-sum_{i=0}^{C-1}{y_i log(p_i)}= -log(p_c)

其中 p=[p_0, ..., p_{C-1}] 是一個概率分布,每個元素 p_i 表示樣本屬於第i類的概率; y=[y_0, ..., y_{C-1}] 是樣本標籤的onehot表示,當樣本屬於第類別i時y_i=1 ,否則 y_i=0 ;c是樣本標籤。

PyTorch中的交叉熵損失函數實現

PyTorch提供了兩個類來計算交叉熵,分別是CrossEntropyLoss() 和NLLLoss()。

  • torch.nn.CrossEntropyLoss()

類定義如下

torch.nn.CrossEntropyLoss(
weight=None,
ignore_index=-100,
reduction="mean",
)

z=[z_0, ..., z_{C-1}] 表示一個樣本的非softmax輸出,c表示該樣本的標籤,則損失函數公式描述如下,

loss(z, c) = -log(frac{exp(z[c])}{sum_{j=0}^{C-1}{exp(z[j])}}) =-z[c]+log(sum_{j=0}^{C-1}{exp(z[j])})

如果weight被指定,

loss(z, c)  =w cdot (-z[c]+log(sum_{j=0}^{C-1}{exp(z[j])}))

其中,w=weight[c] cdot 1{c
e ignore\_index}

import torch
import torch.nn as nn

model = nn.Linear(10, 3)
criterion = nn.CrossEntropyLoss()

x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,)) # (16, )
logits = model(x) # (16, 3)

loss = criterion(logits, y)

  • torch.nn.NLLLoss()

類定義如下

torch.nn.NLLLoss(
weight=None,
ignore_index=-100,
reduction="mean",
)

a=[a_0, ..., a_{C-1}] 表示一個樣本對每個類別的對數似然(log-probabilities),c表示該樣本的標籤,損失函數公式描述如下,

loss(a, c) = -w cdot a[c] = -w cdot log(p_c)

其中, w=weight[c] cdot 1{c
e ignore\_index}

import torch
import torch.nn as nn

model = nn.Sequential(
nn.Linear(10, 3),
nn.LogSoftmax()
)
criterion = nn.NLLLoss()

x = torch.randn(16, 10)
y = torch.randint(0, 3, size=(16,)) # (16, )
out = model(x) # (16, 3)

loss = criterion(out, y)

  • 總結

torch.nn.CrossEntropyLoss在一個類中組合了nn.LogSoftmax和nn.NLLLoss,

This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class. The input is expected to contain scores for each class.

推薦閱讀:

相关文章