【神經網絡擴展】:斷點續訓和參數提取

課程來源:人工智能實踐:Tensorflow筆記2

文章目錄

  • 前言
  • 斷點續訓主要步驟
  • 參數提取主要步驟
  • 總結


前言

本講目標:斷點續訓,存取最優模型;保存可訓練參數至文本


斷點續訓主要步驟

讀取模型:

先定義出存放模型的路徑和文件名,命名為.ckpt文件。
生成ckpt文件的時候會同步生成索引表,所以通過判斷是否存在索引表來知曉是不是已經保存過模型參數。
如果有了索引表就利用load_weights函數讀取已經保存的模型參數。

code:


checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)

在這里插入圖片描述

保存模型:

保存模型參數可以使用TensorFlow給出的回調函數,直接保存訓練出來的模型參數
tf.keras.callbacks.ModelCheckpoint( filepath=路徑文件名(文件存儲路徑),
save_weights_only=True/False,(是否只保留參數模型)
save_best_only=True/False(是否只保留最優結果)) 執行訓練過程中時,加入callbacks選項:
history=model.fit(callbacks=[cp_callback])

code:

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])

第一次運行:
在這里插入圖片描述
第二次運行:可以發現模型并不是從初始訓練,而是在基于保存的模型開始訓練的(這一點可以從準確率和損失看出):
在這里插入圖片描述
全部代碼:

import tensorflow as tf
import osfashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
model.summary()

參數提取主要步驟

設置打印的格式,使所有參數都打印出來

np.set_printoptions(threshold=np.inf)
print(model.trainable_variables)

將所有可訓練參數存入文本:

file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()

完整代碼:

import tensorflow as tf
import os
import numpy as npnp.set_printoptions(threshold=np.inf)fashion = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])checkpoint_save_path = "./checkpoint/fashion.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('-------------load the model-----------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,callbacks=[cp_callback])
model.summary()print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:file.write(str(v.name) + '\n')file.write(str(v.shape) + '\n')file.write(str(v.numpy()) + '\n')
file.close()

效果:
在這里插入圖片描述

總結

課程鏈接:MOOC人工智能實踐:TensorFlow筆記2

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/378275.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/378275.shtml
英文地址,請注明出處:http://en.pswp.cn/news/378275.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

開發DBA(APPLICATION DBA)的重要性

開發DBA是干什么的? 1. 審核開發人員寫的SQL,并且糾正存在性能問題的SQL ---非常重要 2. 編寫復雜業務邏輯SQL,因為復雜業務邏輯SQL開發人員寫出的SQL基本上都是有性能問題的,與其讓開發人員寫,不如DBA自己寫。---非常…

javascript和var之間的區別?

You can define your variables in JavaScript using two keywords - the let keyword and the var keyword. The var keyword is the oldest way of defining and declaring variables in JavaScript whereas the let is fairly new and was introduced by ES15. 您可以使用兩…

小米手環6NFC安裝太空人表盤

以前看我室友峰哥、班長都有手環,一直想買個手環,不舍得,然后今年除夕的時候降價,一狠心,入手了,配上除夕的打年獸活動還有看春晚京東敲鼓領的紅包和這幾年攢下來的京東豆豆,原價279的小米手環6…

計算機二級c語言題庫縮印,計算機二級C語言上機題庫(可縮印做考試小抄資料)...

小抄,答案,形成性考核冊,形成性考核冊答案,參考答案,小抄資料,考試資料,考試筆記第一套1.程序填空程序通過定義學生結構體數組,存儲了若干個學生的學號、姓名和三門課的成績。函數fun 的功能是將存放學生數據的結構體數組,按照姓名的字典序(從小到大排序…

為什么兩層3*3卷積核效果比1層5*5卷積核效果要好?

目錄1、感受野2、2層3 * 3卷積與1層5 * 5卷積3、2層3 * 3卷積與1層5 * 5卷積的計算量比較4、2層3 * 3卷積與1層5 * 5卷積的非線性比較5、2層3 * 3卷積與1層5 * 5卷積的參數量比較1、感受野 感受野:卷積神經網絡各輸出特征像素點,在原始圖片映射區域大小。…

算法正確性和復雜度分析

算法正確性——循環不變式 算法復雜度的計算 方法一 代換法 —局部代換 這里直接對n變量進行代換 —替換成對數或者指數的情形 n 2^m —整體代換 這里直接對遞推項進行代換 —替換成內部遞推下標的形式 T(2^n) S(n) 方法二 遞歸樹法 —用實例說明 —分析每一層的內容 —除了…

第十五章 Python和Web

第十五章 Python和Web 本章討論Python Web編程的一些方面。 三個重要的主題:屏幕抓取、CGI和mod_python。 屏幕抓取 屏幕抓取是通過程序下載網頁并從中提取信息的過程。 下載數據并對其進行分析。 從Python Job Board(http://python.org/jobs&#x…

array_chunk_PHP array_chunk()函數與示例

array_chunkPHP array_chunk()函數 (PHP array_chunk() Function) array_chunk() function is an array function, it is used to split a given array in number of array (chunks of arrays). array_chunk()函數是一個數組函數,用于將給定數組拆分為多個數組(數組…

raise

raise - Change a windows position in the stacking order button .b -text "Hi there!"pack [frame .f -background blue]pack [label .f.l1 -text "This is above"]pack .b -in .fpack [label .f.l2 -text "This is below"]raise .b轉載于:ht…

c語言輸出最大素數,for語句計算輸出10000以內最大素數怎么搞最簡單??各位大神們...

該樓層疑似違規已被系統折疊 隱藏此樓查看此樓#include #include int* pt NULL; // primes_tableint pt_size 0; // primes_table 數量大小int init_primes_table(void){FILE* pFile;pFile fopen("primes_table.bin", "rb");if (pFile NULL) {fputs(&q…

【數據結構基礎筆記】【圖】

代碼參考《妙趣橫生的算法.C語言實現》 文章目錄前言1、圖的概念2、圖的存儲形式1、鄰接矩陣:2、鄰接表3、代碼定義鄰接表3、圖的創建4、深度優先搜索DFS5、廣度優先搜索BFS6、實例分析前言 本章總結:圖的概念、圖的存儲形式、鄰接表定義、圖的創建、圖…

第十六章 測試基礎

第十六章 測試基礎 在編譯型語言中,需要不斷重復編輯、編譯、運行的循環。 在Python中,不存在編譯階段,只有編輯和運行階段。測試就是運行程序。 先測試再編碼 極限編程先鋒引入了“測試一點點,再編寫一點點代碼”的理念。 換而…

如何蹭網

引言蹭網,在普通人的眼里,是一種很高深的技術活,總覺得肯定很難,肯定很難搞。還沒開始學,就已經敗給了自己的心里,其實,蹭網太過于簡單。我可以毫不夸張的說,只要你會windows的基本操…

android對象緩存,Android簡單實現 緩存數據

前言1、每一種要緩存的數據都是有對應的versionCode,通過versionCode請求網絡獲取是否需要更新2、提前將要緩存的數據放入assets文件夾中,打包上線。緩存設計代碼實現/*** Created by huangbo on 2017/6/19.** 主要是緩存的工具類** 緩存設計&#xff1a…

通信原理.緒論

今天剛上通信原理的第一節課,沒有涉及過多的講解,只是講了下大概的知識框架。現記錄如下: 目錄1、基本概念消息、信息與信號2、通信系統模型1、信息源2、發送設備3、信道4、接收設備5、信宿6、模擬通信系統模型7、數字通信系統模型8、信源編…

Android版本演進中的兼容性問題

原文:http://android.eoe.cn/topic/summary Android 3.0 的主要變化包括: 不再使用硬件按鍵進行導航 (返回、菜單、搜索和主屏幕),而是采用虛擬按鍵 (返回,主屏幕和最近的應用)。在操作欄固定菜單。 Android 4.0 則把這些變化帶到了手機平臺。…

css rgba透明_rgba()函數以及CSS中的示例

css rgba透明Introduction: 介紹: Functions are used regularly while we are developing a web page or website. Therefore, to be a good developer you need to master as many functions as you can. This way your coding knowledge will increase as well …

第十七章 擴展Python

第十七章 Python什么都能做,真的是這樣。這門語言功能強大,但有時候速度有點慢。 魚和熊掌兼得 本章討論確實需要進一步提升速度的情形。在這種情況下,最佳的解決方案可能不是完全轉向C語言(或其他中低級語言)&…

android studio資源二進制,無法自動檢測ADB二進制文件 – Android Studio

我嘗試在Android Studio上測試我的應用程序,但我遇到了困難"waiting for AVD to come online..."我讀過Android設備監視器重置adb會做到這一點,它確實……對于1次測試,當我第二天重新啟動電腦時,我不僅僅是:"waiting for AVD to come online..."…

犀牛腳本:仿迅雷的增強批量下載

迅雷的批量下載滿好用。但是有兩點我不太中意。在這個腳本里會有所增強 1、不能設置保存的文件名。2、不能單獨設置這批下載的線程限制。 使用方法 // 下載從編號001到編號020的圖片,保存名為貓咪寫真*.jpg 使用6個線程 jdlp http://bizhi.zhuoku.com/bizhi/200804/…