Tensorflow学习过程中tfrecord的简单理解
一般使用直接将数据加载到内存的方式来存储数据量较小的数据,然后再分batch输入网络进行训练。如果数据量太大,这种方法是十分消耗内存的,这时可以使用tensorflow提供的队列queue从文件中提取数据(比如csv文件等)。还有一种较为常用的,高效的读取方法,即使用tensorflow内定标准格式——TFRecords.作者也是刚接触tensorflow,对日常学习遇到的问题做简单记录,有不对地方需要指正。
TFRecord是谷歌推荐的一种常用的存储二进制序列数据的文件格式,理论上它可以保存任何格式的信息。下面是Tensorflow的官网给出的文档结构,整个文件由文件长度信息,长度校验码,数据,数据校验码组成。
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
import numpy as np
import tensorflow as tf
writer = tf.python_io.TFRecordWriter(‘test.tfrecord‘)
TensorFlow经常使用 tf.Example 来写入,读取TFRecord数据。
通常tf.example有下面几种数据结构:
TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature,如下所示:
#feature一般是多维数组,要先转为list
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
#tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape)))
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
下面以一个具体的简单例子来介绍tf.example
for k in range(0, 3):
x = 0.1712 + k
y = [1+k, 2+k]
z = np.array([[1,2,3],[4,5,6]]) + k
z = z.astype(np.uint8)
z_raw = z.tostring()
example = tf.train.Example(
features = tf.train.Features(
feature = {‘x‘:tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
‘y‘:tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
‘z‘:tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}))
serialized = example.SerializeToString()
writer.write(serialized)
writer.close()
x,y,z分别是以float,int64和string的形式存储的,注意观察下面语句:
feature = {‘x‘:tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
‘y‘:tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
‘z‘:tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}
value的值是一个list形式,x定义的为一个数,value的值应为[x],同样y定义的格式就是一个list所以value的值直接为y即可,z_raw是由z转换过来的string形式,对应的value值与x的形式应该是一样的。
#output file name string to a queue
filename_queue = tf.train.string_input_producer([‘test.tfrecord‘], num_epochs = None)
#Create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
#Get feature from serialized example
features = tf.parse_single_example(serialized_example,
features = {‘x‘: tf.FixedLenFeature([],tf.float32),
‘y‘: tf.FixedLenFeature([2],tf.int64),
‘z‘: tf.FixedLenFeature([],tf.string)})
x_out = features[‘x‘]
y_out = features[‘y‘]
z_raw_out = features[‘z‘]
z_out = tf.decode_raw(z_raw_out,tf.uint8)
z_out = tf.reshape(z_out, [2,3])
print(x_out)
print(y_out)
print(z_out)
显示结果为:
Tensor("ParseSingleExample_2/ParseSingleExample:0", shape=(), dtype=float32)
Tensor("ParseSingleExample_2/ParseSingleExample:1", shape=(2,), dtype=int64)
Tensor("Reshape_1:0", shape=(2, 3), dtype=uint8)
主要参考:
TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
tensorflow学习笔记——高效读取数据的方法(TFRecord
Tensorflow学习记录 --TensorFlow高效读取数据tfrecord
原文:https://www.cnblogs.com/ysfurh/p/14127941.html