模型的保存與恢復
我們來簡單實現一下模型的保存與恢復
訓練完TensorFlow模型后,可將其保存為文件,以便于預測新數據時直接加載使用。
TensorFlow模型主要包含網絡的設計或者圖以及已經訓練好的網絡參數的值。
TensorFlow提供的tf.train.Saver()函數可以建立一個saver對象,在會話中調用其save()函數,即可將模型保存起來
save()函數的用法
函數 | 說明 |
save( ????? sess, ????? sace_path, ????? global_step=None, ????? latest_filename=None, ????? meta_graph_suffix='meta', ????? write_meta_graph=True, ????? write_state=True ) | sess:保存模型,要求必須有一個加載了計算圖的會話,且所有變量已被初始化。 sace_path:模型保存路徑及保存名稱 global_step:如果提供,該數字會添加到save_path后,用于區分不同訓練階段的結果 latest_filename:檢查點文件的名稱,默認是checkpoint meta_graph_suffix= MetaGraphDef元圖后綴,默認為meta write_meta_graph=是否要保存元圖數據,默認為True write_state:是否要保存CheckpointStateProto,默認為True |
模型保存
import tensorflow as tf
m1 = tf.Variable(tf.constant([[1.0,3.0],[2.0,4.0]],shape=[2,2]),name='m1')
m2 = tf.Variable(tf.constant([[2.0,7.0],[3.0,8.0]],shape=[2,2]),name='m2')
result = m1 + m2
saver = tf.train.Saver()
with tf.Session() as sess:sess.run(tf.global_variables_initializer())print('resulit:',sess.run(result))saver.save(sess,'C:/model/model.ckpt')
運行程序,當前目錄的model文件夾下會產生4個文件:checkpoint,data-00000-of-00001,meta和index
checkpoint:保存模型的權重、偏置、梯度以及其他保護變量的二進制文件。
data:保存模型的所有變量的值
meta:保存計算圖的結構。當meta文件存在時,不在程序中定義模型,直接加載meta可以直接運行
index:保存string-string的鍵值對。其中的key值為張量名,value為BundleEntryProto
模型恢復
模型保存好了以后,載入發出方便。
在會話中調用saver的restore()函數,就會從指定的路徑找到模型文件,并覆蓋相關參數。
saver.restore()函數的形式如表
函數 | 說明 |
saver.restore( ??? sess, ??? save_path ) | 從指定的路徑恢復模型。 sess:用于恢復參數模型的會話 save_path:已保存模型的路徑,通常包含模型名字 |
import tensorflow as tf
tf.reset_default_graph()
v1 = tf.Variable(tf.constant([[5.0,6.0],[7.0,7.0]],shape=[2,2]),name='m1')
v2 = tf.Variable(tf.constant([[4.0,6.0],[7.0,8.0]],shape=[2,2]),name='m2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:saver.restore(sess,'C:/model/model.ckpt')print(sess.run(result))
?