PyTorch教程:如何讀寫張量與模型參數

本文演示了PyTorch中張量(Tensor)和模型參數的保存與加載方法,并提供完整的代碼示例及輸出結果,幫助讀者快速掌握數據持久化的核心操作。


1. 保存和加載單個張量

通過torch.savetorch.load可以直接保存和讀取張量。

import torch# 創建并保存張量
x = torch.arange(4)
torch.save(x, 'x-file')# 加載張量
x2 = torch.load('x-file')
print(x2)  # 輸出:tensor([0, 1, 2, 3])

輸出結果

tensor([0, 1, 2, 3])

2. 保存和加載張量列表

可以將多個張量存儲為列表,并一次性加載。

# 創建兩個張量并保存為列表
y = torch.zeros(4)
torch.save([x, y], 'x-files')# 加載列表
x2, y2 = torch.load('x-files')
print((x2, y2))

輸出結果

(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

3. 保存和加載字典

通過字典可以更靈活地管理多個張量。

# 創建字典并保存
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')# 加載字典
mydict2 = torch.load('mydict')
print(mydict2)

輸出結果

{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

4. 定義神經網絡模型

以下是一個簡單的全連接神經網絡示例:

from torch import nn
from torch.nn import functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.hidden = nn.Linear(20, 256)  # 隱藏層self.output = nn.Linear(256, 10)   # 輸出層def forward(self, x):return self.output(F.relu(self.hidden(x)))# 實例化模型并進行前向傳播
net = Model()
x = torch.rand(size=(2, 20))
y = net(x)
print(y)

輸出結果(因隨機初始化可能不同):

tensor([[-0.0711, 0.1161, -0.1113, ..., 0.0787],[-0.0151, 0.0275, -0.1652, ..., 0.0109]], grad_fn=<AddmmBackward0>)

5. 保存模型參數

使用state_dict保存模型參數:

torch.save(net.state_dict(), 'net.params')

6. 加載模型參數并驗證

加載參數到新模型實例,并驗證一致性:

# 創建新模型并加載參數
clone = Model()
clone.load_state_dict(torch.load('net.params'))
clone.eval()  # 設置為評估模式(關閉Dropout/BatchNorm等)# 比較輸出結果
Y_clone = clone(x)
print(Y_clone == y)

輸出結果

tensor([[True, True, ..., True],[True, True, ..., True]])

總結

  1. 張量讀寫:直接使用torch.savetorch.load,支持列表和字典。

  2. 模型參數保存:通過state_dict保存模型狀態,加載時需重新實例化模型。

  3. 驗證一致性:加載參數后,輸出與原模型一致表明操作成功。

通過本文的代碼示例,讀者可以快速掌握PyTorch中數據和模型參數的持久化方法,為模型訓練和部署提供便利。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/75061.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/75061.shtml
英文地址,請注明出處:http://en.pswp.cn/web/75061.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

持續集成:GitLab CI/CD 與 Jenkins CI/CD 的全面剖析

一、引言 在當今快速迭代的軟件開發領域,持續集成(Continuous Integration,CI)已成為保障軟件質量、加速開發流程的關鍵實踐。通過頻繁地將代碼集成到共享倉庫,并自動進行構建和測試,持續集成能夠盡早發現并解決代碼沖突和缺陷。而 GitLab CI/CD 和 Jenkins CI/CD 作為兩…

Python 序列構成的數組(序列的增量賦值)

序列的增量賦值 增量賦值運算符 和 * 的表現取決于它們的第一個操作對象。簡單起 見&#xff0c;我們把討論集中在增量加法&#xff08;&#xff09;上&#xff0c;但是這些概念對 * 和其他 增量運算符來說都是一樣的。 背后的特殊方法是 iadd &#xff08;用于“就地加法”&…

GEO, TCGA 等將被禁用?!這40個公開數據庫可能要小心使用了

GEO, TCGA 等將被禁用&#xff1f;&#xff01;這40個公開數據庫可能要小心使用了 最近NIH公共數據庫開始對中國禁用的消息鬧得風風火火&#xff1a; 你認為研究者上傳到 GEO 數據庫上的數據會被禁用嗎&#xff1f; 單選 會&#xff0c;畢竟占用存儲資源 不會&#xff0c;不…

【如何自建MCP服務器?從協議原理到實踐的全流程指南】

文章目錄 如何自建MCP服務器&#xff1f;從協議原理到實踐的全流程指南一、MCP協議是什么&#xff1f;核心架構 二、為什么要自建MCP服務器&#xff1f;1. 突破LLM的固有局限2. 實現個性化功能擴展3. 確保數據隱私安全 三、手把手搭建MCP服務器&#xff08;Python示例&#xff…

鴻蒙開發_ARKTS快速入門_語法說明_渲染控制---純血鴻蒙HarmonyOS5.0工作筆記012

然后我們再來看渲染控制 首先看條件渲染,其實就是根據不同的狀態,渲染不同的UI界面 比如下面這個暫停 開啟播放的 可以看到就是通過if 這種條件語句 修改狀態變量的值 然后我們再來看這個, 下面點擊哪個,上面橫線就讓讓他顯示哪個 去看一下代碼 可以看到,有兩個狀態變量opt…

【Java設計模式】第3章 軟件設計七大原則

3-1 本章導航 學習開辟原則(基礎原則)依賴倒置原則單一職責原則接口隔離原則迪米特法則(最少知道原則)里氏替換原則合成復用原則(組合復用原則)核心思想: 設計原則需結合實際場景平衡,避免過度設計。設計模式中可能部分遵循原則,需靈活取舍。3-2 開閉原則講解 定義 軟…

JVM即時編譯(JIT)

JVM基礎回顧 Java 作為一門高級程序語言&#xff0c;由于它自身的語言特性&#xff0c;它并非直接在硬件上運行&#xff0c;而是通過編譯器(前端編譯器)將 Java 程序轉換成該虛擬機所能識別的指令序列&#xff0c;也就是字節碼&#xff0c;然后運行在虛擬機之上的&#xff1b;…

剛體碰撞檢測與響應(C++實現)

本文實現一個經典的物理算法&#xff1a;剛體碰撞檢測與響應。這個算法用于檢測兩個剛體&#xff08;如矩形或圓形&#xff09;是否發生碰撞&#xff0c;并在碰撞時更新它們的速度和位置。我們將使用C來實現這個算法&#xff0c;并結合**邊界框&#xff08;Bounding Box&#x…

常用的國內鏡像源

常見的 pip 鏡像源 阿里云鏡像&#xff1a;https://mirrors.aliyun.com/pypi/simple/ 清華大學鏡像&#xff1a;https://pypi.tuna.tsinghua.edu.cn/simple 中國科學技術大學鏡像&#xff1a;https://pypi.mirrors.ustc.edu.cn/simple/ 豆瓣鏡像&#xff1a;https://pypi.doub…

鴻蒙小案例-京東登錄

效果 代碼實現 Entry Component struct Index {build() {Column() {Row() {Image($r(app.media.jd_cancel)).width(20)Text(幫助).fontSize(16).fontColor(#666)}.width(100%).justifyContent(FlexAlign.SpaceBetween)Image($r(app.media.jd_logo)).height(250).width(250)// …

《 Scikit-learn與MySQL的深度協同:構建智能數據生態系統的架構哲學》

在機器學習工程實踐中&#xff0c;數據存儲與模型訓練的割裂始終是制約算法效能的關鍵瓶頸。Scikit-learn作為經典機器學習庫&#xff0c;其與MySQL的深度協同并非簡單的數據管道連接&#xff0c;而是構建了一個具備自組織能力的智能數據生態系統。這種集成突破了傳統ETL流程的…

華為AI-agent新作:使用自然語言生成工作流

論文標題 WorkTeam: Constructing Workflows from Natural Language with Multi-Agents 論文地址 https://arxiv.org/pdf/2503.22473 作者背景 華為&#xff0c;北京大學 動機 當下AI-agent產品百花齊放&#xff0c;盡管有ReAct、MCP等框架幫助大模型調用工具&#xff0…

關于軟件bug描述

軟件缺陷&#xff08;Defect&#xff09;&#xff0c;常常又被叫做Bug。 所謂軟件缺陷&#xff0c;即為計算機軟件或程序中存在的某種破壞正常運行能力的問題、錯誤&#xff0c;或者隱藏的功能缺陷。缺陷的存在會導致軟件產品在某種程度上不能滿足用戶的需要。IEEE729-1983對缺…

【元表 vs 元方法】

元表 vs 元方法 —— 就像“魔法書”和“咒語”的關系 1. 元表&#xff08;Metatable&#xff09;&#xff1a;魔法書 是什么&#xff1f; 元表是一本**“規則說明書”**&#xff0c;它本身是一個普通的 Lua 表&#xff0c;但可以綁定到其他表上&#xff0c;用來定義這個表應該…

Spring Boot 通過全局配置去除字符串類型參數的前后空格

1、問題 避免前端輸入的字符串參數兩端包含空格&#xff0c;通過統一處理的方式&#xff0c;trim掉空格 2、實現方式 /*** 去除字符串類型參數的前后空格* author yanlei* since 2022-06-14*/ Configuration AutoConfigureAfter(WebMvcAutoConfiguration.class) public clas…

C語言核心知識點整理:結構體對齊、預處理、文件操作與Makefile

目錄 結構體的字節對齊預處理指令詳解文件操作基礎Makefile自動化構建總結 1. 結構體的字節對齊 字節對齊原理 內存對齊&#xff1a;CPU訪問內存時&#xff0c;對齊的地址能提高效率。操作系統要求變量按類型大小對齊。對齊規則&#xff1a; 每個成員的起始地址必須是min(成…

VBA+BOS單據+插件,解決計劃任務跟蹤的問題之二:導入ERP

第二步&#xff0c;就是要將拆分好的任務導入ERP了 1、將建一個BOS單據叫“任務池”&#xff0c;大概是這樣的 然后在拆分工具中進行導數據&#xff0c;點擊“數據導出準備”&#xff0c;跳轉到“導入ERP”界面&#xff0c;然后點“獲取數據”&#xff0c;將拆分好的數據轉過來…

使用uglifyjs對靜態引入的js文件進行壓縮

前言 因為有時候js文件沒有npm包&#xff0c;或者需要修改&#xff0c;只能引入靜態的js&#xff0c;那么這個時候就可以對js進行壓縮了。我其實想通過vite、webpack等插件進行壓縮的&#xff0c;可是他都不能定位到public目錄下面的文件&#xff0c;所以我只能自己壓縮了。編…

藍橋杯 web 水果拼盤 (css3)

做題步驟&#xff1a; 看結構&#xff1a;html 、css 、f12 分析: f12 查看元素&#xff0c;你會發現水果的高度剛好和拼盤的高度一樣&#xff0c;每一種水果的盤子剛好把頁面填滿了&#xff0c;所以咱們就只要讓元素豎著排列&#xff0c;加上是豎著&#xff0c;排不下的換行…

差分音頻轉單端音頻單電源方案

TI LMV321介紹 TI的LMV321是單通道的低壓軌到軌輸出運算放大器&#xff0c;適用于需要低工作壓、節省空間和低成本的應用。 其中&#xff0c;芯片設計中的軌到軌輸出&#xff08;Rail-to-Rail Output&#xff09; 是指通過特定的電路設計&#xff0c;使芯片&#xff08;如運算…