代碼鏈接

MagaretJi/TFRecord?

github.com

TensorFlow現在有三種方法讀取數據:

  • 直接載入數據 適合數據量小的
  • 文件讀取數據:從文件讀取數據,如CSV文件格式
  • TFRecord:適合數據量大

現在我們來講述最後一種方法

什麼是TFRecord?

TRecord數據文件是一種將圖像數據和標籤統一存儲的二進位文件,能更好的利用內存,在TensorFlow中快速的複製,移動,讀取,存儲等。對於我們普通開發者而言,我們並不需要關心這些,Tensorflow 提供了豐富的 API 可以幫助我們輕鬆讀寫 TFRecord文件。我們只關心如何使用Tensorflow生成TFRecord,並且讀取它。

如何使用TFRecord?

因為深度學習很多都是與圖片打交道,那麼,我們可以嘗試下把一張張的圖片轉換成 TFRecord 文件。

不說很多原理,直接看代碼,代碼全部親測可用。TensorFlow小白,有任何問題請及時指出。

數據集

本數據集採用kaggle的貓狗大戰數據集中的的訓練集(即train)。

鏈接: https://pan.baidu.com/s/1AgHPMMkLZzR4HrEdWfuNhw 密碼: vtg6

生成TFRecord文件

step1 數據集準備工作

我們將一個文件下的所有貓狗圖片的位置和對應的標籤分別存放到兩個list中。

def get_files(file_dir,is_random=True):
image_list=[]
label_list=[]
dog_count=0
cat_count=0
for file in os.listdir(file_dir):
name=file.split(sep=.)
if(name[0]==cat):
image_list.append(file_dir+file)
label_list.append(0)
cat_count+=1
else:
image_list.append(file_dir+file)
label_list.append(1)
dog_count+=1
print(%d cats and %d dogs%(cat_count,dog_count))

image_list=np.asarray(image_list)
label_list=np.asarray(label_list) if is_random:
rnd_index=np.arange(len(image_list))
np.random.shuffle(rnd_index)
image_list=image_list[rnd_index]
label_list=label_list[rnd_index]
return image_list,label_list

How to use?

get_files(file_dir,is_random=True)
<!--file_dir:圖片文件中的所在位置-->

step2 TFRecord數據類型轉換

在保存圖片信息的時候,需要先將這些圖片的信息轉換為byte數據才能寫入到tfrecord文件中。屬性的取值可以為字元串(BytesList)、實數列表(FloatList)或者整數列表(Int64List)可以看見TFRecord是以字典的形式存儲的,這裡我們存儲了image、label、width、height的信息。

def int64_feature(values):
if not isinstance(values,(tuple,list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

def float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def image_to_tfexample(image_data, label,size):
return tf.train.Example(features=tf.train.Features(feature={
image: bytes_feature(image_data),
label: int64_feature(label),
image_width:int64_feature(size[0]),
image_height:int64_feature(size[1])
}))

step3 數據存儲

將之前的兩個list中的信息轉化為我們需要的TFRecord數據類型文件

def _convert_dataset(image_list, label_list, tfrecord_dir):
""" Convert data to TFRecord format. """
with tf.Graph().as_default():
with tf.Session() as sess:
if not os.path.exists(tfrecord_dir):
os.makedirs(tfrecord_dir)
output_filename = os.path.join(tfrecord_dir, "train.tfrecord")
tfrecord_writer = tf.python_io.TFRecordWriter(output_filename)
length = len(image_list)
for i in range(length): # 圖像數據
image_data = Image.open(image_list[i],r)
size = image_data.size
image_data = image_data.tobytes()
label = label_list[i]
example = image_to_tfexample(image_data, label,size)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write(
>> Converting image %d/%d % (i + 1, length))
sys.stdout.flush()

sys.stdout.write(
)
sys.stdout.flush()

How to use?

_convert_dataset(image_list, label_list, tfrecord_dir)
<!--image_list,label_list:上述產生的兩個list-->
<!--tfrecord_dir:你要保存TFRecord文件的位置-->

step4 解析TFRecord數據

你們不禁要問了:怎麼解析這麼複雜的數據呢?我們使用tf.parse_single_example() 將存儲為字典形式的TFRecord數據解析出來。這樣我們就將image、label、width、height的信息就原封不動「拿」出來了。

def read_and_decode(tfrecord_path):
data_files = tf.gfile.Glob(tfrecord_path) #data_path為TFRecord格式數據的路徑
filename_queue = tf.train.string_input_producer(data_files,shuffle=True)
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
label:tf.FixedLenFeature([],tf.int64),
image:tf.FixedLenFeature([],tf.string),
image_width: tf.FixedLenFeature([],tf.int64),
image_height: tf.FixedLenFeature([],tf.int64),
})

image = tf.decode_raw(features[image],tf.uint8)
image_width = tf.cast(features[image_width],tf.int32)
image_height = tf.cast(features[image_height],tf.int32)
image = tf.reshape(image,[image_height,image_width,3])
label = tf.cast(features[label], tf.int32)
return image,label

How to use?

read_and_decode(tfrecord_path)
<!--tfrecord_path:就是你剛剛存放TFRecord的文件位置,我們將它取出來就好了。-->

step5 載入數據集

數據拿出來了,那我們就要用它來組成一個個batch,這樣我們就可以訓練模型了。

def batch(image,label):
# Load training set.
#一定要reshape一下image,不然會報錯。
image = tf.image.resize_images(image, [128, 128])
with tf.name_scope(input_train):
image_batch, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size=30,
capacity=2000,
min_after_dequeue=1500)
return image_batch, label_batch

How to use?

batch(image,label)
<!--image,label:我們剛剛解析出來的圖片和標籤-->

總結

最後讓我們回顧一下數據轉換的步驟


推薦閱讀:
相关文章