【深度學習】讀寫文件

讀寫文件

到目前為止,我們討論了如何處理數據,以及如何構建、訓練和測試深度學習模型。
然而,有時我們希望保存訓練的模型,以備將來在各種環境中使用(比如在部署中進行預測)。
此外,當運行一個耗時較長的訓練過程時,最佳的做法是定期保存中間結果,以確保在服務器電源被不小心斷掉時,我們不會損失幾天的計算結果。

因此,現在是時候學習如何加載和存儲權重向量和整個模型了。

(加載和保存張量)

對于單個張量,我們可以直接調用loadsave函數分別讀寫它們。
這兩個函數都要求我們提供一個名稱,save要求將要保存的變量作為輸入。

import torch
from torch import nn
from torch.nn import functional as F# 創建一個包含從 0 到 3 的整數的一維張量
x = torch.arange(4)
# 將張量 x 保存到名為 'x-file' 的文件中
torch.save(x, 'x-file')

通常 x-file的文件格式一般是.pt 或者 .pth ,用于保存 PyTorch 模型的狀態字典(state_dict)或者整個模型對象。

我們現在可以將存儲在文件中的數據讀回內存。

# 從名為 'x-file' 的文件中加載之前保存的張量,并將其賦值給變量 x2
x2 = torch.load('x-file')
# 打印加載得到的張量 x2
x2
tensor([0, 1, 2, 3])

我們可以[存儲一個張量列表,然后把它們讀回內存。]

# 創建一個包含 4 個零的一維張量
y = torch.zeros(4)
# 將張量 x 和 y 組成一個列表,并保存到名為 'x-files' 的文件中
torch.save([x, y], 'x-files')
# 從 'x-files' 文件中加載保存的張量,并將它們分別賦值給 x2 和 y2
x2, y2 = torch.load('x-files')
# 打印加載得到的張量元組 (x2, y2)
(x2, y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

我們甚至可以(寫入或讀取從字符串映射到張量的字典)。當我們要讀取或寫入模型中的所有權重時,這很方便。

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

[加載和保存模型參數]

保存單個權重向量(或其他張量)確實有用,但是如果我們想保存整個模型,并在以后加載它們,單獨保存每個向量則會變得很麻煩。
畢竟,我們可能有數百個參數散布在各處。因此,深度學習框架提供了內置函數來保存和加載整個網絡。需要注意的一個重要細節是,這將保存模型的參數而不是保存整個模型
例如,如果我們有一個3層多層感知機,我們需要單獨指定架構。因為模型本身可以包含任意代碼,所以模型本身難以序列化。因此,為了恢復模型,我們需要用代碼生成架構,然后從磁盤加載參數。
讓我們從熟悉的多層感知機開始嘗試一下。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MLP(nn.Module):"""定義一個多層感知機(MLP)模型,繼承自 nn.Module。該模型包含一個隱藏層和一個輸出層。"""def __init__(self):"""初始化 MLP 模型的各層。"""# 調用父類 nn.Module 的構造函數super().__init__()# 定義隱藏層,輸入維度為 20,輸出維度為 256self.hidden = nn.Linear(20, 256)# 定義輸出層,輸入維度為 256,輸出維度為 10self.output = nn.Linear(256, 10)def forward(self, x):"""定義模型的前向傳播過程。參數:x (torch.Tensor): 輸入張量。返回:torch.Tensor: 模型的輸出張量。"""# 對隱藏層的輸出應用 ReLU 激活函數hidden_output = F.relu(self.hidden(x))# 通過輸出層得到最終輸出return self.output(hidden_output)# 創建 MLP 模型的實例
net = MLP()
# 生成一個形狀為 (2, 20) 的隨機輸入張量
X = torch.randn(size=(2, 20))
# 將輸入張量傳入模型進行前向傳播,得到輸出
Y = net(X)

接下來,我們[將模型的參數存儲在一個叫做“mlp.params”的文件中。]

torch.save(net.state_dict(), 'mlp.params')

為了恢復模型,我們[實例化了原始多層感知機模型的一個備份。]
這里我們不需要隨機初始化模型參數,而是(直接讀取文件中存儲的參數。)

# 創建一個新的 MLP 模型實例,用于加載預訓練的參數
clone = MLP()
# 從 'mlp.params' 文件中加載保存的模型參數狀態字典,并將其加載到 clone 模型中
clone.load_state_dict(torch.load('mlp.params'))
# 將模型設置為評估模式,這會影響一些特定層(如 Dropout、BatchNorm)的行為,確保在推理時使用正確的參數
clone.eval()

load_state_dict 方法可以將一個保存好的狀態字典加載到當前的模型實例中,從而實現模型參數的恢復或遷移。狀態字典是一個 Python 字典對象,它包含了模型中所有可學習參數(如權重和偏置)的張量。

clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()

由于兩個實例具有相同的模型參數,在輸入相同的X時,兩個實例的計算結果應該相同。讓我們來驗證一下。

Y_clone = clone(X)
Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],[True, True, True, True, True, True, True, True, True, True]])

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/73121.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/73121.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/73121.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

仿Manus一

復制 ┌───────────────┐ ┌─────────────┐ │ 主界面UI │?─────?│ 會話管理模塊 │ └───────┬───────┘ └──────┬──────┘│ │▼ ▼ ┌─…

VS Code C++ 開發環境配置

VS Code 是當前非常流行的開發工具. 本文講述如何配置 VS Code 作為 C開發環境. 本文將按照如下步驟來介紹如何配置 VS Code 作為 C開發環境. 安裝編譯器安裝插件配置工作區 第一個步驟的具體操作會因為系統不同或者方案不同而有不同的選擇. 環境要求 首先需要立即 VS Code…

Flutter 學習之旅 之 flutter 不使用插件,實現簡單帶加載動畫的 LoadingToast 功能

Flutter 學習之旅 之 flutter 不使用插件,實現簡單帶加載動畫的 LoadingToast 功能 目錄 Flutter 學習之旅 之 flutter 不使用插件,實現簡單帶加載動畫的 LoadingToast 功能 一、簡單介紹 二、LoadingToast 三、簡單案例實現 四、關鍵代碼 一、簡單…

Spring (八)AOP-切面編程的使用

目錄 實現步驟&#xff1a; 1 導入AOP依賴 2 編寫切面Aspect 3 編寫通知方法 4 指定切入點表達式 5 測試AOP動態織入 圖示&#xff1a; 一 實現步驟&#xff1a; 1 導入AOP依賴 <!-- Spring Boot AOP依賴 --><dependency><groupId>org.springframewor…

開源數字人模型Heygem

一、Heygem是什么 Heygem 是硅基智能推出的開源數字人模型&#xff0c;專為 Windows 系統設計。基于先進的AI技術&#xff0c;僅需1秒視頻或1張照片&#xff0c;能在30秒內完成數字人形象和聲音克隆&#xff0c;在60秒內合成4K超高清視頻。Heygem支持多語言輸出、多表情動作&a…

uniapp開通開屏廣告后動態開啟或關閉開屏廣告

近期使用uniapp開發的APP有uniad的廣告對接&#xff0c;并且要求會員用戶不顯示包含開屏廣告在內的廣告&#xff0c;除開屏廣告外的廣告都可以通過uniapp廣告組件控制是否顯示 因uniad的開屏廣告無需代碼開發&#xff0c;經過uniad客服指點可在App.vue中的onLaunch生命周期中執…

神經網絡為什么要用 ReLU 增加非線性?

在神經網絡中使用 ReLU&#xff08;Rectified Linear Unit&#xff09; 作為激活函數的主要目的是引入非線性&#xff0c;這是神經網絡能夠學習復雜模式和解決非線性問題的關鍵。 1. 為什么需要非線性&#xff1f; 1.1 線性模型的局限性 如果神經網絡只使用線性激活函數&…

使用SSH密鑰連接本地git 和 github

目錄 配置本地SSH&#xff0c;添加到github首先查看本地是否有SSH密鑰生成SSH密鑰&#xff0c;和郵箱綁定將 SSH 密鑰添加到 ssh-agent&#xff1a;顯示本地公鑰*把下面這一串生成的公鑰存到github上* 驗證SSH配置是否成功終端跳轉到本地倉庫把http協議改為SSH&#xff08;如果…

關于AI數據分析可行性的初步評估

一、結論&#xff1a;可在部分環節嵌入&#xff0c;無法直接處理大量數據 1.非本地部署的AI應用處理非機密文件沒問題&#xff0c;內部文件要注意數據安全風險。 2.AI&#xff08;指高規格大模型&#xff09;十分適合探索性研究分析&#xff0c;對復雜報告無法全流程執行&…

矩陣分析-淺要理解(深度學習方向)

梯度分析與最優化 在深度學習的任務中&#xff0c;我們所期望的是訓練一個神經網絡&#xff0c;使得預測結果與真實標簽之間的誤差最小化&#xff0c;這可以近似看作是一個提供梯度下降等優化找到全局最優解的凸優化問題。 奇異值分解 在信息工程領域&#xff0c;對數據處理的…

使用DeepSeek+藍耘快速設計網頁簡易版《我的世界》小游戲

前言&#xff1a;如今&#xff0c;借助先進的人工智能模型與便捷的云平臺&#xff0c;即便是新手開發者&#xff0c;也能開啟創意游戲的設計之旅。DeepSeek 作為前沿的人工智能模型&#xff0c;具備強大的功能與潛力&#xff0c;而藍耘智算云平臺則為其提供了穩定高效的運行環境…

固定表頭、首列 —— uniapp、vue 項目

項目實地&#xff1a;也可以在 【微信小程序】搜索體驗&#xff1a;xny.handbook 另一個體驗項目&#xff1a;官網 一、效果展示 二、代碼展示 &#xff08;1&#xff09;html 部分 <view class"table"><view class"tr"><view class&quo…

【學習筆記】Numpy和Tensor的區別

1. NumPy 和 PyTorch Tensor 的格式對比 NumPy 使用的是 numpy.ndarray&#xff0c;而 PyTorch 使用的是 torch.Tensor&#xff0c;兩者的格式在數據存儲和計算方式上有所不同。 NumPy (numpy.ndarray) import numpy as np array np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.…

每天一道算法題【藍橋杯】【在排序數組中查找元素的第一個位置和最后一個位置】

思路 本題為查找左邊界和右邊界的標準模型 查找左邊界 int left 0, right nums.size() - 1, mid 0; //查找左邊界 while (left < right) { mid left (right - left) / 2; if (nums[mid] < target) left mid 1; else right mid; } 查找右邊界 int left 0, r…

Python數據分析之機器學習基礎

Python 數據分析重點知識點 本系列不同其他的知識點講解&#xff0c;力求通過例子讓新同學學習用法&#xff0c;幫助老同學快速回憶知識點 可視化系列&#xff1a; Python基礎數據分析工具數據處理與分析數據可視化機器學習基礎 五、機器學習基礎 了解機器學習概念、分類及…

我與DeepSeek讀《大型網站技術架構》(10)- 維基百科的高性能架構設計分析

目錄 網站整體架構核心組件請求處理流程圖關鍵環節說明 性能優化策略前端優化&#xff1a;攔截 80% 以上請求服務端優化&#xff1a;高性能 PHP 集群后端優化&#xff1a;存儲與緩存極致設計Memcached 持久化連接 性能優化策略對比表 網站整體架構 核心組件 Wikipedia 的架構…

Excel多級聯動下拉菜單設置

1.問題描述 現有數據表如下圖所示&#xff1a; 該表中包括省、市、縣三級目錄。 現要將其整理成數據表模板&#xff0c;如下圖所示&#xff1a; 要求制作成下拉菜單的形式&#xff0c;且每一級目錄的下拉菜單列表要根據上一級目錄的內容來確定。 如上圖所示&#xff0c;只有…

智駕技術全鏈條解析

智駕技術全鏈條解析&#xff08;2025年最新版&#xff09; 智駕技術涵蓋從環境感知到車輛控制的完整閉環&#xff0c;涉及硬件、算法、數據與系統集成等多個領域。以下結合行業最新進展&#xff08;截至2025年3月&#xff09;進行深度拆解&#xff1a; 一、感知技術&#xff1…

SpringMVC執行的流程

SpringMVC 基于 MVC 架構模式&#xff0c;核心流程時前端控制室 DispathcherServlet 統一調度&#xff0c;通過組件協作完成 http 的請求與響應。 對于 dispatchServlet 作為前端請求的控制器&#xff0c;全局的訪問點&#xff0c;首先將根據 URL 調用 HandlerMapping 獲取 Han…

Linux學習(十五)(故障排除(ICMP,Ping,Traceroute,網絡統計,數據包分析))

故障排除是任何 Linux 用戶或管理員的基本技能。這涉及識別和解決 Linux 系統中的問題。這些問題的范圍包括常見的系統錯誤、硬件或軟件問題、網絡連接問題以及系統資源的管理。Linux 中的故障排除過程通常涉及使用命令行工具、檢查系統和應用程序日志文件、了解系統進程&#…