【一起來學AI大模型】PyTorch 實戰示例:使用 BatchNorm 處理張量(Tensor)

PyTorch 實戰示例?演示如何在神經網絡中使用?BatchNorm?處理張量(Tensor),涵蓋關鍵實現細節和常見陷阱。示例包含數據準備、模型構建、訓練/推理模式切換及結果分析。


示例場景:在 CIFAR-10 數據集上實現帶 BatchNorm 的 CNN

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 設備配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 1. 數據準備 & 預處理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 標準化到[-1,1]
])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = DataLoader(test_set, batch_size=100, shuffle=False)# 2. 定義帶 BatchNorm 的 CNN
class CNNWithBN(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(# Conv-BN-ReLU-Pool 模塊nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.BatchNorm2d(64),  # 關鍵!通道數=64nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.BatchNorm2d(128),  # 通道數=128nn.ReLU(),nn.MaxPool2d(2, 2))self.classifier = nn.Sequential(nn.Linear(128 * 8 * 8, 512),nn.BatchNorm1d(512),  # 全連接層也適用BNnn.ReLU(),nn.Linear(512, 10))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)  # 展平return self.classifier(x)model = CNNWithBN().to(device)# 3. 訓練循環(重點:BN的訓練模式)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  # 配合BN的Weight Decaydef train(epoch):model.train()  # 切換到訓練模式(啟用BN的mini-batch統計)for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 4. 測試推理(重點:BN的推理模式)
def test():model.eval()  # 切換到評估模式(使用全局統計量)correct = 0with torch.no_grad():  # 禁用梯度計算for inputs, labels in test_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = outputs.max(1)correct += predicted.eq(labels).sum().item()accuracy = 100. * correct / len(test_set.dataset)print(f'Test Accuracy: {accuracy:.2f}%')return accuracy# 5. 執行訓練與測試
for epoch in range(10):train(epoch)acc = test()# 6. 查看BN層參數(實戰調試)
print("\nBatchNorm層參數檢查:")
for name, module in model.named_modules():if isinstance(module, nn.BatchNorm2d):print(f"{name}: weight={module.weight.data.mean().item():.4f}, "f"bias={module.bias.data.mean().item():.4f}")print(f"  Running Mean: {module.running_mean.mean().item():.4f}, "f"Running Var: {module.running_var.mean().item():.4f}")

關鍵實戰細節解析

1. BatchNorm 層初始化
nn.BatchNorm2d(num_features)  # 必須與輸入通道數一致
nn.BatchNorm1d(512)          # 全連接層適用
2. 模式切換的重要性
模式代碼BN行為忘記切換的后果
訓練model.train()使用當前batch的統計量更新?running_mean/running_var推理時統計量錯誤,精度大幅下降
推理model.eval()固定使用訓練積累的?running_mean/running_var訓練引入測試噪聲,收斂不穩定
3. 參數解讀(以?nn.BatchNorm2d?為例)
# 可學習參數
bn_layer.weight   # γ (縮放因子), shape=(C,)
bn_layer.bias     # β (偏移因子), shape=(C,)# 自動統計量(訓練時更新)
bn_layer.running_mean   # 全局均值估計, shape=(C,)
bn_layer.running_var    # 全局方差估計, shape=(C,)
4. 常見錯誤及解決方案
  • 錯誤1:Batch Size 過小(<16)

    # 解決方案:使用GroupNorm替代
    nn.GroupNorm(num_groups=32, num_channels=128)

  • 錯誤2:忘記在測試時調用?model.eval()

    # 正確做法:在推理前顯式切換模式
    model.eval()
    with torch.no_grad():output = model(input_tensor)

  • 錯誤3:微調時錯誤處理 BN 統計量

    # 凍結BN的統計量(只更新γ/β)
    for module in model.modules():if isinstance(module, nn.BatchNorm2d):module.eval()  # 固定running_mean/var


BatchNorm 張量變換可視化

假設輸入張量維度:(batch_size, channels, height, width) = (4, 3, 2, 2)

input_tensor = torch.randn(4, 3, 2, 2)  # 模擬輸入數據# BatchNorm2d 操作步驟
bn = nn.BatchNorm2d(3)  # 通道數=3# 前向傳播分解:
# 1. 計算每個通道的均值和方差
mean_per_channel = input_tensor.mean(dim=[0, 2, 3])  # shape=(3,)
var_per_channel = input_tensor.var(dim=[0, 2, 3], unbiased=False)# 2. 標準化 (x - μ) / √(σ2 + ε)
normalized = (input_tensor - mean_per_channel[None, :, None, None]) / torch.sqrt(var_per_channel[None, :, None, None] + 1e-5)# 3. 縮放和偏移
output = normalized * bn.weight[None, :, None, None] + bn.bias[None, :, None, None]

性能對比(CIFAR-10 實驗結果)

模型測試精度收斂速度訓練穩定性
無 BatchNorm78.2%慢 (20 epochs)需要精細調參
帶 BatchNorm86.7%快 (8 epochs)高學習率魯棒
BatchNorm + Dropout85.9%最優正則化

注意:BN 的輕微正則化效果可能部分替代 Dropout,但組合使用需調整丟棄概率

通過這個實戰示例,你可以直觀理解 BatchNorm 如何操作張量,以及它在實際訓練中的關鍵作用。建議在 Colab 中運行代碼并嘗試修改 BN 參數(如?momentum?參數控制統計量更新速度),觀察對結果的影響。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/88161.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/88161.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/88161.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

第8章:應用層協議HTTP、SDN軟件定義網絡、組播技術、QoS

應用層協議HTTP 應用層協議概述 應用層協議非常多&#xff0c;我們重點熟悉以下常見協議功能即可。 Telnet:遠程登錄協議&#xff0c;基于TCP 23端口&#xff0c;用于遠程管理設備&#xff0c;采用明文傳輸。安全外殼協議 (SecureShell,SSH) ,基于TCP 22端口&#xff0c;用于…

uniapp頁面間通信

uniapp中通過eventChannel實現頁面間通信的方法&#xff0c;這是一種官方推薦的高效傳參方式。我來解釋下這種方式的完整實現和注意事項&#xff1a;?發送頁面&#xff08;父頁面&#xff09;?&#xff1a;uni.navigateTo({url: /pages/detail/detail,success: (res) > {/…

Android ViewModel機制與底層原理詳解

Android 的 ViewModel 是 Jetpack 架構組件庫的核心部分&#xff0c;旨在以生命周期感知的方式存儲和管理與 UI 相關的數據。它的核心目標是解決兩大痛點&#xff1a; 數據持久化&#xff1a; 在配置變更&#xff08;如屏幕旋轉、語言切換、多窗口模式切換&#xff09;時保留數…

雙倍硬件=雙倍性能?TDengine線性擴展能力深度實測驗證!

軟件擴展能力是軟件架構設計中的一個關鍵要素&#xff0c;具有良好擴展能力的軟件能夠充分利用新增的硬件資源。當軟件性能與硬件增加保持同步比例增長時&#xff0c;我們稱這種現象為軟件具有線性擴展能力。要實現這種線性擴展并不簡單&#xff0c;它要求軟件架構精心設計&…

頻繁迭代下完成iOS App應用上架App Store:一次快速交付項目的完整回顧

在一次面向商戶的會員系統App開發中&#xff0c;客戶要求每周至少更新一次版本&#xff0c;涉及功能迭代、UI微調和部分支付方案的更新。團隊使用Flutter進行跨平臺開發&#xff0c;但大部分成員日常都在Windows或Linux環境&#xff0c;只有一臺云Mac用于打包。如何在高頻率發布…

springsecurity03--異常攔截處理(認證異常、權限異常)

目錄 Spingsecurity異常攔截處理 認證異常攔截 權限異常攔截 注冊異常攔截器 設置跨域訪問 Spingsecurity異常攔截處理 認證異常攔截 /*自定義認證異常處理器類*/ Component public class MyAuthenticationExceptionHandler implements AuthenticationEntryPoint {Overr…

企業如何制作網站?網站制作的步驟與流程?

以下是2025年網站制作的綜合指南&#xff0c;涵蓋核心概念、主流技術及實施流程&#xff1a; 一、定義與范疇 網站制作是通過頁面結構設計、程序設計、數據庫開發等技術&#xff0c;將視覺設計轉化為可交互網頁的過程&#xff0c;包含前端展示與后臺功能實現。其核心目標是為企…

Rust+Blender:打造高性能游戲引擎

基于Rust和Blender的游戲引擎 以下是基于Rust和Blender的游戲引擎開發實例,涵蓋不同應用場景和技術方向的實際案例。案例分為工具鏈整合、渲染技術、物理模擬等類別,每個案例附核心代碼片段或實現邏輯。 工具鏈整合案例 案例1:Blender模型導出到Bevy引擎 使用blender-bev…

Git基本操作1

Git 是一款分布式版本控制系統&#xff0c;主要用于高效管理代碼版本和團隊協作開發。它能精確記錄每次代碼修改&#xff0c;支持版本回溯和分支管理&#xff0c;讓開發者可以并行工作而互不干擾。通過本地提交和遠程倉庫同步&#xff0c;Git 既保障了代碼安全&#xff0c;又實…

React Native 組件間通信方式詳解

React Native 組件間通信方式詳解 在 React Native 開發中&#xff0c;組件間通信是核心概念之一。以下是幾種主要的組件通信方式及其適用場景&#xff1a; 簡單父子通信&#xff1a;使用 props 和回調函數兄弟組件通信&#xff1a;提升狀態到共同父組件跨多級組件&#xff1a;…

TCP的可靠傳輸機制

TCP通過校驗和、序列號、確認應答、重發控制、連接管理以及窗口控制等機制實現可靠性的傳輸。 先來看第一個可靠性傳輸的方法。 通過序列號和可靠性提供可靠性 TCP是面向字節的。TCP把應用層交下來的報文&#xff08;可能要劃分為許多較短的報文段&#xff09;看成一個一個字節…

沒有DBA的敏捷開發管理

前言一家人除了我都去旅游了&#xff0c;我這項請假&#xff0c;請不動啊。既然在家了&#xff0c;閑著也是閑著&#xff0c;就復盤下最近的工作&#xff0c;今天就復盤表結構管理吧&#xff0c;隨系統啟動的&#xff0c;不是flyway&#xff0c;而是另一個liquibase&#xff0c…

go-carbon v2.6.10發布,輕量級、語義化、對開發者友好的 golang 時間處理庫

carbon 是一個輕量級、語義化、對開發者友好的 Golang 時間處理庫&#xff0c;提供了對時間穿越、時間差值、時間極值、時間判斷、星座、星座、農歷、儒略日 / 簡化儒略日、波斯歷 / 伊朗歷的支持。 carbon 目前已捐贈給 dromara 開源組織&#xff0c;已被 awesome-go 收錄&am…

【AI News | 20250708】每日AI進展

AI Repos 1、claude-code-templates Claude Code Templates是一款全面的命令行工具&#xff0c;旨在為不同編程語言和框架&#xff08;如JavaScript/TypeScript、Python等&#xff0c;Go和Rust即將推出&#xff09;提供優化的Claude Code配置。它通過交互式設置、自動化鉤子&a…

Nginx源碼安裝+靜態站點部署指南(CentOS 7)

安裝包&#xff1a;可自行前往我的飛書下載 Docs 也可以進入 nginx 官網&#xff0c;下載自己所需適應版本 nginx 開始安裝nginx 1. 創建準備目錄 cd /opt mkdir soft module # 創建軟件包和源碼解壓目錄 2. 安裝依賴環境 yum -y install make zlib zlib-devel gcc-c l…

交換機的核心原理和作用

一、交換機的核心原理交換機是一種用于連接多臺設備的網絡硬件&#xff0c;其核心原理基于二層網絡&#xff08;數據鏈路層&#xff09;的 MAC 地址尋址1. MAC 地址學習與存儲當交換機接收到數據幀時&#xff0c;會讀取幀中的源 MAC 地址&#xff0c;并將該地址與對應的端口號記…

【工具變量】上市公司企業金融強監管數據、資管新規數據(2001-2024年)

數據簡介&#xff1a;參考頂刊《經濟研究》李青原&#xff08;2022&#xff09;老師的做法&#xff0c;Post 為時間虛擬變量&#xff0c;根據資管新規實施的時間&#xff0c;當觀測期為2018 年上半年及之后時&#xff0c;Post 取值1&#xff0c;否則取值0。PreFin 為資管新規實…

CSS Grid與Flexbox布局實戰對比

概述 CSS布局技術在過去幾年經歷了重大變革&#xff0c;從傳統的基于浮動和定位的方法&#xff0c;到現在強大的Flexbox和Grid布局系統。這兩種現代布局方法極大地簡化了復雜界面的開發過程&#xff0c;但它們各自適用于不同的場景。本文將對Flexbox和Grid進行深入比較&#x…

[Pytest][Part 4]多種測試運行方式

實現需求2&#xff1a;有兩種運行測試的方式&#xff1a;通過config配置文件運行&#xff0c;測試只需要修改config配置文件cmdline 運行這里是新建一個config類來存儲所有的測試配置&#xff0c;以后配置有修改的話也只需要修改這個類。根據目前的測試需求&#xff0c;config中…

平衡二叉樹的刪除操作

對于平衡二叉樹的操作應對與考試只需要模擬出過程即可&#xff0c;且他的過程和插入的平衡方法一樣&#xff0c;不一樣的只是對于平衡因子的計算上。接下來我將給出方法①刪除結點&#xff08;方法同“二叉排序樹”&#xff09; ②一路向北找到最小不平衡子樹&#xff0c;找不到…