首页 > 其他 > 详细

4.9 TF读入TFRecord

时间:2020-05-12 15:45:01      阅读:67      评论:0      收藏:0      [点我收藏+]
import tensorflow as tf

filelist = [‘data/train.tfrecord‘]
file_queue = tf.train.string_input_producer(filelist,  # 定义文件队列
                                            num_epochs=None,
                                            shuffle=True)
reader = tf.TFRecordReader()  # tensoeflow文件读取器从文件队列读取
_, ex = reader.read(file_queue)  # 原图-编码-序列化-打包,现在是反
# 向解析,ex是序列化之后的数据,所以还需要解码

feature = {  # 定义序列化格式
    ‘image‘: tf.FixedLenFeature([], tf.string),  # image是byte储存的,解码则直接解析为string型
    ‘label‘: tf.FixedLenFeature([], tf.int64)  # label本身就是int型
}
# 将队列中数据打乱后再读取出来
# batch_size:从队列中提取新的批量大小.
# capacity:队列容量.
# min_after_dequeue:最小队列容量.
batchsize = 2
batch = tf.train.shuffle_batch([ex], batchsize, capacity=batchsize * 10,
                               min_after_dequeue=batchsize * 5)

# 解码方法,features(字典型)有点像解析格式,返回的是字典型
example = tf.parse_example(batch, features=feature)
image = example[‘image‘]
label = example[‘label‘]

#image是string型,需要转换为uint8
image=tf.decode_raw(image, tf.uint8)

#这里的image其实是一串数字,按我们32*32*3的数据规
#模来重排序,可以制定这样的矩阵的大小,-1表示程序自动计算矩阵个数
#输出image:Tensor("DecodeRaw:0", shape=(2, ?), dtype=uint8)
image = tf.reshape(image, [-1,32, 32, 3])
#输出image:Tensor("Reshape:0", shape=(?, 32, 32, 3), dtype=uint8)


with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    tf.train.start_queue_runners(sess=sess)
    for i in range(1):
        image_bth,label=sess.run([image,label])
        import cv2
        cv2.imshow(str(label[0,...]),image_bth[0,...])
        cv2.waitKey(0)

4.9 TF读入TFRecord

原文:https://www.cnblogs.com/thgpddl/p/12876610.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!