首页 > 其他 > 详细

tensorflow(四)

时间:2019-12-23 10:06:21      阅读:112      评论:0      收藏:0      [点我收藏+]

tensorflow数据处理方法,

1.输入数据集

小数据集,可一次性加载到内存处理。

大数据集,一般由大量数据文件组成,因为数据集的规模太大,无法一次性加载到内存,只能每一步训练时加载数据,可以采用流水线并行读取数据。

流水线并行读取数据过程, (1)创建文件名列表(2)创建文件名队列(3)创建Reader和Decoder(4)创建样例队列

filename_queue = tf.train.string_input_producer([stat0.csv,stat1.csv])

reader = tf.TextLinerReader()
_,value = reader.read(filename_queue)

record_defaults = [[0],[0],[0.0],[0.0]]
id,age = tf.decode_csv(value,record_defaults=record_defaults)
features = tf.stack([id,age])
def get_my_example(filename_queue):
    reader = tf.SomeReader()
    _,value = reader.read(filename_queue)
    features = tf.decode_some(value)
    processed_example = some_processing(features)
    return processed_example

def input_pipeline(filenames,batch_size,num_epochs=None):
    filename_queue = tf.train.string_input_producer(filenames,num_epochs,shuffle=True)
    example = get_my_example(filename_queue)
    min_after_deque = 10000
    capacity = min_after_deque + 3*batch_size
    example_batch = tf.train.shuffle_batch([example],batch_size=batch_size,capacity=capacity,min_after_deque=min_after_deque)
    
    return example_batch

x_batch = input_pipeline([stat.tfrecord],batch_size=20)
sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())

sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
try:
    for _ in range(1000):
        if not coord.should_stop():
            sess.run(train_op)
            print(example)
except:
    print(catch exception)
finally:
    coord.request_stop()
coord.join(threads)
sess.close()

2.模型参数

模型参数指的是模型的权重值和偏置值,使用tf.Variable创建模型参数

W = tf.Variable(0.0,name=W)
double = tf.multiply(2.0,W)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(4):
        sess.run(tf.assign_add(W,1.0))
        print(sess.run(W))

3.保持和恢复模型参数

tf.train.Saver是辅助训练工具类,它实现了存储模型参数的变量和checkpoint文件间的读写操作。

W = tf.Variable(0.0,name=W)
double = tf.multiply(2.0,W)

saver = tf.train.Saver({weights:W})

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(4):
        sess.run(tf.assign_add(W,1.0))
        print(sess.run(W))
        saver.save(sess,/tmp/text/ckpt)

 

tensorflow(四)

原文:https://www.cnblogs.com/yangyang12138/p/12081892.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!