首页 > 其他 > 详细

TensorFlow 存储与读取

时间:2017-12-12 18:20:04      阅读:308      评论:0      收藏:0      [点我收藏+]

之前通过CNN进行的MNIST训练识别成功率已经很高了,不过每次运行都需要消耗很多的时间。在实际使用的时候,每次都要选经过训练后在进行识别那就太不方便了。

所以我们学习一下如何将训练习得的参数保存起来,然后在需要用的时候直接使用这些参数进行快速的识别。

本章节代码来自《Tensorflow 实战Google深度学习框架》5.5 TensorFlow 最佳实践样例程序  针对书中的代码做了一点点的调整。

 

mnist_inference.py:

#coding=utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500

def get_weight_variable(shape, regularizer):
    weights = tf.get_variable("weights", shape, initializer = tf.truncated_normal_initializer(stddev=0.1))
    if regularizer != None:
        tf.add_to_collection(losses, regularizer(weights))
    return weights

def inference(input_tensor, regularizer):
    with tf.variable_scope(layer1):
        weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
        biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)

    with tf.variable_scope(layer2):
        weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
        biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
        layer2 = tf.matmul(layer1, weights) + biases

    return layer2

这里是向前传播的方法文件。这个方法在训练和测试的过程都需要用到,将它抽离出来既能使用起来更加方便,也能保证训练和测试时使用的方法保持一致。

get_variable

 weights = tf.get_variable("weights", shape, initializer = tf.truncated_normal_initializer(stddev=0.1))

源代码第十行使用get_variable函数获取变量。

在训练网络是会创建这些变量;

在测试时会通过训练时保存的模型加载这些变量的值。

 

(未完待续。。。。)

TensorFlow 存储与读取

原文:http://www.cnblogs.com/guolaomao/p/8028600.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!