訓練完一個模型后,為了以后重復使用,通常我們需要對模型的結果進行保存。如果用Tensorflow去實現神經網絡,所要保存的就是神經網絡中的各項權重值。建議可以使用Saver類保存和加載模型的結果。
1、使用tf.train.Saver.save()方法保存模型
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)
- sess: 用于保存變量操作的會話。
- save_path: String類型,用于指定訓練結果的保存路徑。
- global_step: 如果提供的話,這個數字會添加到save_path后面,用于構建checkpoint文件。這個參數有助于我們區分不同訓練階段的結果。
2、使用tf.train.Saver.restore方法價值模型
tf.train.Saver.restore(sess, save_path)
- sess: 用于加載變量操作的會話。
- save_path: 同保存模型是用到的的save_path參數。
下面通過一個代碼演示這兩個函數的使用方法
import tensorflow as tf
import numpy as npx = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + bloss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))with tf.Session() as sess:sess.run(tf.initialize_all_variables())if isTrain:for i in xrange(train_steps):sess.run(train, feed_dict={x: x_data})if (i + 1) % checkpoint_steps == 0:saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)else:ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)else:passprint(sess.run(w))print(sess.run(b))