[普通]ckpt加载

作者(passion) 阅读(1366次) 评论(0) 分类( 算法)
import tensorflow as tf
from bert import modeling
import os
# 这里是下载下来的bert配置文件
bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json")
#  创建bert的输入
#placeholder占位符,在tensorflow中类似于函数参数,运行时必须传入值
#dtype:数据类型,常用的是tf.float32,tf.float64等数值类型
#shape:数据形状,默认是None,就是一维值,也可以是多维,比如[2,3]两行三列,[None,3]行不固定三列
#name:名称
input_ids=tf.placeholder (shape=[64,128],dtype=tf.int32,name="input_ids")
input_mask=tf.placeholder (shape=[64,128],dtype=tf.int32,name="input_mask")
segment_ids=tf.placeholder (shape=[64,128],dtype=tf.int32,name="segment_ids")
# 创建bert模型
model = modeling.BertModel(
    config=bert_config,
    is_training=True,
    input_ids=input_ids,
    input_mask=input_mask,
    token_type_ids=segment_ids,
    use_one_hot_embeddings=False # 这里如果使用TPU 设置为True,速度会快些。使用CPU 或GPU 设置为False ,速度会快些。
)
#bert模型参数初始化的地方
init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt"
use_tpu = False
# 获取模型中所有的训练参数。
tvars = tf.trainable_variables()
# 加载BERT模型
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
# 打印加载模型的参数
for var in tvars:
    init_string = ""
    if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
    tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                    init_string)
                    
#通过会话tf.Session().run()进行循环优化网络参数
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())


« 上一篇:fastlabel 最强版标注神器,想你所想,做你想做
« 下一篇:ubuntu时区的显示与设置
在这里写下您精彩的评论
  • 微信

  • QQ

  • 支付宝

返回首页
返回首页 img
返回顶部~
返回顶部 img