ckpt = tf.train.Checkpoint(model=encoder)
ckpt.restore(checkpoint_from_path).expect_partial()
tf1_var_dict = {}
for weight in encoder.weights:
var_name = weight.name.split(":")[0]
np_val = weight.numpy()
tf1_var_dict[var_name] = np_val
logging.info("Converted: %s --> %s %s", weight.name, var_name, np_val.shape)
with tf.Graph().as_default():
for name in tf1_var_dict:
tf1_var_dict[name] = tf.Variable(tf1_var_dict[name], name=name)
saver = tf.train.Saver(tf1_var_dict)
with tf.Session() as sess: sess.run(tf.global_variables_initializer()) logging.info("Writing checkpoint_to_path: %s", checkpoint_to_path) saver.save(sess, checkpoint_to_path)
微信
支付宝