torchtext是PyTorch中用於處理文本預處理的包,高度封裝使用起來非常簡單快捷。
本文介紹如何保存Datasets對象到本地並重新載入。
如果使用過torchtext,預處理階段使用的時間一定會讓你印象深刻,如果你需要預處理的文本數據較大的話。我們以自然語言推斷(NLI)中經典的數據集SNLI為例。SNLI數據大小為570K,本人在實驗室伺服器上(TiTan XP顯卡,128G內存,28Core)預處理花費的時間為300s左右。單次處理可以勉強接受,但是如果每一次調整參數或者debug都需要等待5分鐘,相信沒有人可以接受。
預處理SNLI的代碼如下代碼所示:
from torchtext import data
from torchtext import datasets
from torchtext.vocab import GloVe
from nltk import word_tokenize
import numpy as np
class SNLI():
def __init__(self, args):
self.TEXT = data.Field(batch_first=True, include_lengths=True, tokenize=word_tokenize, lower=True)
self.LABEL = data.Field(sequential=False, unk_token=None)
self.train, self.dev, self.test = datasets.SNLI.splits(self.TEXT, self.LABEL)
self.TEXT.build_vocab(self.train, self.dev, self.test, vectors=GloVe(name=840B, dim=300))
self.LABEL.build_vocab(self.train)
self.train_iter, self.dev_iter, self.test_iter =
data.BucketIterator.splits((self.train, self.dev, self.test),
batch_size=args.batch_size,
device=args.gpu)
如果不知道上述基礎代碼什麼意思,建議參考torchtext文檔或者其他基礎教程,本文不展開介紹。
實驗中約9成的時間花費在以下數據劃分中
self.train, self.dev, self.test = datasets.SNLI.splits(self.TEXT, self.LABEL)
得到的self.train等是不可以直接序列化的,pickle和dill都不可以。通過debug可以發現處理後的數據保存在datasets.examples中,而examples是可以通過dill保存到本地的列表。
為了能夠在代碼中復用保存到本地的examples,我們需要根據examples重建datasets,即是上述的self.train,self.dev,self.test。目前官方沒有給出解決方案,思路我們參考了下面的文章,基本的思路也是保存dataset的examples。如果有其他思路,可以和我分享一下。
Use torchtext to Load NLP Datasets?towardsdatascience.comtorchtext.dataset.SNLI類繼承自TabularDataset,而後者繼承自Dataset。Dataset新建時需要傳遞examples和fields,如下所示。
class torchtext.data.Dataset(examples, fields, filter_pred=None)
examples即是我們保存到本地的數據,使用dill再次載入即可。fields是一個字典,可以debug看具體信息,SNLI預處理中如下。
fields = {premise: self.TEXT, hypothesis: self.TEXT, label: self.LABEL}
按照如下代碼即可從本地載入example並構建datasets
# 從本地載入切分好的數據集
def load_split_datasets(self, fields):
# 載入examples
with open(snli_train_examples_path, rb)as f:
train_examples = dill.load(f)
with open(snli_dev_examples_path, rb)as f:
dev_examples = dill.load(f)
with open(snli_test_examples_path, rb)as f:
test_examples = dill.load(f)
# 恢複數據集
train = SNLIDataset(examples=train_examples, fields=fields)
dev = SNLIDataset(examples=dev_examples, fields=fields)
test = SNLIDataset(examples=test_examples, fields=fields)
return train, dev, test
總體的思路大概就是這樣,本文只是一個簡單的記錄,沒有展開過多的細節。如果遇到了相應的問題,上述對於解決問題基本也是足夠了。優化後處理時間縮短到6秒左右,本地保存的examples有200M左右。優化後的結果如下: