【深度學習】神經網絡 批量標準化-part6

九、批量標準化

是一種廣泛使用的神經網絡正則化技術,對每一層的輸入進行標準化,進行縮放和平移,目的是加速訓練,提高模型穩定性和泛化能力,通常在全連接層或是卷積層之和,激活函數之前使用

核心思想

對每一批數據的通道進行標準化,解決內部協變量偏移

? ? ? ? 加速網絡訓練;運行使用更大的學習率;減少對初始化的依賴;提供輕微的正則化效果

思路:在輸入上執行標準化操作,學習兩可訓練的參數:縮放因子γ和偏移量β

?批量標準化操作 在訓練階段和測試階段行為是不同的。測試階段沒有mini_batch數據,無法直接計算當前batch的均值和方差,所以使用訓練階段計算的全局統量(均值和方差)進行標準化

1. 訓練階段的批量標準化

1.1 計算均值和方差

對于給定的神經網絡層,輸入,m是批次大小。我們計算該批次數據的均值和方差

均值

方差

1.2 標準化

用計算得到的均值和方差對數據進行標準化,使得沒個特征的均值為0,方差為1

標準化后的值

ε是很小的常數,防止除0

1.3 縮放和平移

標準化的數據通常會通過可訓練的參數進行縮放和平移,以揮發模型的表達能力

縮放

平移

γ和β是在訓練過程中學習到的參數,會隨著網絡的訓練過程通過反向傳播進行更新

1.4 更新全局統計量

指數移動平均更新全局均值和方差

momentum是超變量,控制當前mini-batch統計量對全局統計量的貢獻

它在0到1之間,控制mini-batch統計量的權重,在pytorch默認為0.1

與優化器中的momentum的區別

標準化中的:

更新全局統計量

控制當前mini-batch統計量對全局統計量的貢獻

優化器中:

加速梯度下降,跳出局部最優

2.測試階段的批量標準化

測試階段沒有mini-batch數據,所以通過EMA計算的全局統計量來進行標準化

測試階段用全局統計量對輸入數據進行標準化

對標準化后的數據進行縮放和平移

為什么用全局統計量

一致性

  • 測試階段,輸入數據通常是單個樣本或少量樣本無法準確計算均值和方差

  • 使用全局統計量可以確保測試階段的行為與訓練階段一致

穩定性

  • 全局統計量是通過訓練階段的大量 mini-batch 數據計算得到的,能夠更好地反映數據的整體分布

  • 使用全局統計量可以減少測試階段的隨機性,使模型的輸出更加穩定

效率

  • 在測試階段,使用預先計算的全局統計量可以避免重復計算,提高效率。

3. 作用

3.1 緩解梯度問題

防止激活值過大或過小,避免激活函數的飽和,緩解梯度消失或爆炸

3.2 加速訓練

輸入值分布更穩定,提高學習訓練的效率,加速收斂

3.3 減少過擬合

類似于正則化,有助于提高模型的泛化能力

避免對單一數據點的過度擬合

4. 函數說明

torch.nn.BatchNorm1d 是 PyTorch 中用于一維數據的批量標準化(Batch Normalization)模塊。

torch.nn.BatchNorm1d(num_features, ? ? ? ? # 輸入數據的特征維度eps=1e-05, ? ? ? ? ? # 用于數值穩定性的小常數momentum=0.1, ? ? ? ?# 用于計算全局統計量的動量affine=True, ? ? ? ? # 是否啟用可學習的縮放和平移參數track_running_stats=True, ?# 是否跟蹤全局統計量device=None, ? ? ? ? # 設備類型(如 CPU 或 GPU)dtype=None ? ? ? ? ? # 數據類型
)

參數說明:

eps:用于數值穩定性的小常數,添加到方差的分母中,防止除零錯誤。默認值:1e-05

momentum:用于計算全局統計量(均值和方差)的動量默認值:0.1,參考本節1.4

affine:是否啟用可學習的縮放和平移參數(γ和 β)。如果 affine=True,則模塊會學習兩個參數;如果 affine=False,則不學習參數,直接輸出標準化后的值 。默認值:True

track_running_stats:是否跟蹤全局統計量(均值和方差)。如果 track_running_stats=True,則在訓練過程中計算并更新全局統計量,并在測試階段使用這些統計量。如果 track_running_stats=False,則不跟蹤全局統計量,每次標準化都使用當前 mini-batch 的統計量。默認值:True

4. 代碼實現

import torch
from torch import nn
from matplotlib import pyplot as pltfrom sklearn.datasets import make_circles
from sklearn.model_selection import train_test_split
from torch.nn import functional as F
from torch import optim# 生成數據集:兩個同心圓,內圈和外圈的點分別屬于兩個類別
x, y = make_circles(n_samples=2000, noise=0.1, factor=0.4, random_state=42)
# 轉換為PyTorch張量
x = torch.tensor(x, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)# 劃分訓練集和測試集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3,random_state=42)# 可視化數據集
plt.scatter(x[:, 0], x[:, 1], c=y, cmap='coolwarm', edgecolors="k")
plt.show()# 定義帶批量歸一化的神經網絡
class NetWithBN(nn.Module):def __init__(self):super().__init__()# 第一層全連接層,輸入維度2,輸出維度64self.fc1 = nn.Linear(2, 64)# 第一層批量歸一化self.bn1 = nn.BatchNorm1d(64)# 第二層全連接層,輸入維度64,輸出維度32self.fc2 = nn.Linear(64, 32)# 第二層批量歸一化self.bn2 = nn.BatchNorm1d(32)# 第三層全連接層,輸入維度32,輸出維度2(兩個類別)self.fc3 = nn.Linear(32, 2)def forward(self, x):# 前向傳播:ReLU激活函數+批量歸一化+全連接層x = F.relu(self.bn1(self.fc1(x)))x = F.relu(self.bn2(self.fc2(x)))x = self.fc3(x)return x# 定義不帶批量歸一化的神經網絡
class NetWithoutBN(nn.Module):def __init__(self):super().__init__()# 第一層全連接層,輸入維度2,輸出維度64self.fc1 = nn.Linear(2, 64)# 第二層全連接層,輸入維度64,輸出維度32self.fc2 = nn.Linear(64, 32)# 第三層全連接層,輸入維度32,輸出維度2(兩個類別)self.fc3 = nn.Linear(32, 2)def forward(self, x):# 前向傳播:ReLU激活函數+全連接層x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 定義訓練函數
def train(model, x_train, y_train, x_test, y_test, name, lr=0.1, epoches=500):# 定義交叉熵損失函數criterion = nn.CrossEntropyLoss()# 定義SGD優化器optimizer = optim.SGD(model.parameters(), lr=lr)# 用于記錄訓練損失和測試準確率train_loss = []test_acc = []for epoch in range(epoches):# 設置模型為訓練模式model.train()# 前向傳播y_pred = model(x_train)# 計算損失loss = criterion(y_pred, y_train)# 反向傳播optimizer.zero_grad()loss.backward()optimizer.step()# 記錄訓練損失train_loss.append(loss.item())# 設置模型為評估模式model.eval()# 禁用梯度計算with torch.no_grad():# 前向傳播y_test_pred = model(x_test)# 獲取預測類別_, pred = torch.max(y_test_pred, dim=1)# 計算正確預測的數量correct = (pred == y_test).sum().item()# 計算測試準確率test_acc.append(correct / len(y_test))# 每100個epoch打印一次日志if epoch % 100 == 0:print(F"{name}|Epoch:{epoch},loss:{loss.item():.4f},acc:{test_acc[-1]:.4f}")return train_loss, test_acc# 創建帶批量歸一化的模型
model_bn = NetWithBN()
# 創建不帶批量歸一化的模型
model_nobn = NetWithoutBN()# 訓練帶批量歸一化的模型
bn_train_loss, bn_test_acc = train(model_bn, x_train, y_train, x_test, y_test,name="BN")
# 訓練不帶批量歸一化的模型
nobn_train_loss, nobn_test_acc = train(model_nobn, x_train, y_train, x_test, y_test,name="NoBN")# 定義繪圖函數
def plot(bn_train_loss, nobn_train_loss, bn_test_acc, nobn_test_acc):# 創建繪圖窗口fig = plt.figure(figsize=(10, 5))# 添加子圖1:訓練損失ax1 = fig.add_subplot(1, 2, 1)ax1.plot(bn_train_loss, "b", label="BN")ax1.plot(nobn_train_loss, "r", label="NoBN")ax1.legend()# 添加子圖2:測試準確率ax2 = fig.add_subplot(1, 2, 2)ax2.plot(bn_test_acc, "b", label="BN")ax2.plot(nobn_test_acc, "r", label="NoBN")ax2.legend()# 顯示圖像plt.show()# 調用繪圖函數
plot(bn_train_loss, nobn_train_loss, bn_test_acc, nobn_test_acc)

?

十、模型的保存和加載

?1.標準網絡模型構建

class MyModel(nn.Module):def __init__(self,input_size,output_size):super(MyModel,self).__init__()self.fc1 = nn.Linear(input_size,128)self.fc2 = nn.Linear(128,64)self.fc3 = nn.Linear(64,output_size)def forward(self,x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return outputmodel = MyModel(input_size=10,output_size = 2)
x  =torch.randn(5,10)output = model(x)

?2. 序列化模型對象

模型保存

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

參數說明:

  • obj:要保存的對象,可以是模型、張量、字典等。

  • f:保存文件的路徑或文件對象。可以是字符串(文件路徑)或文件描述符。

  • pickle_module:用于序列化的模塊,默認是 Python 的 pickle 模塊。

  • pickle_protocol:pickle 模塊的協議版本,默認是 DEFAULT_PROTOCOL(通常是最高版本)。

模型加載

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

參數說明:

  • f:文件路徑或文件對象。可以是字符串(文件路徑)或文件描述符。

  • map_location:指定加載對象的設備位置(如 CPU 或 GPU)。默認是 None,表示保持原始設備位置。例如:map_location=torch.device('cpu') 將對象加載到 CPU。

  • pickle_module:用于反序列化的模塊,默認是 Python 的 pickle 模塊。

  • pickle_load_args:傳遞給 pickle_module.load() 的額外參數。

import torch
import torch.nn as nn
import pickleclass MyModel(nn.Module):def __init__(self,input_size,output_size):super(MyModel,self).__init__()self.fc1 = nn.Linear(input_size,output_size,128)self.fc2 = nn.Linear(128,64)self.fc3 = nn.Linear(64,output_size)def forward(self,x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return output
def test001():model = MyModel(input_size=128,output_size=32)torch.save(model,"model.pkl",pickle_module=pickle,pickle_protocol=2)def test002():model = torch.load("model.pkl",map_location = "cpu",pickle_module=pickle)print(model)test001()
test002()

.pkl是二進制文件,內容是通過pickle模塊化序列的python對象。可能存在兼容問題(python2,3的區別)

.pth是二進制文件,序列化的pytorch模型或張量。

3. 模型保存參數

import torch
import torch.nn as nn
import torch.optim as optim
import pickleclass MyModle(nn.Module):def __init__(self,input_size,output_size):super(MyModle,self).__init__()self.fc1 = nn.Linear(input_size,128)self.fc2 = nn.Linear(128,64)self.fc3 = nn.Linear(64,output_size)def forward(self,x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return outputdef test003():model = MyModle(input_size=128,output_size=32)optimizer = optim.SGD(model.parameters(),lr = 0.01)save_dict = {"init_params":{"input_size":128,"output_size":32,},"accuracy":0.99,"model_state_dict":model.state_dict(),"optimizer_state_dict":optimizer.state_dict(),}torch.save(save_dict,"model_dict.pth")def test004():save_dict = torch.load("model_dict.pth")model = MyModle(input_size = save_dict["init_params"]["input_size"],output_size = save_dict["init_params"]["output_size"],)model.load_state_dict(save_dict["model_state_dict"])optimizer = optim.SGD(model.parameters(),lr = 0.01)optimizer.load_state_dict(save_dict["optimizer_state_dict"])print(save_dict["accuracy"])print(model)test003()
test004()

推理時加載模型參數簡單如下

# 保存模型狀態字典
torch.save(model.state_dict(), 'model.pth')
?
# 加載模型狀態字典
model = MyModel(128, 32)
model.load_state_dict(torch.load('model.pth'))
?

十一、項目實戰

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/89498.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/89498.shtml
英文地址,請注明出處:http://en.pswp.cn/web/89498.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【數據可視化-67】基于pyecharts的航空安全深度剖析:墜毀航班數據集可視化分析

🧑 博主簡介:曾任某智慧城市類企業算法總監,目前在美國市場的物流公司從事高級算法工程師一職,深耕人工智能領域,精通python數據挖掘、可視化、機器學習等,發表過AI相關的專利并多次在AI類比賽中獲獎。CSDN…

【科研繪圖系列】R語言繪制分組箱線圖

文章目錄 介紹 加載R包 數據下載 導入數據 畫圖1 畫圖2 合并圖 系統信息 參考 介紹 【科研繪圖系列】R語言繪制分組箱線圖 加載R包 library(ggplot2) library(patchwork)rm(list = ls()) options(stringsAsFactors = F)

基于Android的旅游計劃App

項目介紹系統打開進入登錄頁面,如果沒有注冊過賬號,點擊注冊按鈕輸入賬號、密碼、郵箱即可注冊,注冊后可登錄進入系統,系統分為首頁、預訂、我的三大模塊,下面具體詳細說說三大模塊功能說明。1.首頁顯示旅游備忘或旅游…

【LeetCode 2163. 刪除元素后和的最小差值】解析

目錄LeetCode中國站原文原始題目題目描述示例 1:示例 2:提示:講解分割線的藝術:前后綴分解與優先隊列的完美邂逅第一部分:算法思想 —— “分割線”與前后綴分解1. 想象一條看不見的“分割線”2. 前后綴分解&#xff1…

控制鼠標和鍵盤

控制鼠標和鍵盤的Python庫Python中有多個庫可以用于控制鼠標和鍵盤,常用的包括pyautogui、pynput、keyboard和mouse等。這些庫提供了模擬用戶輸入的功能,適用于自動化測試、GUI操作等場景。使用pyautogui控制鼠標pyautogui是一個跨平臺的庫,支…

基于按鍵開源MultiButton框架深入理解代碼框架(二)(指針的深入理解與應用)

文章目錄2、針對該開源框架理解3、分析代碼3.1 再談指針、數組、數組指針3.2 繼續分析源碼2、針對該開源框架理解 在編寫按鍵模塊的框架中,一定要先梳理按鍵相關的結構體、枚舉等變量。這些數據是判斷按鍵按下、狀態跳轉、以及綁定按鍵事件的核心。 這一部分定義是…

web前端渡一大師課 CSS屬性計算過程

你是否了解CSS 的屬性計算過程呢? <body> <h1>這是一個h1標題</h1> </body> 目前我們沒有設置改h1的任何樣式,但是卻能看到改h1有一定的默認樣式,例如有默認的字體大小,默認的顏色 那么問題來了,我們這個h1元素上面除了有默認字體大小,默認顏色等…

Redis高頻面試題:利用I/O多路復用實現高并發

Redis 通過 I/O 多路復用&#xff08;I/O Multiplexing&#xff09;技術實現高并發&#xff0c;這是其單線程模型能夠高效處理大量客戶端連接的關鍵。以下是通俗易懂的解釋&#xff0c;結合 Redis 的工作原理&#xff0c;詳細說明其實現過程。 1. 什么是 I/O 多路復用&#xff…

爬蟲小知識(二)網頁進行交互

一、提交信息到網頁 1、模塊核心邏輯 “提交信息到網頁” 是網絡交互關鍵環節&#xff0c;借助 requests 庫的 post() 函數&#xff0c;能模擬瀏覽器向網頁發數據&#xff08;如表單、文件 &#xff09;&#xff0c;實現信息上傳&#xff0c;讓我們能與網頁背后的服務器 “溝通…

WPF學習(五)

文章目錄一、FileStream和StreamWriter理解1.1、具體關系解析1.2、類比理解1.3、總結1.4、示例代碼1.5、 WriteLine()和 Write&#xff08;&#xff09;的區別1.6、 StreamWriter.Close的作用二、一、FileStream和StreamWriter理解 在 C# 中&#xff0c;StreamWriter 和 FileS…

ctf.show-web習題-web2-最簡單的sql注入-flag獲取詳解、總結

解題思路打開靶場既然提示是最簡單的sql注入了&#xff0c;那么直接嘗試永真登錄1 or 11#這里閉合就是簡單的單引號可以看到沒登錄成功&#xff0c;但是有回顯&#xff1a;歡迎你&#xff0c;ctfshowsql注入最喜歡的就是回顯了&#xff01;這題的思路就是靠這個回顯&#xff0c…

upload-labs 靶場通關(1-20)

目錄 Pass-01(JS 繞過) Pass-02(文件類型驗證) Pass-03(黑名單驗證) Pass-04(黑名單驗證.htaccess) Pass-05(大小寫繞過) Pass-06(末尾空格) Pass-07(增加一個.) Pass-08(增加一個::$DATA) Pass-09&#xff08;代碼不嚴謹&#xff09; Pass-10&#xff08;PPHPHP&am…

[附源碼+數據庫+畢業論文]基于Spring+MyBatis+MySQL+Maven+vue實現的酒店預訂管理系統,推薦!

摘 要 使用舊方法對酒店預訂信息進行系統化管理已經不再讓人們信賴了&#xff0c;把現在的網絡信息技術運用在酒店預訂信息的管理上面可以解決許多信息管理上面的難題&#xff0c;比如處理數據時間很長&#xff0c;數據存在錯誤不能及時糾正等問題。 這次開發的酒店預訂管理系…

LSTM入門案例(時間序列預測)| pytorch實現(可復現)

需求 假如我有一個時間序列&#xff0c;例如是前113天的價格數據&#xff08;訓練集&#xff09;&#xff0c;然后我希望借此預測后30天的數據&#xff08;測試集&#xff09;&#xff0c;實際上這143天的價格數據都已經有了。這里為了簡單&#xff0c;每一天的數據只有一個價…

Axure RP 10 預覽顯示“無標題文檔”的空白問題探索【護航版】

1. 安裝情況 官網 Axure RP 10&#xff1a;Download Axure RP 10 - Axure &#xff08;PS&#xff1a;11都出了&#xff09; 版本&#xff1a;10.0.0.3924 激活碼&#xff1a;49bb9513c40444b9bcc3ce49a7a022f9 &#xff08;10/11都可以用&#xff0c;但只嘗試了10&#xff…

基于SpringBoot+Vue的汽車租賃系統(協同過濾算法、騰訊地圖API、支付寶沙盒支付、WebsSocket實時聊天、ECharts圖形化分析)

系統亮點&#xff1a;協同過濾算法、騰訊地圖API、支付寶沙盒支付、WebsSocket實時聊天、ECharts圖形化分析&#xff1b;01系統開發工具與環境搭建—前后端分離架構項目架構&#xff1a;B/S架構運行環境&#xff1a;win10/win11、jdk17前端&#xff1a;技術&#xff1a;框架Vue…

數據結構入門:像整理收納一樣簡單!

在我們生活中&#xff0c;經常會面對這樣的問題&#xff1a; “我要怎么整理我的衣柜&#xff1f;” “電腦里照片太多了&#xff0c;怎么歸類才方便查找&#xff1f;” 其實&#xff0c;程序員也有類似的煩惱。他們不整理衣柜&#xff0c;而是“整理數據”。而這門關于如何“收…

力扣每日一題--2025.7.15

&#x1f4da; 力扣每日一題–2025.7.15 3135. 有效單詞 &#xff08;簡單&#xff09; 大家好&#xff01;今天我們要來聊聊一道有趣的編程題——有效單詞 &#x1f4dd; 題目描述 題目分析 &#x1f4da; 題目要求我們判斷一個字符串是否為有效單詞。有效單詞需要滿足以下…

Mysql數據庫——增刪改查CRUD

文章目錄一、數據庫的基礎命令二、創建表三、增(create)四、查詢&#xff08;retrieve)五、條件查詢&#xff08;where&#xff09;六、修改&#xff08;update&#xff09;七、刪除&#xff08;delete&#xff09;一、數據庫的基礎命令 1.使用客戶端連接服務器 mysql -u root…

關于pytorch虛擬環境及具體bug問題修改

本篇博客包含對于虛擬環境概念的講解和代碼實現過程中相關bug的解決關于虛擬環境我的pytorch虛擬環境在D盤&#xff0c;相應python解釋器也在D盤&#xff08;一起&#xff09;&#xff0c;但是我的pycharm中的項目在C盤&#xff0c;使用的是pytorch的虛擬環境&#xff0c;這是為…