[普通]load and save tf weight

作者(passion) 阅读(1039次) 评论(0) 分类( 算法)

ckpt = tf.train.Checkpoint(model=encoder)
ckpt.restore(checkpoint_from_path).expect_partial()

tf1_var_dict = {}
for weight in encoder.weights:
var_name = weight.name.split(":")[0]
np_val = weight.numpy()
tf1_var_dict[var_name] = np_val
logging.info("Converted: %s --> %s %s", weight.name, var_name, np_val.shape)

with tf.Graph().as_default():
for name in tf1_var_dict:
tf1_var_dict[name] = tf.Variable(tf1_var_dict[name], name=name)
saver = tf.train.Saver(tf1_var_dict)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  logging.info("Writing checkpoint_to_path: %s", checkpoint_to_path)
  saver.save(sess, checkpoint_to_path)


« 上一篇:fastlabel 最强版标注神器,想你所想,做你想做
« 下一篇:推荐论文每周更新
在这里写下您精彩的评论
  • 微信

  • QQ

  • 支付宝

返回首页
返回首页 img
返回顶部~
返回顶部 img