首页 > 其他 > 详细

Tensorflow学习记录 --TensorFlow高效读取数据tfrecord

时间:2020-12-13 11:58:35      阅读:23      评论:0      收藏:0      [点我收藏+]

Tensorflow学习过程中tfrecord的简单理解

一、TFRecord的介绍:

一般使用直接将数据加载到内存的方式来存储数据量较小的数据,然后再分batch输入网络进行训练。如果数据量太大,这种方法是十分消耗内存的,这时可以使用tensorflow提供的队列queue从文件中提取数据(比如csv文件等)。还有一种较为常用的,高效的读取方法,即使用tensorflow内定标准格式——TFRecords.作者也是刚接触tensorflow,对日常学习遇到的问题做简单记录,有不对地方需要指正。

什么是TFRecord?

TFRecord是谷歌推荐的一种常用的存储二进制序列数据的文件格式,理论上它可以保存任何格式的信息。下面是Tensorflow的官网给出的文档结构,整个文件由文件长度信息,长度校验码,数据,数据校验码组成。

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

二、代码及相关简介

1.构建写入数据的writer

import numpy as np 
import tensorflow as tf 
writer = tf.python_io.TFRecordWriter(‘test.tfrecord‘)

2. TFRecord

TensorFlow经常使用 tf.Example 来写入,读取TFRecord数据。

通常tf.example有下面几种数据结构:

  • tf.train.FloatList: 可以使用的类型包括 float和double
  • tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
  • f.train.BytesList: 可以使用的类型包括 string和byte

TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature,如下所示:

#feature一般是多维数组,要先转为list
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
 
#tostring函数后feature的形状信息会丢失,把shape也写入
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape))) 
 
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))

下面以一个具体的简单例子来介绍tf.example

for k in range(0, 3):
    x = 0.1712 + k
    y = [1+k, 2+k]
    z = np.array([[1,2,3],[4,5,6]]) + k
    z = z.astype(np.uint8)
    z_raw = z.tostring()
    example = tf.train.Example(
        features = tf.train.Features(
            feature = {‘x‘:tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
                       ‘y‘:tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
                       ‘z‘:tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}))
    serialized = example.SerializeToString()
    writer.write(serialized)
writer.close()

x,y,z分别是以float,int64和string的形式存储的,注意观察下面语句:

feature = {‘x‘:tf.train.Feature(float_list = tf.train.FloatList(value = [x])),
           ‘y‘:tf.train.Feature(int64_list = tf.train.Int64List(value = y)),
           ‘z‘:tf.train.Feature(bytes_list = tf.train.BytesList(value = [z_raw]))}

value的值是一个list形式,x定义的为一个数,value的值应为[x],同样y定义的格式就是一个list所以value的值直接为y即可,z_raw是由z转换过来的string形式,对应的value值与x的形式应该是一样的。

3.创建文件读取队列并读取其中内容(字典格式)

#output file name string to a queue
filename_queue = tf.train.string_input_producer([‘test.tfrecord‘], num_epochs = None)
#Create a reader from file queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
#Get feature from serialized example
features = tf.parse_single_example(serialized_example,
                features = {‘x‘: tf.FixedLenFeature([],tf.float32),
                            ‘y‘: tf.FixedLenFeature([2],tf.int64),
                            ‘z‘: tf.FixedLenFeature([],tf.string)})

4.读取数据

x_out  = features[‘x‘]
y_out  = features[‘y‘]
z_raw_out = features[‘z‘]
z_out = tf.decode_raw(z_raw_out,tf.uint8)
z_out = tf.reshape(z_out, [2,3])
print(x_out)
print(y_out)
print(z_out)

显示结果为:

Tensor("ParseSingleExample_2/ParseSingleExample:0", shape=(), dtype=float32)
Tensor("ParseSingleExample_2/ParseSingleExample:1", shape=(2,), dtype=int64)
Tensor("Reshape_1:0", shape=(2, 3), dtype=uint8)

主要参考:
TensorFlow学习记录-- 7.TensorFlow高效读取数据之tfrecord详细解读
tensorflow学习笔记——高效读取数据的方法(TFRecord

Tensorflow学习记录 --TensorFlow高效读取数据tfrecord

原文:https://www.cnblogs.com/ysfurh/p/14127941.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!