背景

BERT的問世向世人宣告了無監督預訓練的語言模型在眾多NLP任務中成為「巨人肩膀」的可能性,接踵而出的GPT2、XL-Net則不斷將NLP從業者的期望帶向了新的高度。得益於這些力作模型的開源,使得我們在瞭解其論文思想的基礎上,可以借力其憑藉強大算力預訓練的模型從而快速在自己的數據集上開展實驗,甚至應用於真實的業務中。

在GitHub上已經存在使用多種語言/框架依照Google最初release的TensorFlow版本的代碼進行實現的Pretrained-BERT,並且都提供了較為詳細的文檔。本文主要展示通過極簡的代碼調用Pytorch Pretrained-BERT並進行fine-tuning的文本分類任務。

下面的代碼是使用pytorch

-pretrained-BERT進行文本分類的官方實現,感興趣的同學可以直接點進去閱讀:

https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py?

github.com

數據介紹

本文所使用的數據是標題及其對應的類別,如「中國的垃圾分類能走多遠」對應「社會」類別,共有28個類別,每個類別的訓練數據和測試數據各有1000條,數據已經同步至雲盤,歡迎下載。鏈接:

https://pan.baidu.com/s/1r4SI6-IizlCcsyMGL7RU8Q?

pan.baidu.com

提取碼: 6awx

載入庫

import os
import sys
import pickle
import pandas as pd
import numpy as np
from concurrent.futures import ThreadPoolExecutor
import torch
import pickle
from sklearn.preprocessing import LabelEncoder
from torch.optim import optimizer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.nn import CrossEntropyLoss,BCEWithLogitsLoss
from tqdm import tqdm_notebook, trange
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from sklearn.metrics import precision_recall_curve,classification_report
import matplotlib.pyplot as plt
%matplotlib inline

載入數據

# pandas讀取數據
data = pd.read_pickle("title_category.pkl")
# 列名重新命名
data.columns = [text,label]

標籤編碼

因為label為中文格式,為了適應模型的輸入需要進行ID化,此處調用sklearn中的label encoder方法快速進行變換。

le = LabelEncoder()
le.fit(data.label.tolist())
data[label] = le.transform(data.label.tolist())

觀察數據

訓練數據準備

本文需要使用的預訓練bert模型為使用中文維基語料訓練的字元級別的模型,在Google提供的模型列表中對應的名稱為bert-base-chinese,使用更多語言語料訓練的模型名稱可以參見下方鏈接:github.com/huggingface/

另外,首次執行下面的代碼時因為本地沒有cache,因此會自動啟動下載,實踐證明下載速度還是很快的。需要注意的是,do_lower_case參數需要手動顯式的設置為False

# 分詞工具
bert_tokenizer = BertTokenizer.from_pretrained(bert-base-chinese, do_lower_case=False)
?
# 封裝類
class DataPrecessForSingleSentence(object):
"""
對文本進行處理
"""
?
def __init__(self, bert_tokenizer, max_workers=10):
"""
bert_tokenizer :分詞器
dataset :包含列名為text與label的pandas dataframe
"""
self.bert_tokenizer = bert_tokenizer
# 創建多線程池
self.pool = ThreadPoolExecutor(max_workers=max_workers)
# 獲取文本與標籤
?
def get_input(self, dataset, max_seq_len=30):
"""
通過多線程(因為notebook中多進程使用存在一些問題)的方式對輸入文本進行分詞、ID化、截斷、填充等流程得到最終的可用於模型輸入的序列。

入參:
dataset : pandas的dataframe格式,包含兩列,第一列為文本,第二列為標籤。標籤取值為{0,1},其中0表示負樣本,1代表正樣本。
max_seq_len : 目標序列長度,該值需要預先對文本長度進行分別得到,可以設置為小於等於512(BERT的最長文本序列長度為512)的整數。

出參:
seq : 在入參seq的頭尾分別拼接了CLS與SEP符號,如果長度仍小於max_seq_len,則使用0在尾部進行了填充。
seq_mask : 只包含0、1且長度等於seq的序列,用於表徵seq中的符號是否是有意義的,如果seq序列對應位上為填充符號,
那麼取值為1,否則為0。
seq_segment : shape等於seq,因為是單句,所以取值都為0。
labels : 標籤取值為{0,1},其中0表示負樣本,1代表正樣本。


"""
sentences = dataset.iloc[:, 0].tolist()
labels = dataset.iloc[:, 1].tolist()
# 切詞
tokens_seq = list(
self.pool.map(self.bert_tokenizer.tokenize, sentences))
# 獲取定長序列及其mask
result = list(
self.pool.map(self.trunate_and_pad, tokens_seq,
[max_seq_len] * len(tokens_seq)))
seqs = [i[0] for i in result]
seq_masks = [i[1] for i in result]
seq_segments = [i[2] for i in result]
return seqs, seq_masks, seq_segments, labels
?
def trunate_and_pad(self, seq, max_seq_len):
"""
1. 因為本類處理的是單句序列,按照BERT中的序列處理方式,需要在輸入序列頭尾分別拼接特殊字元CLS與SEP,
因此不包含兩個特殊字元的序列長度應該小於等於max_seq_len-2,如果序列長度大於該值需要那麼進行截斷。
2. 對輸入的序列 最終形成[CLS,seq,SEP]的序列,該序列的長度如果小於max_seq_len,那麼使用0進行填充。

入參:
seq : 輸入序列,在本處其為單個句子。
max_seq_len : 拼接CLS與SEP這兩個特殊字元後的序列長度

出參:
seq : 在入參seq的頭尾分別拼接了CLS與SEP符號,如果長度仍小於max_seq_len,則使用0在尾部進行了填充。
seq_mask : 只包含0、1且長度等於seq的序列,用於表徵seq中的符號是否是有意義的,如果seq序列對應位上為填充符號,
那麼取值為1,否則為0。
seq_segment : shape等於seq,因為是單句,所以取值都為0。

"""
# 對超長序列進行截斷
if len(seq) > (max_seq_len - 2):
seq = seq[0:(max_seq_len - 2)]
# 分別在首尾拼接特殊符號
seq = [[CLS]] + seq + [[SEP]]
# ID化
seq = self.bert_tokenizer.convert_tokens_to_ids(seq)
# 根據max_seq_len與seq的長度產生填充序列
padding = [0] * (max_seq_len - len(seq))
# 創建seq_mask
seq_mask = [1] * len(seq) + padding
# 創建seq_segment
seq_segment = [0] * len(seq) + padding
# 對seq拼接填充序列
seq += padding
assert len(seq) == max_seq_len
assert len(seq_mask) == max_seq_len
assert len(seq_segment) == max_seq_len
return seq, seq_mask, seq_segment

DataPrecessForSingleSentence是一個用於將pandas Dataframe轉化為模型輸入的類,每個函數的入參和出參已經寫得比較清晰翔實了。處理流程大致如下:

  • 通過多線程的方式進行調用tokenize進行切詞(字元級別)
  • 對於切詞產生的序列如果長度大於設置的max_seq_len-2時需要進行截斷。BERT中使用的max_seq_len是512,因此最長不可以超過512個字元。另外,本處需要減2的原因在於還需要在原始序列上拼接兩個特殊符號,因此需要預留兩個字元的「槽位」。
  • 在首、尾分別拼接[CLS]及[SEP],如果序列長度不足max_seq_len,使用0進行填充。產生相應的mask序列和segment序列,其中mask序列使用0、1值標註對應位上是否為填充符號,如果是那麼取值為0,負責為1,如果序列長度不足max_seq_len,使用0進行填充。segment序列則用於表示序列是否為同一個輸入源,在本例中取值全部為0,如果序列長度不足max_seq_len,使用0進行填充。
  • 對於填充後的序列進行ID化,調用的是convert_tokens_to_ids方法,最終返回seq,seq_maskseq_segment序列。

# 類初始化
processor = DataPrecessForSingleSentence(bert_tokenizer= bert_tokenizer)
# 產生輸入ju 數據
seqs, seq_masks, seq_segments, labels = processor.get_input(
dataset=data, max_seq_len=30)

本文設定的max_seq_len為30,因為通過統計標題的長度可以得知30已經是其85百分位數,基本已經涵蓋了絕大部分樣本。

載入預訓練的bert模型

# 載入預訓練的bert模型
model = BertForSequenceClassification.from_pretrained(
bert-base-chinese, num_labels=28)

同樣,首次執行會自動啟動下載,在本例中因為有28個類別,因此num_labels參數需要設置為28。

數據格式化

數據格式化指的是將list格式的數據轉化為torch的tensor格式。

# 轉換為torch tensor
t_seqs = torch.tensor(seqs, dtype=torch.long)
t_seq_masks = torch.tensor(seq_masks, dtype = torch.long)
t_seq_segments = torch.tensor(seq_segments, dtype = torch.long)
t_labels = torch.tensor(labels, dtype = torch.long)
?
train_data = TensorDataset(t_seqs, t_seq_masks, t_seq_segments, t_labels)
train_sampler = RandomSampler(train_data)
train_dataloder = DataLoader(dataset= train_data, sampler= train_sampler,batch_size = 256)

使用了TensorDatasetRandomSamplerDataLoader對輸入數據進行了封裝,相較於自己編寫generator代碼量簡短很多,此處設置的batch size為256。

# 將模型轉換為trin mode
model.train()

BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(21128, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(1): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(2): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(3): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(4): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(5): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(6): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(7): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(8): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(9): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(10): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): BertLayerNorm()
(dropout): Dropout(p=0.1)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1)
(classifier): Linear(in_features=768, out_features=28, bias=True)
)

從列印出的網路結構可以看出,classifier層的out_features已經設置為了上文的提到的28。另外,我們可以關注一下BertPooler層,如果對於前面步驟中在序列頭部拼接[CLS]有疑問的話,通過閱讀BertPooler的代碼可以明晰該字元的用處。

# link : https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
?
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output

上面的代碼是BertPooler的實現,可以看出在forward方法中hidden_states[:, 0]只取了第一個字元對應的hidden unit,因此憑藉雙向Encoder的表徵能力,[CLS]符號融合了整個序列的表徵信息,因此可以用於以一種低維的方式對整個序列進行表徵。

# 待優化的參數
param_optimizer = list(model.named_parameters())
no_decay = [bias, LayerNorm.bias, LayerNorm.weight]
?
optimizer_grouped_parameters = [
{
params:
[p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
weight_decay:
0.01
},
{
params:
[p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
weight_decay:
0.0
}
]
?
optimizer = BertAdam(optimizer_grouped_parameters,
lr=2e-05,
warmup= 0.1 ,
t_total= 2000)
?
device = cpu

我記得當時在看《動手學深度學習》一書(3.12節)時,李沐提到權重衰減等價於L2正則化。在bert官方的代碼中對於bias項、LayerNorm.biasLayerNorm.weight項免於正則化。

fine-tuning

# 存儲每一個batch的loss
loss_collect = []
for i in trange(10, desc=Epoch):
for step, batch_data in enumerate(
tqdm_notebook(train_dataloder, desc=Iteration)):
batch_data = tuple(t.to(device) for t in batch_data)
batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels = batch_data
# 對標籤進行onehot編碼
one_hot = torch.zeros(batch_labels.size(0), 28).long()
one_hot_batch_labels = one_hot.scatter_(
dim=1,
index=torch.unsqueeze(batch_labels, dim=1),
src=torch.ones(batch_labels.size(0), 28).long())
?
logits = model(
batch_seqs, batch_seq_masks, batch_seq_segments, labels=None)
logits = logits.softmax(dim=1)
loss_function = CrossEntropyLoss()
loss = loss_function(logits, batch_labels)
loss.backward()
loss_collect.append(loss.item())
print("
%f" % loss, end=)
optimizer.step()
optimizer.zero_grad()

總共進行了10個epoch的訓練,將各個batch的loss寫入了loss_collect,下面對loss_collect進行可視化。

loss可視化

plt.figure(figsize=(12,8))
plt.plot(range(len(loss_collect)), loss_collect,g.)
plt.grid(True)
plt.show()

從上圖可以看出,loss在前200個batch下降速度明顯,隨後下降速度逐漸變緩,但從整體趨勢以及縱軸的loss絕對值可以看出,loss距離收斂還存在一定空間,如果增大訓練樣本量及迭代次數,loss依然可以繼續減小。

測試

模型持久化

torch.save(model,open("fine_tuned_chinese_bert.bin","wb"))

載入測試數據

test_data = pd.read_pickle("title_category_valid.pkl")
test_data.columns = [text,label]
# 標籤ID化
test_data[label] = le.transform(test_data.label.tolist())
# 轉換為tensor
test_seqs, test_seq_masks, test_seq_segments, test_labels = processor.get_input(
dataset=test_data, max_seq_len=30)
test_seqs = torch.tensor(test_seqs, dtype=torch.long)
test_seq_masks = torch.tensor(test_seq_masks, dtype = torch.long)
test_seq_segments = torch.tensor(test_seq_segments, dtype = torch.long)
test_labels = torch.tensor(test_labels, dtype = torch.long)
test_data = TensorDataset(test_seqs, test_seq_masks, test_seq_segments, test_labels)
test_dataloder = DataLoader(dataset= train_data, batch_size = 256)
# 用於存儲預測標籤與真實標籤
true_labels = []
pred_labels = []
model.eval()
# 預測
with torch.no_grad():
for batch_data in tqdm_notebook(test_dataloder, desc = TEST):
batch_data = tuple(t.to(device) for t in batch_data)
batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels = batch_data
logits = model(
batch_seqs, batch_seq_masks, batch_seq_segments, labels=None)
logits = logits.softmax(dim=1).argmax(dim = 1)
pred_labels.append(logits.detach().numpy())
true_labels.append(batch_labels.detach().numpy())
# 查看各個類別的準召
print(classification_report(np.concatenate(true_labels), np.concatenate(pred_labels)))

precision recall f1-score support
?
0 0.93 0.95 0.94 1000
1 0.88 0.90 0.89 1000
2 0.91 0.92 0.91 1000
3 0.88 0.95 0.92 1000
4 0.88 0.92 0.90 1000
5 0.91 0.91 0.91 1000
6 0.85 0.84 0.84 1000
7 0.93 0.97 0.95 1000
8 0.88 0.94 0.91 1000
9 0.77 0.86 0.81 1000
10 0.97 0.94 0.96 1000
11 0.85 0.90 0.88 1000
12 0.91 0.97 0.94 1000
13 0.75 0.86 0.80 1000
14 0.84 0.90 0.87 1000
15 0.77 0.87 0.82 1000
16 0.91 0.95 0.93 1000
17 0.96 0.95 0.95 1000
18 0.91 0.93 0.92 1000
19 0.92 0.94 0.93 1000
20 0.94 0.93 0.93 1000
21 0.80 0.80 0.80 1000
22 0.93 0.97 0.95 1000
23 0.82 0.86 0.84 1000
24 0.00 0.00 0.00 1000
25 0.92 0.93 0.93 1000
26 0.89 0.90 0.89 1000
27 0.89 0.89 0.89 1000
?
micro avg 0.88 0.88 0.88 28000
macro avg 0.85 0.88 0.86 28000
weighted avg 0.85 0.88 0.86 28000

可以看出,整體的準召還是比較理想的,不過因為訓練和測試都是使用的平衡數據集,因此在真實分佈上的準召與該數據集存在一定差異。

總結

本文主要是對run_classifier.py的代碼進行了簡化,然後在中文數據集上進行了fine-tuning。具體的數據集和代碼在文中進行了提供和展示,歡迎交流!


推薦閱讀:
相關文章