說下register_buffer和Parameter的異同
相同點方面 描述 追蹤 都會被加入 state_dict
(模型保存時會保存下來)。 與 Module
的綁定 都會隨著模型移動到 cuda
/ cpu
/ float()
等而自動遷移。 都是 nn.Module
的一部分 都可以通過模塊屬性訪問,如 self.x
。
不同點方面 torch.nn.Parameter
register_buffer
是否是可訓練參數 ? 是,會被視為模型需要優化的參數(model.parameters()
中包含) ? 否,不會被優化器更新 梯度計算 默認 requires_grad=True
,參與反向傳播 默認 requires_grad=False
,不參與反向傳播 用途場景 模型的權重、偏置 等需要學習的參數 均值、方差、mask、位置編碼等常量或狀態 ,如 BatchNorm 中的 running mean/var注冊方式 self.w = nn.Parameter(tensor)
或 self.register_parameter("w", nn.Parameter(...))
self.register_buffer("buf", tensor)
是否顯示在 parameters()
中 ? 會顯示 ? 不會顯示 是否能直接賦值注冊 ? 可以直接賦值 ? 必須通過 register_buffer()
注冊,否則不會記錄到 state_dict
使用建議情境 推薦使用 需要優化 nn.Parameter
只做記錄或參與計算但不優化 register_buffer
實現自定義模塊(如 BatchNorm)時的狀態 register_buffer
使用位置編碼、attention mask register_buffer
模型保存中需要但不訓練 register_buffer
這里我自己寫了一個測試代碼,分別運行ToyModel1 2 3 保存并讀取,相信會對這兩個函數有很深刻的認識。
import torch
import torch. nn as nn
import torch. nn. functional as Fclass ToyModel ( nn. Module) : def __init__ ( self, inChannels, outChannels) : super ( ) . __init__( ) self. a1 = 1 self. a2 = 2 self. linear = nn. Linear( inChannels, outChannels) self. init_weights( ) def init_weights ( self) : for m in self. modules( ) : if isinstance ( m, nn. Linear) : nn. init. xavier_uniform_( m. weight) nn. init. zeros_( m. bias) def forward ( self, x) : out = self. linear( x) return outclass ToyModel2 ( nn. Module) : def __init__ ( self, inChannels, outChannels) : super ( ) . __init__( ) self. a1 = 1 self. a2 = 2 self. linear = nn. Linear( inChannels, outChannels) self. init_weights( ) self. b1 = nn. Parameter( torch. randn( outChannels) , ) def init_weights ( self) : for m in self. modules( ) : if isinstance ( m, nn. Linear) : nn. init. xavier_uniform_( m. weight) nn. init. zeros_( m. bias) def forward ( self, x) : out = self. linear( x) out += self. b1return outclass ToyModel3 ( nn. Module) : def __init__ ( self, inChannels, outChannels) : super ( ) . __init__( ) self. a1 = 1 self. a2 = 2 self. linear = nn. Linear( inChannels, outChannels) self. init_weights( ) self. b1 = nn. Parameter( torch. randn( outChannels) , ) self. register_buffer( "c1" , torch. ones_like( self. b1) , persistent= True ) def init_weights ( self) : for m in self. modules( ) : if isinstance ( m, nn. Linear) : nn. init. xavier_uniform_( m. weight) nn. init. zeros_( m. bias) def forward ( self, x) : out = self. linear( x) out += self. b1out += self. c1return out
import torch
import torch. nn as nn
import torch. nn. functional as F
import logging
from pathlib import Pathfrom models import ToyModel2, ToyModel, ToyModel3logging. basicConfig( level= logging. INFO, format = '%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s' ) if __name__ == "__main__" : savePath = Path( "toymodel3.pth" ) logger = logging. getLogger( __name__) inp = torch. randn( 3 , 5 ) model = ToyModel3( inp. size( 1 ) , inp. size( 1 ) * 2 ) pred = model( inp) logger. info( f" { pred. size( ) = } " ) for m in model. modules( ) : logger. info( m) for name, param in model. named_parameters( ) : logger. info( f" { name = } , { param. size( ) = } , { param. requires_grad= } " ) for name, buffer in model. named_buffers( ) : logger. info( f" { name = } , { buffer . size( ) = } " ) torch. save( model. state_dict( ) , savePath)
import torch
import torch. nn as nn
import torch. nn. functional as F
from pathlib import Pathfrom models import ToyModel, ToyModel2, ToyModel3if __name__ == "__main__" : savePath = Path( "toymodel3.pth" ) inp = torch. randn( 3 , 5 ) model = ToyModel3( inp. size( 1 ) , inp. size( 1 ) * 2 ) ckpt = torch. load( savePath, map_location= "cpu" , weights_only= True ) model. load_state_dict( ckpt) pred = model( inp) print ( f" { pred. size( ) = } " ) for m in model. modules( ) : print ( m) for name, param in model. named_parameters( ) : print ( f" { name = } , { param. size( ) = } , { param. requires_grad= } " ) for name, buffer in model. named_buffers( ) : print ( f" { name = } , { buffer . size( ) = } " )