import tensorflow as tf from bert import modeling import os # 这里是下载下来的bert配置文件 bert_config = modeling.BertConfig.from_json_file("chinese_L-12_H-768_A-12/bert_config.json") # 创建bert的输入 #placeholder占位符,在tensorflow中类似于函数参数,运行时必须传入值 #dtype:数据类型,常用的是tf.float32,tf.float64等数值类型 #shape:数据形状,默认是None,就是一维值,也可以是多维,比如[2,3]两行三列,[None,3]行不固定三列 #name:名称 input_ids=tf.placeholder (shape=[64,128],dtype=tf.int32,name="input_ids") input_mask=tf.placeholder (shape=[64,128],dtype=tf.int32,name="input_mask") segment_ids=tf.placeholder (shape=[64,128],dtype=tf.int32,name="segment_ids") # 创建bert模型 model = modeling.BertModel( config=bert_config, is_training=True, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=False # 这里如果使用TPU 设置为True,速度会快些。使用CPU 或GPU 设置为False ,速度会快些。 ) #bert模型参数初始化的地方 init_checkpoint = "chinese_L-12_H-768_A-12/bert_model.ckpt" use_tpu = False # 获取模型中所有的训练参数。 tvars = tf.trainable_variables() # 加载BERT模型 (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") # 打印加载模型的参数 for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) #通过会话tf.Session().run()进行循环优化网络参数 with tf.Session() as sess: sess.run(tf.global_variables_initializer())
微信
支付宝