顺便一提,上节定义的网络结构有问题,现已修改,之后会陆续整理上来。
两种常用(我会的)的加载方式:
1.
‘‘‘
使用原网络保存的模型加载到自己重新定义的图上
可以使用python变量名加载模型,也可以使用节点名
‘‘‘
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
IMAGE_PATH = ‘./flower_photos/daisy/5673728_71b8cb57eb.jpg‘
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
y = Net.inference_1(x, N_CLASS=5, train=False)
with tf.Session() as sess:
# 程序前面得有 Variable 供 save or restore 才不报错
# 否则会提示没有可保存的变量
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(‘./model/‘)
img_raw = tf.gfile.FastGFile(IMAGE_PATH, ‘rb‘).read()
img = sess.run(tf.expand_dims(tf.image.resize_images(
tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess,‘./model/model.ckpt-0‘)
global_step = ckpt.model_checkpoint_path.split(‘/‘)[-1].split(‘-‘)[-1]
res = sess.run(y, feed_dict={x: img})
print(global_step,sess.run(tf.argmax(res,1)))
2.
‘‘‘
直接使用使用保存好的图
无需加载python定义的结构,直接使用节点名称加载模型
由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错
现阶段不推荐使用,以后如果理解深入了可能会找到使用方法
‘‘‘
import AlexNet_train as train
import random
import tensorflow as tf
IMAGE_PATH = ‘./flower_photos/daisy/5673728_71b8cb57eb.jpg‘
# x = tf.placeholder(
# tf.float32, [1, train.INPUT_SIZE[0],train.INPUT_SIZE[1], 3], name=‘Placeholder‘)
ckpt = tf.train.get_checkpoint_state(‘./model/‘)
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +‘.meta‘)
with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path)
img_raw = tf.gfile.FastGFile(IMAGE_PATH, ‘rb‘).read()
img = sess.run(tf.image.resize_images(
tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
imgs = []
for i in range(128):
imgs.append(img)
print(sess.run(tf.get_default_graph().get_tensor_by_name(‘fc3:0‘),feed_dict={‘Placeholder:0‘: imgs}))
‘‘‘
img = sess.run(tf.expand_dims(tf.image.resize_images(
tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
print(img)
imgs = []
for i in range(128):
imgs.append(img)
print(sess.run(tf.get_default_graph().get_tensor_by_name(‘conv1:0‘),
feed_dict={‘Placeholder:0‘:img}))
注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。
『TensorFlow』徒手装高达_战斗数据收集模块原型_save&restore
原文:http://www.cnblogs.com/hellcat/p/6925757.html