數據集在一個tsv文件中存放,大小約為120G,一共有300萬條,訓練的機器內存只有32G,怎麼用pytorch載入數據進行訓練,求指導。


只能分批讀入,沒有其他的辦法。但是,分批歸分批,具體操作還是有不少講究。

  1. 總體思路:將訓練數據分成多個子文件,讀入一個就訓練一個,然後刪除當前這個讀下一個,以此降低內存開銷。
  2. 不要讀入原始的tsv文件:與將全部數據放在內存中相比,分批訓練不可避免地需要對同一個文件做多次讀入,造成額外的耗時。有什麼優化的方法嗎?當然。不要直接處理 tsv 或者任何文本格式的訓練數據,增加一個單獨的預處理步驟,將所有數據轉換成 python 支持的格式(比如 list)後,然後直接使用 torch.save() 進行序列化落盤。載入的時候,使用對應的 torch.load()。這對 api 組合與普通文本格式相比,快得飛起。
  3. 數據打亂:在做 shuffle 的時候,要從兩個維度去打亂。在每個 epoch 內部,子文件的順序打亂;在每個子文件內部,樣本的順序打亂;
  4. 真正地刪除:前面說到,要訓練一個就刪除一個,別忘了使用下面的代碼確保真的從內存裏抹去了:

gc.collect()
del self.cur_dataset
gc.collect()


在你 implement dataset 的時候,不要在 __init__ 裡面讀取文件,而是在 __getitem__ 裡面使用 linecache module 的 getline。你可以直接把 index pass 進去讀就行了。


不能一行一行讀嗎?不太可能會有需要一定要同時把所有數據存到內存裏。


可以做個預處理,把數據放進資料庫裏,sqlite就行。


  1. 將數據轉換成二進位存儲 然後可以使用numpy的mmap 記得好像fairseq之類的有實現這樣的功能 可以看看源碼學一學
  2. 可以用iterdataset 然後在iter方法中一行一行 或者幾行幾行讀 而不是在init方法中一次性讀取


pandas.read_csv(filepath, chunksize=n)


先了解一下迭代器,然後torch.utils.data可以解決你的需求


分批載入,每次讀幾個batch


最簡單的辦法就是創建並掛載一個體積大於你數據大小的swapfile

速度不會比其他回答裏的方案慢


看一下 torchtext 的文檔?他提供的 torchtext.data.TabularDataset 似乎就可以滿足你的需求啊。


重寫 迭代器。

dataloader實際上是一個list。每次讀滿了就隨機換下一個dataloader

問題是loss損失函數會有起伏。每次換dataloader loss都會增加。

不過看起來並沒有損失性能。


推薦閱讀:
相關文章