首页 > 其他 > 详细

Bert层数剪枝

时间:2020-07-24 17:56:18      阅读:153      评论:0      收藏:0      [点我收藏+]

模型精简的流程如下:pretrian model -> retrain with new data(fine tuning) -> pruning -> retrain -> model

对bert进行层数剪枝,保留第一层和第十二层参数,再用领域数据微调。代码如下:

"""
    test
"""
import tensorflow as tf
import os

sess = tf.Session()
last_name = bert_model.ckpt
model_path = bert_model/chinese_L-12_H-768_A-12
imported_meta = tf.train.import_meta_graph(os.path.join(model_path, last_name + .meta))
imported_meta.restore(sess, os.path.join(model_path, last_name))
init_op = tf.local_variables_initializer()
sess.run(init_op)

bert_dict = {}
# 获取待保存的层数节点
for var in tf.global_variables():
    # print(var)
    # 提取第0层和第11层和其它的参数,其余1-10层去掉,存储变量名的数值
    if var.name.startswith(bert/encoder/layer_) and not var.name.startswith(
            bert/encoder/layer_0) and not var.name.startswith(bert/encoder/layer_11):
        pass
    else:
        bert_dict[var.name] = sess.run(var).tolist()

# print(‘bert_dict:{}‘.format(bert_dict))
# 真是保存的变量信息
need_vars = []
for var in tf.global_variables():
    if var.name.startswith(bert/encoder/layer_) and not var.name.startswith(
            bert/encoder/layer_0/) and not var.name.startswith(bert/encoder/layer_1/):
        pass
    elif var.name.startswith(bert/encoder/layer_1/):
        # 寻找11层的var name,将11层的参数给第一层使用
        new_name = var.name.replace("bert/encoder/layer_1", "bert/encoder/layer_11")
        op = tf.assign(var, bert_dict[new_name])
        sess.run(op)
        need_vars.append(var)
        print(var)
    else:
        need_vars.append(var)
        print(####,var)

# 保存model
saver = tf.train.Saver(need_vars)
saver.save(sess, os.path.join(bert_model/chinese_L-12_H-768_A-12_pruning, bert_pruning_2_layer.ckpt))

 

Bert层数剪枝

原文:https://www.cnblogs.com/demo-deng/p/13372797.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!