1、数据集生成读取文件(mnist_generateds.py)
tfrecords 文件
1)tfrecords:是一种二进制文件,可先将图片和标签制作成该格式的文件。使用 tfrecords 进行数据读取,会提高内存利用率。
2)tf.train.Example: 用来存储训练数据。训练数据的特征用键值对的形式表示。如:‘ img_raw ’ :值 ‘ label ’ :值 值是 Byteslist/FloatList/Int64List
3)SerializeToString( ):把数据序列化成字符串存储。
首先生成 tfrecords 文件 :
1)将数据集的相关路径定义好
2)读训练集和测试集;
读文件解析
a:先读入文件名,路径
b:新建一个writer,计数次数
c: 在open函数中默认为只读形式打开label_path,readlines() 方法用于读取所有行(直到结束符 EOF)并返回列表,该列表可以由 Python 的 for... in ... 结构进行处理。如果碰到结束符 EOF 则返回空字符串。
d:for 循环遍历每张图和标签
f:
example = tf.train.Example(features=tf.train.Features(feature={ ‘img_raw‘: tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), ‘label‘: tf.train.Feature(int64_list=tf.train.Int64List(value=labels)) }))
这段代码将在tf.train.examle函数中讲解
e: writer.write(example.SerializeToString()) # 把 example 进行序列化
image_train_path=‘./data/mnist_data_jpg/mnist_train_jpg_60000/‘ label_train_path=‘./data/mnist_data_jpg/mnist_train_jpg_60000.txt‘ tfRecord_train=‘./data/mnist_train.tfrecords‘ image_test_path=‘./data/mnist_data_jpg/mnist_test_jpg_10000/‘ label_test_path=‘./data/mnist_data_jpg/mnist_test_jpg_10000.txt‘ tfRecord_test=‘./data/mnist_test.tfrecords‘ data_path=‘./data‘ resize_height = 28 resize_width = 28 #生成tfrecords文件 def write_tfRecord(tfRecordName, image_path, label_path): #新建一个writer writer = tf.python_io.TFRecordWriter(tfRecordName) num_pic = 0 f = open(label_path, ‘r‘) contents = f.readlines() f.close() #循环遍历每张图和标签 for content in contents: value = content.split() img_path = image_path + value[0] img = Image.open(img_path) img_raw = img.tobytes() labels = [0] * 10 labels[int(value[1])] = 1 #把每张图片和标签封装到example中 example = tf.train.Example(features=tf.train.Features(feature={ ‘img_raw‘: tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), ‘label‘: tf.train.Feature(int64_list=tf.train.Int64List(value=labels)) })) #把example进行序列化 writer.write(example.SerializeToString()) num_pic += 1 print ("the number of picture:", num_pic) #关闭writer writer.close() print("write tfrecord successful") def generate_tfRecord(): isExists = os.path.exists(data_path) if not isExists: os.makedirs(data_path) print ‘The directory was created successfully‘ else: print ‘directory already exists‘ write_tfRecord(tfRecord_train, image_train_path, label_train_path) write_tfRecord(tfRecord_test, image_test_path, label_test_path)
解析tfrecords文件:
先看一下对应的路径文件名:
image_train_path=‘./data/mnist_data_jpg/mnist_train_jpg_60000/‘ label_train_path=‘./data/mnist_data_jpg/mnist_train_jpg_60000.txt‘ tfRecord_train=‘./data/mnist_train.tfrecords‘ image_test_path=‘./data/mnist_data_jpg/mnist_test_jpg_10000/‘ label_test_path=‘./data/mnist_data_jpg/mnist_test_jpg_10000.txt‘ tfRecord_test=‘./data/mnist_test.tfrecords‘ data_path=‘./data‘
def main(): generate_tfRecord() (1)
def generate_tfRecord(): (2)
isExists = os.path.exists(data_path) if not isExists: os.makedirs(data_path) print ‘The directory was created successfully‘ else: print ‘directory already exists‘ (3)
def get_tfrecord(num, isTrain=True): (1)
if isTrain:
tfRecord_path = tfRecord_train
else:
tfRecord_path = tfRecord_test
img, label = read_tfRecord(tfRecord_path) (2)
def read_tfRecord(tfRecord_path): (3)
filename_queue = tf.train.string_input_producer([tfRecord_path], shuffle=True) #新建一个reader reader = tf.TFRecordReader()
上面用颜色标记了一些参数的传递情况;
原文:https://www.cnblogs.com/fcfc940503/p/11019441.html