Pytorch中register_buffer和torch.nn.Parameter的異同

說下register_buffer和Parameter的異同

相同點

方面描述
追蹤都會被加入 state_dict(模型保存時會保存下來)。
Module 的綁定都會隨著模型移動到 cuda / cpu / float() 等而自動遷移。
都是 nn.Module 的一部分都可以通過模塊屬性訪問,如 self.x

不同點

方面torch.nn.Parameterregister_buffer
是否是可訓練參數? 是,會被視為模型需要優化的參數(model.parameters() 中包含)? 否,不會被優化器更新
梯度計算默認 requires_grad=True,參與反向傳播默認 requires_grad=False,不參與反向傳播
用途場景模型的權重、偏置等需要學習的參數均值、方差、mask、位置編碼等常量或狀態,如 BatchNorm 中的 running mean/var
注冊方式self.w = nn.Parameter(tensor)self.register_parameter("w", nn.Parameter(...))self.register_buffer("buf", tensor)
是否顯示在 parameters()? 會顯示? 不會顯示
是否能直接賦值注冊? 可以直接賦值? 必須通過 register_buffer() 注冊,否則不會記錄到 state_dict

使用建議

情境推薦使用
需要優化nn.Parameter
只做記錄或參與計算但不優化register_buffer
實現自定義模塊(如 BatchNorm)時的狀態register_buffer
使用位置編碼、attention maskregister_buffer
模型保存中需要但不訓練register_buffer

這里我自己寫了一個測試代碼,分別運行ToyModel1 2 3 保存并讀取,相信會對這兩個函數有很深刻的認識。

import torch
import torch.nn as nn
import torch.nn.functional as Fclass ToyModel(nn.Module):def __init__(self, inChannels, outChannels):super().__init__()self.a1 = 1 # 實例成員,不會保存在ckpt中self.a2 = 2self.linear = nn.Linear(inChannels, outChannels)self.init_weights()def init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):out = self.linear(x)return outclass ToyModel2(nn.Module):def __init__(self, inChannels, outChannels):super().__init__()self.a1 = 1 # 實例成員,不會保存在ckpt中self.a2 = 2self.linear = nn.Linear(inChannels, outChannels)self.init_weights()self.b1 = nn.Parameter(torch.randn(outChannels),) # 模型參數,requires_grad=True, 保存進ckptdef init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):out = self.linear(x)out += self.b1return outclass ToyModel3(nn.Module):def __init__(self, inChannels, outChannels):super().__init__()self.a1 = 1 # 實例成員,不會保存在ckpt中self.a2 = 2self.linear = nn.Linear(inChannels, outChannels)self.init_weights()self.b1 = nn.Parameter(torch.randn(outChannels),)self.register_buffer("c1", torch.ones_like(self.b1), persistent=True) # 類成員,requires_grad=False, 保存進ckpt,用于保存需要直接計算的常量,可以用self.c1訪問def init_weights(self):for m in self.modules():if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)nn.init.zeros_(m.bias)def forward(self, x):out = self.linear(x)out += self.b1out += self.c1return out
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from pathlib import Pathfrom models import ToyModel2, ToyModel, ToyModel3logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s')if __name__ == "__main__":savePath = Path("toymodel3.pth")logger = logging.getLogger(__name__)inp = torch.randn(3, 5)model = ToyModel3(inp.size(1), inp.size(1) * 2)pred = model(inp)logger.info(f"{pred.size()=}")for m in model.modules():logger.info(m)for name, param in model.named_parameters():logger.info(f"{name = }, {param.size() = }, {param.requires_grad=}")for name, buffer in model.named_buffers():logger.info(f"{name = }, {buffer.size() = }")torch.save(model.state_dict(), savePath)
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Pathfrom models import ToyModel, ToyModel2, ToyModel3if __name__ == "__main__":savePath = Path("toymodel3.pth")inp = torch.randn(3, 5)model = ToyModel3(inp.size(1), inp.size(1) * 2)ckpt = torch.load(savePath, map_location="cpu", weights_only=True)model.load_state_dict(ckpt)pred = model(inp)print(f"{pred.size()=}")for m in model.modules():print(m)for name, param in model.named_parameters():print(f"{name = }, {param.size() = }, {param.requires_grad=}")for name, buffer in model.named_buffers():print(f"{name = }, {buffer.size() = }")

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/90600.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/90600.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/90600.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

吉吉巳資源整站源碼完整打包,適用于搭建資源聚合/整合類站點,全網獨家,拿來就用

想要搭建一個資源整合站點,如影視聚合類站點、資訊聚合類站點、圖集聚合類站點等,需要花費大量的時間來查找合適的系統或源碼。然后要去測試,修復bug,一直到能夠正常的運營使用,花費的時間絕對不短,今天分享…

嵌入式學習的第三十五天-進程間通信-HTTP

TCP/IP協議模型:應用層:HTTP;傳輸層:TCP UDP;網絡層:IPv4 IPv6網絡接口層一、HTTP協議1. 萬維網WWW(World Wide Web) 世界范圍內的,聯機式的信息儲藏所。 萬維網解決了獲取互聯網上的數據時需要解決的以下問題&#x…

es 和 lucene 的區別

1. Lucene 是“發動機”,ES 是“整車”Lucene:只是一個 Java 庫,提供倒排索引、分詞、打分等底層能力。你必須自己寫代碼處理索引創建、更新、刪除、分片、分布式、故障恢復、API 封裝等所有邏輯。Elasticsearch:基于 Lucene 的分…

AS32S601 系列 MCU芯片GPIO Sink/Source 能力測試方法

一、引言隨著電子技術的飛速發展,微控制器(MCU)在工業控制、汽車電子、商業航天等眾多領域得到了廣泛應用。國科安芯推出的AS32S601 系列 MCU 以其卓越的性能和可靠性,成為了眾多設計工程師的首選之一。為了確保其在實際應用中的穩…

JAVA-08(2025.07.24學習記錄)

面向對象類package com.mm;public class Person {/*** 名詞-屬性*/String name;int age;double height;/*** 動詞-方法*/public void sleep(String add) {System.out.println("我在" add "睡覺");}public String introduce() {return "我的名字是&q…

地下隧道管廊結構健康監測系統 測點的布設及設備選型

隧道監測背景 隧道所處地下環境復雜,在施工過程中會面臨圍堰變形、拱頂沉降、凈空收斂、初襯應力變化、土體塌方等多種危險情況。在隧道營運過程中,也會受到材料退化、地震、人為破壞等因素影響,引發隧道主體結構的劣化和損壞,若不…

node.js卸載與安裝超詳細教程

文章目錄一、卸載Step1:通過控制面板刪除node版本Step2:刪除node的安裝目錄Step3:查找.npmrc文件是否存在,有就刪除。Step4:查看以下文件是否存在,有就刪除Step5:打開系統設置,檢查系…

飛算JavaAI“刪除接口信息” 功能:3 步清理冗余接口,讓管理效率翻倍

在飛算JavaAI的接口設計與管理流程中,“刪除接口信息” 功能為用戶提供了靈活調整接口方案的便利。該功能的存在,讓用戶能夠在接口生命周期的前期(審核階段)及時清理無需創建的接口,保證接口管理的簡潔性與高效性。一、…

行業熱點丨SimLab解決方案如何高效應對3D IC多物理場與ECAD建模挑戰?

半導體行業正快速超越傳統2D封裝技術,積極采用 3D集成電路(3D ICs)和2.5D 先進封裝等方案。這些技術通過異構芯粒、硅中介層和復雜多層布線實現更高性能與集成度。然而,由于電子計算機輔助設計(ECAD)數據規…

2025暑期—05神經網絡-BP網絡

按誤差反向傳播(簡稱誤差反傳)訓練的多層前饋網絡線性回歸或者分類不需要使用神經元,原有最小二程即可。求解J依次變小。使用泰勒展開,只看第一階。偏導是確定的,需要讓J小于0的delta WkWk構造完成后 J(Wk1)已知&#…

qml的信號槽機制

qml的信號槽機制和qtwidget差不多,但是使用方法不一樣,qtwidget一般直接用connect函數把信號和槽一綁定就完事了,qml分為自動綁定和手動綁定。信號自動綁定在一個組件里面定義一個信號,用signal定義,當事件觸發&#x…

Unity國際版下載鏈接分享(非c1國內版)

轉載Unity國際版下載鏈接分享(非c1國內版) - 嗶哩嗶哩 大家平時使用Unity注意一下會發現,現在我們下載的Unity版本號后面都一個c1,但是大家在B站學習時大神UP主們使用的Unity版本號大都是沒有c1的。 例如:我在用的是…

第4章唯一ID生成器——4.1 分布式唯一ID

在復雜的系統中,每個業務實體都需要使用ID做唯一標識,以方便進行數據操作。例如,每個用戶都有唯一的用戶ID,每條內容都有唯一的內容ID,甚至每條內容下的每條評論都有唯一的評論ID。 4.1.1 全局唯一與UUID 在互聯網還未…

圖論水題日記

cf1805D 題意 給定一棵樹,規定dis(u,v)≥kdis(u,v) \geq kdis(u,v)≥k時(u,v)(u,v)(u,v)之間存在一條無向邊,求k(1,2,...n)k(1,2,...n)k(1,2,...n)時圖中的連通塊個數 思路 前置知識:樹上一點到其最遠的點一定是樹直徑的兩個端點之一若一個點…

自定義線程

每個程序至少有一個線程 —— 主線程 主線程是程序的起點,你可以從它開始創建新的線程來執行任務。為此,你需要創建自定義線程,編寫在線程中執行的代碼,并啟動它。 通過繼承創建自定義線程 創建新線程有兩種主要方式:繼…

2025真實面試試題分析-安卓客戶端開發

以下是對安卓客戶端開發工程師面試問題的分類整理、領域占比分析及高頻問題精選(基于??85道問題,總出現次數118次??)。按技術領域整合為??7大核心類別??,按占比排序并精選高頻問題標注優先級(1-5🌟…

算法學習筆記:29.拓撲排序——從原理到實戰,涵蓋 LeetCode 與考研 408 例題

拓撲排序(Topological Sorting)是一種針對有向無環圖(DAG)的線性排序算法,它將圖中的頂點按照一定規則排列,使得對于圖中的任意一條有向邊 u→v,頂點 u 都排在頂點 v 之前。拓撲排序在任務調度、…

利用Web3加密技術保障您的在線數據安全

在這個信息爆炸的數字化時代,保護個人和企業數據安全變得尤為重要。Web3技術以其去中心化和加密特性,為在線數據安全提供了新的解決方案。本文將探討Web3技術如何通過加密技術保障您的在線數據安全,并介紹如何有效利用這些技術。 什么是Web3技…

Vue實現el-checkbox單選并回顯選中

先說需求 我要在頁面進行checkbox單選并回顯 第一步先把基本的頁面寫好噢&#xff1a;vue代碼&#xff1a;別忘了寫change啊<el-form-item label"按鈕顏色:" prop"menuColor"><el-checkbox-group v-model"buttonColor" change"bin…

動態規劃--序列找優問題【1】

一、說明 動態規劃似乎針對問題很多&#xff0c;五花八門&#xff0c;似乎每一個問題都有一套具體算法。其實不是的&#xff0c;動態規劃只有兩類&#xff1a;1&#xff09;針對圖的路徑問題 2&#xff09;針對一個序列的問題。本篇講動態規劃針對序列的算法范例。 二、動態規劃…