saver()与restore()只是保存了session中的相关变量对应的值,并不涉及模型的结构。
saver()
可以选择global_step
参数来为ckpt文件名添加数字标记:saver.save(sess, ‘my-model‘, global_step=0) ==> filename: ‘my-model-0‘ ... saver.save(sess, ‘my-model‘, global_step=1000) ==> filename: ‘my-model-1000‘
max_to_keep
参数定义saver()
将自动保存的最近n个ckpt文件,默认n=5,即保存最近的5个检查点ckpt文件。若n=0或者None,则保存所有的ckpt文件。keep_checkpoint_every_n_hours
与max_to_keep
类似,定义每n小时保存一个ckpt文件。... # Create a saver. saver = tf.train.Saver(...variables...) # Launch the graph and train, saving the model every 1,000 steps. sess = tf.Session() for step in xrange(1000000): sess.run(..training_op..) if step % 1000 == 0: # Append the step number to the checkpoint name: saver.save(sess, ‘my-model‘, global_step=step)
一个简单的例子:
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) import tensorflow as tf import time time.clock() x = tf.placeholder(tf.float32 ,[None, 784]) W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W) + b) # 为了计算交叉熵,我们需要添加一个新的占位符用于输入正确值。 y_ = tf.placeholder(tf.float32, [None,10]) cross_entropy = -tf.reduce_sum(y_*tf.log(y)) # 在此,我们要求TF使用梯度下降算法,并以0.01的学习速率最小化交叉熵。 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) init = tf.global_variables_initializer() sess = tf.Session() sess.run(init) # 创建Saver节点,并设置自动保存最近n=1次模型 saver = tf.train.Saver(max_to_keep=1) saver_max_acc = 0 for i in range(100): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_:batch_ys}) correct_prediction = tf.equal(tf.argmax(y,1), tf.arg_max(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float")) if (i+1)%10 == 0: print(‘{0:0>2d}:{1:.4f}‘.format((i+1),accuracy.eval(session=sess, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))) # 添加判断语句,选择保存精度最高的模型 if accuracy > saver_max_acc: saver.save(sess,‘ckpt/mnist.ckpt‘,global_step=i+1) saver_max_acc = accuracy sess.close() print(time.clock())
restore(sess, save_path) # sess: A Session to use to restore the parameters. # save_path: Path where parameters were previously saved.
sess
: 保存参数的会话。save_path
: 保存参数的路径。tf.train.latest_checkpoint()
来自动获取最后一次保存的模型。如:model_file=tf.train.latest_checkpoint(‘ckpt/‘) saver.restore(sess,model_file)
原文:https://www.cnblogs.com/zhangly2020/p/14177428.html