class Mnist_data(Dataset):
def __init__(self,
root_dir,
pre_load=False,
transform=None):
self.root_dir = root_dir
self.transform = transform
self.filenames = [] #用於存儲文件名和label鍵值對
self.images = None
self.labels = None #下文用到

for j in range(10):
all_files = glob.glob(opj(self.root_dir,str(j),"*.png"))
# glob return 文件名的地址

for file in all_files:
self.filenames.append((file, j))
self.len = len(self.filenames)
if pre_load:
self._pre_load()

def _pre_load(self): #用於將數據載到內存
self.images = []
self.labels = []

for file, label in self.filenames:
img = Image.open(file)
self.images.append(img.copy())
img.close()
self.labels.append(label)

def __getitem__(self,index):
if self.images is not None: #緩存的不為空
img = self.images[index]
label = self.labels[index]
else:
img_path, label = self.filenames[index]
img = Image.open(img_path)

if self.transform is not None:
img = self.transform(img)
return img, label

def __len__(self):
return self.len

看的cs231n 2018版的 pytorch 教程,已經更到了0.4版本,這是比較標準的數據集讀取方法了.

cs231n有空會上傳代碼,assignment2做完了結果代碼誤刪了,準備下載2018版的作業重來一遍

下面的是另一版本的讀取方式,對比之下發現上面的官方教程多了一步緩存數據,更快.但是他添加了數據集分割的功能,可以綜合一下.

class DogCat(data.Dataset):

def __init__(self,root,transforms=None,train=True,test=False):

主要目標: 獲取所有圖片的地址,並根據訓練,驗證,測試劃分數據

self.test = test
imgs = [os.path.join(root,img) for img in os.listdir(root)]

# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
imgs = sorted(imgs,key=lambda x:int(x.split(.)[-2].split(/)[-1]))
else:
imgs = sorted(imgs,key=lambda x:int(x.split(.)[-2]))

imgs_num = len(imgs)

if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7*imgs_num)]
else :
self.imgs = imgs[int(0.7*imgs_num):]

if transforms is None:
normalize = T.Normalize(mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225])

if self.test or not train:
self.transforms = T.Compose([
T.Scale(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
else :
self.transforms = T.Compose([
T.Scale(256),
T.RandomSizedCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])

def __getitem__(self,index):

一次返回一張圖片的數據

img_path = self.imgs[index]
if self.test: label = int(self.imgs[index].split(.)[-2].split(/)[-1])
else: label = 1 if dog in img_path.split(/)[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label

def __len__(self):
return len(self.imgs)

此外讀取圖片進行分類最快捷的是imagefolder方法

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
hymenoptera_dataset = datasets.ImageFolder(root=hymenoptera_data/train,
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)

下面的是可視化的代碼

trainset = Mnist_data(mnist_png/training,
pre_load=True,
transform=transforms.ToTensor())
testset = Mnist_data(mnist_png/testing,
pre_load=True,
transform=transforms.ToTensor())
train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)

from torchvision.utils import make_grid
# functions to show an image
def imshow(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))

# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()

# show images
imshow(make_grid(images))

另一個比較高效的數據集處理分割方式來自fastai論壇,只要稍微改一下基本就能應用於所有的數據集的清理.

import os
import random
import shutil

def organize_folder(folder):
_, _, filenames = next(os.walk(folder))
unique_classes = {filename.split(".")[0] for filename in filenames}
for _class in unique_classes:
path = os.path.join(folder, _class)
if not os.path.exists(path):
os.makedirs(path)
for filename in filenames:
if filename.startswith(_class):
shutil.move(os.path.join(folder, filename), os.path.join(path, filename))

def create_sample_folder(_from, to, percentage=0.1, move=True):
if not os.path.exists(to):
os.makedirs(to)
_, folders, _ = next(os.walk(_from))
for folder in folders:
if not os.path.exists(os.path.join(to, folder)):
os.makedirs(os.path.join(to, folder))
_, _, files = next(os.walk(os.path.join(_from, folder)))
sample = random.sample(files, int(len(files) * percentage))
for filename in sample:
if move:
shutil.move(os.path.join(_from, folder, filename), os.path.join(to, folder, filename))
else:
shutil.copyfile(os.path.join(_from, folder, filename), os.path.join(to, folder, filename))

推薦閱讀:

相關文章