这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;整合什么呢,就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。
tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,我们可以将tf的图和参数文件整合进一个后缀为pb的二进制文件中,由于整合过程回将变量转化为常量,所以我们在日后读取模型文件时不能够进行训练,仅能向前传播,而且我们在保存时需要指定节点名称。
将图变量转换为常量的API:tf.graph_util.convert_variables_to_constants
转换后的graph_def对象转换为二进制数据(graph_def.SerializeToString())后,写入pb即可。
1
2
3
4
5
6
7
8
9
10
11
12
13
|
import tensorflow as tf v1 = tf.Variable(tf.constant( 1.0 , shape = [ 1 ]), name = ‘v1‘ ) v2 = tf.Variable(tf.constant( 2.0 , shape = [ 1 ]), name = ‘v2‘ ) result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, ‘./tmodel/test_model.ckpt‘ ) gd = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(), [ ‘add‘ ]) with tf.gfile.GFile( ‘./tmodel/model.pb‘ , ‘wb‘ ) as f: f.write(gd.SerializeToString()) |
我们可以直接查看gd:
node {
name: "v1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
}
float_val: 1.0
}
}
}
}
……
node {
name: "add"
op: "Add"
input: "v1/read"
input: "v2/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
}
四、从图上读取张量
上面的代码实际上已经包含了本小节的内容,但是由于从图上读取特定的张量是如此的重要,所以我仍然单独的补充上这部分的内容。
无论如何,想要获取特定的张量我们必须要有张量的名称和图的句柄,比如 ‘import/pool_3/_reshape:0‘ 这种,有了张量名和图,索引就很简单了。
从二进制模型加载张量
第二小节的代码很好的展示了这种情况
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
BOTTLENECK_TENSOR_NAME = ‘pool_3/_reshape:0‘ # 瓶颈层输出张量名称 JPEG_DATA_TENSOR_NAME = ‘DecodeJpeg/contents:0‘ # 输入层张量名称 MODEL_DIR = ‘./inception_dec_2015‘ # 模型存放文件夹 MODEL_FILE = ‘tensorflow_inception_graph.pb‘ # 模型名 # 加载模型 # with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),‘rb‘) as f: # 阅读器上下文 with open (os.path.join(MODEL_DIR, MODEL_FILE), ‘rb‘ ) as f: # 阅读器上下文 graph_def = tf.GraphDef() # 生成图 graph_def.ParseFromString(f.read()) # 图加载模型 # 加载图上节点张量(按照句柄理解) bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def( # 从图上读取张量,同时导入默认图 graph_def, return_elements = [BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME]) |
从当前图中获取对应张量
这个就是很普通的情况,从我们当前操作的图中获取某个张量,用于feed啦或者用于输出等操作,API也很简单,用法如下:
g.get_tensor_by_name(‘import/pool_3/_reshape:0‘)
g表示当前图句柄,可以简单的使用 g = tf.get_default_graph() 获取。
从图中获取节点信息
有的时候我们对于模型中的节点并不够了解,此时我们可以通过图句柄来查询图的构造:
1
2
|
g = tf.get_default_graph() print (g.as_graph_def().node) |
这个操作将返回图的构造结构。从这里,对比前面的代码,我们也可以了解到:graph_def 实际就是图的结构信息存储形式,我们可以将之还原为图(二进制模型加载代码中展示了),也可以从图中将之提取出来(本部分代码)。