Pytorch數據集的讀取
class Mnist_data(Dataset):
def __init__(self,
root_dir,
pre_load=False,
transform=None):
self.root_dir = root_dir
self.transform = transform
self.filenames = [] #用於存儲文件名和label鍵值對
self.images = None
self.labels = None #下文用到
for j in range(10):
all_files = glob.glob(opj(self.root_dir,str(j),"*.png"))
# glob return 文件名的地址
for file in all_files:
self.filenames.append((file, j))
self.len = len(self.filenames)
if pre_load:
self._pre_load()
def _pre_load(self): #用於將數據載到內存
self.images = []
self.labels = []
for file, label in self.filenames:
img = Image.open(file)
self.images.append(img.copy())
img.close()
self.labels.append(label)
def __getitem__(self,index):
if self.images is not None: #緩存的不為空
img = self.images[index]
label = self.labels[index]
else:
img_path, label = self.filenames[index]
img = Image.open(img_path)
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return self.len
看的cs231n 2018版的 pytorch 教程,已經更到了0.4版本,這是比較標準的數據集讀取方法了.
cs231n有空會上傳代碼,assignment2做完了結果代碼誤刪了,準備下載2018版的作業重來一遍
下面的是另一版本的讀取方式,對比之下發現上面的官方教程多了一步緩存數據,更快.但是他添加了數據集分割的功能,可以綜合一下.
class DogCat(data.Dataset):
def __init__(self,root,transforms=None,train=True,test=False):
主要目標: 獲取所有圖片的地址,並根據訓練,驗證,測試劃分數據
self.test = test
imgs = [os.path.join(root,img) for img in os.listdir(root)]
# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
imgs = sorted(imgs,key=lambda x:int(x.split(.)[-2].split(/)[-1]))
else:
imgs = sorted(imgs,key=lambda x:int(x.split(.)[-2]))
imgs_num = len(imgs)
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7*imgs_num)]
else :
self.imgs = imgs[int(0.7*imgs_num):]
if transforms is None:
normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])
if self.test or not train:
self.transforms = T.Compose([
T.Scale(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
else :
self.transforms = T.Compose([
T.Scale(256),
T.RandomSizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self,index):
一次返回一張圖片的數據
img_path = self.imgs[index]
if self.test: label = int(self.imgs[index].split(.)[-2].split(/)[-1])
else: label = 1 if dog in img_path.split(/)[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
此外讀取圖片進行分類最快捷的是imagefolder方法
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root=hymenoptera_data/train,
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)
下面的是可視化的代碼
trainset = Mnist_data(mnist_png/training,
pre_load=True,
transform=transforms.ToTensor())
testset = Mnist_data(mnist_png/testing,
pre_load=True,
transform=transforms.ToTensor())
train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)
from torchvision.utils import make_grid
# functions to show an image
def imshow(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
# show images
imshow(make_grid(images))
另一個比較高效的數據集處理分割方式來自fastai論壇,只要稍微改一下基本就能應用於所有的數據集的清理.
import os
import random
import shutil
def organize_folder(folder):
_, _, filenames = next(os.walk(folder))
unique_classes = {filename.split(".")[0] for filename in filenames}
for _class in unique_classes:
path = os.path.join(folder, _class)
if not os.path.exists(path):
os.makedirs(path)
for filename in filenames:
if filename.startswith(_class):
shutil.move(os.path.join(folder, filename), os.path.join(path, filename))
def create_sample_folder(_from, to, percentage=0.1, move=True):
if not os.path.exists(to):
os.makedirs(to)
_, folders, _ = next(os.walk(_from))
for folder in folders:
if not os.path.exists(os.path.join(to, folder)):
os.makedirs(os.path.join(to, folder))
_, _, files = next(os.walk(os.path.join(_from, folder)))
sample = random.sample(files, int(len(files) * percentage))
for filename in sample:
if move:
shutil.move(os.path.join(_from, folder, filename), os.path.join(to, folder, filename))
else:
shutil.copyfile(os.path.join(_from, folder, filename), os.path.join(to, folder, filename))
推薦閱讀: