NO· 05 - 圖像分類
圖像分類
本指南訓練神經網路模型,對服裝圖像進行分類,如運動鞋和襯衫。 如果您不瞭解所有細節,那也沒關係;這是一個完整的TensorFlow程序的快節奏概述,詳細解釋了我們的目標。
本指南使用tf.keras,一個高級API,用於在TensorFlow中構建和訓練模型。
from __future__ import absolute_import, division, print_function, unicode_literals
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)
2.0.0-alpha0
1. 導入MNIST數據集
本指南使用Fashion MNIST數據集,該數據集包含10個類別中的70,000個灰度圖像。 圖像顯示了低解析度(28 x 28像素)的單件服裝,如下所示:
Fashion MNIST旨在替代經典的MNIST數據集 - 通常用作計算機視覺機器學習計劃的「Hello,World」。 MNIST數據集包含手寫數字(0,1,2等)的圖像,其格式與我們將在此處使用的服裝的格式相同。
本指南使用Fashion MNIST的多樣性,因為這是一個比普通MNIST稍微更具挑戰性的問題。這兩個數據集都相對較小,用於驗證演算法是否按預期工作。它們是測試和調試代碼的良好起點。
我們將使用60,000張圖像來訓練網路,10,000張圖像來評估網路模型的準確度。您可以直接從TensorFlow訪問Fashion MNIST,直接從TensorFlow導入並加重Fashion MNIST數據:
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 32768/29515 [=================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26427392/26421880 [==============================] - 1s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 8192/5148 [===============================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4423680/4422102 [==============================] - 0s 0us/step
載入數據返回4個NumPy數組:
- train_images和train_labels數組是訓練集 – 模型用於學習的數據。
- 模型將根據測試集test_images和test_labels數組進行測試。
圖像為28x28 NumPy數組,像素值範圍為0到255,標籤是一個整數數組,範圍為0到9,這些對應於圖像所代表的服裝類別:
| Label | Class | |-------|-------------| | 0 | T-shirt/top | | 1 | Trouser | | 2 | Pullover | | 3 | Dress | | 4 | Coat | | 5 | Sandal | | 6 | Shirt | | 7 | Sneaker | | 8 | Bag | | 9 | Ankle boot |
每個圖像都映射到一個標籤,由於類名不包含在數據集中,因此將它們存儲在此處以便在繪製圖像時使用:
class_names = [T-shirt/top, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot]
2. 探索數據
讓我們在訓練模型之前探索數據集的格式。以下顯示訓練集中有60,000個圖像,每個圖像表示為28 x 28像素:
print(train_images.shape)
(60000, 28, 28)
同樣,訓練集中有60,000個標籤:
print(len(train_labels)
60000
每個標籤都是0到9之間的整數:
print(train_labels)
array([9, 0, 0, ..., 3, 0, 5], dtype=uint8)
測試集中有10,000個圖像。同樣,每個圖像表示為28 x 28像素:
print(test_images.shape)
(10000, 28, 28)
測試集包含10,000個圖像標籤:
print(len(test_labels)
10000
3. 預處理數據
在訓練網路之前必須對數據進行預處理。 如果您檢查訓練集中的第一個圖像,您將看到像素值落在0到255的範圍內:
plt.figure()
plt.imshow(train_images[0]
plt.colorbar()
plt.grid(False)
plt.show()