15.手動實現BatchNorm(BN)

15.1 BatchNorm操作手動實現

import torch 
from torch import nndef batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):if not torch.is_grad_enabled():#這個是推理模式X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape)==2:mean=X.mean(dim=0)var=((X-mean)**2).mean(dim=0)else:mean=X.mean(dim=(0,2,3),keepdim=True)var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)# 更新移動平均的均值和方差X_hat=(X-mean)/torch.sqrt(var+eps)moving_mean=momentum*moving_mean+(1.0-momentum)*meanmoving_var=momentum*moving_var+(1.0-momentum)*varY=gamma*X_hat+betareturn Y,moving_mean.data,moving_var.data
class BatchNorm(nn.Module):def __init__(self, num_features,num_dims):super().__init__()if num_dims==2:shape=(1,num_features)else:shape=(1,num_features,1,1)#這是兩個需要更新的參數self.gamma=nn.Parameter(torch.ones(shape))self.beta=nn.Parameter(torch.zeros(shape))self.moving_mean=torch.zeros(shape)self.moving_var=torch.ones(shape)#這個不能為0,應該是/sqrt(var)def forward(self,X):#計算設備對齊if self.moving_mean.device!=X.device:self.moving_mean=self.moving_mean.to(X.device)self.moving_var=self.moving_var.to(X.device)Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return Y
model=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),#Flatten()之后就是[batch_size,features] 2維度的向量矩陣nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),nn.Linear(84,10))

15.2 BatchNorm實驗效果

################################################################################################################
"""BatchNorm"""
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存樣本數total_samples+=X.shape[0]test_acc_samples=0test_samples=0for X,y in test_data:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存樣本數test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):if not torch.is_grad_enabled():#這個是推理模式X_hat=(X-moving_mean)/torch.sqrt(moving_var+eps)else:assert len(X.shape) in (2,4)if len(X.shape)==2:mean=X.mean(dim=0)var=((X-mean)**2).mean(dim=0)else:mean=X.mean(dim=(0,2,3),keepdim=True)var=((X-mean)**2).mean(dim=(0,2,3),keepdim=True)# 更新移動平均的均值和方差X_hat=(X-mean)/torch.sqrt(var+eps)moving_mean=momentum*moving_mean+(1.0-momentum)*meanmoving_var=momentum*moving_var+(1.0-momentum)*varY=gamma*X_hat+betareturn Y,moving_mean.data,moving_var.data
class BatchNorm(nn.Module):def __init__(self, num_features,num_dims):super().__init__()if num_dims==2:shape=(1,num_features)else:shape=(1,num_features,1,1)#這是兩個需要更新的參數self.gamma=nn.Parameter(torch.ones(shape))self.beta=nn.Parameter(torch.zeros(shape))self.moving_mean=torch.zeros(shape)self.moving_var=torch.ones(shape)#這個不能為0,應該是/sqrt(var)def forward(self,X):#計算設備對齊if self.moving_mean.device!=X.device:self.moving_mean=self.moving_mean.to(X.device)self.moving_var=self.moving_var.to(X.device)Y,self.moving_mean,self.moving_var=batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,eps=1e-5,momentum=0.9)return Y
################################################################################################################
transforms=transforms.Compose([transforms.Resize(28),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
train_img=torchvision.datasets.FashionMNIST(root="./data",train=True,transform=transforms,download=True)
test_img=torchvision.datasets.FashionMNIST(root="./data",train=False,transform=transforms,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda:1" if torch.cuda.is_available() else 'cpu')
model=nn.Sequential(nn.Conv2d(1,6,kernel_size=5),BatchNorm(6,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Conv2d(6,16,kernel_size=5),BatchNorm(16,num_dims=4),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2,stride=2),nn.Flatten(),#Flatten()之后就是[batch_size,features] 2維度的向量矩陣nn.Linear(16*4*4,120),BatchNorm(120,num_dims=2),nn.Sigmoid(),nn.Linear(120,84),BatchNorm(84,num_dims=2),nn.Sigmoid(),nn.Linear(84,10)).to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=15)
################################################################################################################
print("BatchNorm算法學習參數效果:")
print("gamma:",model[1].gamma.reshape((-1,)))
print("beta:",model[1].beta.reshape((-1,)))

在這里插入圖片描述
在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/914777.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/914777.shtml
英文地址,請注明出處:http://en.pswp.cn/news/914777.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【項目實踐】SMBMS(Javaweb版)匯總版

文章目錄前期準備工作數據庫、數據表創建web項目創建項目文件目錄配置Tomcat,導入依賴建立實體類編寫基礎公共方法類導入基礎資源登錄功能登錄頁面持久層dao層的用戶登錄及接口實現dao層接口實現所需的方法業務層sevice層的接口的實現接口實現相關的業務邏輯編寫ser…

隱藏源IP的核心方案與高防實踐

一、源IP暴露的風險 直接DDoS攻擊:2025年Q2全球DDoS攻擊峰值達3.8Tbps(來源:Cloudflare報告)漏洞利用:暴露的SSH端口平均每天遭受12,000暴力破解嘗試數據泄露:直接連接數據庫風險提升300% 二、4種有效隱藏方…

深度學習圖像分類數據集—五種電器識別分類

該數據集為圖像分類數據集,適用于ResNet、VGG等卷積神經網絡,SENet、CBAM等注意力機制相關算法,Vision Transformer等Transformer相關算法。 數據集信息介紹:五種電器識別分類:[notebook, phone, powerbank, tablet, w…

Windows11家庭版配置frigate 嵌入自研算法(基于Yolov8)-【2】

使用 YOLOv8 的 results.xyxy 結構,下面是一個完整的 MQTT 推送腳本,用于把識別到的目標(比如突涌水、水漬、障礙物等)發送到 Frigate 的 MQTT 接口。? 前提假設 YOLOv8 推理代碼已經運行并生成 results.xyxy。每一行是 [x1, y1,…

安裝llama-factory報錯 error: subprocess-exited-with-error

報錯信息如下 Using cached https://mirrors.aliyun.com/pypi/packages/17/89/940a509ee7e9449f0c877fa984b37b7cc485546035cc67bbc353f2ac20f3/av-15.0.0.tar.gz (3.8 MB)Preparing metadata (pyproject.toml) ... errorerror: subprocess-exited-with-error Preparing metad…

QT 多線程 管理串口

記錄一下自己使用多線程進行串口管理和數據讀取的過程。如果有問題的話可以發消息給我。背景在使用QT制作一個串口數據讀取處理的小軟件的時候,發現了存在界面卡頓的情況,感覺性能太低,于是考慮把串口數據的讀取和處理都放到子線程的緩沖區中…

在虛擬環境中復現論文(環境配置)

前提:已經下載condawinR,輸入cmd進入命令行conda create -n PPT python3.8.3 pytorch1.7.0conda create -n PPT(虛擬環境名) python3.8.3(包名) pytorch1.7.0(包名)安裝完畢,激活虛擬環境:conda activate PPT根據論文readme要求安…

Flutter Web 的發展歷程:Dart、Flutter 與 WasmGC

Flutter Web 應該是 Flutter 開發者里最不“受寵”的平臺了,但是其實 Flutter 和 Dart 團隊對于 Web 的投入一直沒有減少,這也和 Flutter 還有 Dart 的"出生"有關系,今天就借著 Dart 團隊的 mer A?acan 和 Martin Kustermann 在油…

c#方法關鍵字,ref、out、int

在 C# 中,ref、out 和 in 是用于方法參數傳遞的關鍵字,它們控制參數如何在方法和調用者之間傳遞數據。以下是對這三個關鍵字的詳細分析:1. ref 關鍵字(引用傳遞)作用允許方法修改調用者的變量:通過引用傳遞…

設計模式—初識設計模式

1.設計模式經典面試題分析幾個常見的設計模式對應的面試題。1.1原型設計模式1.使用UML類圖畫出原型模式核心角色(意思就是使用會考察使用UML畫出設計模式中關鍵角色和關系圖等)2.原型設計模式的深拷貝和淺拷貝是什么,寫出深拷貝的兩種方式的源…

深度學習-參數初始化、損失函數

A、參數初始化參數初始化對模型的訓練速度、收斂性以及最終的性能產生重要影響。它可以盡量避免梯度消失和梯度爆炸的情況。一、固定值初始化在神經網絡訓練開始時,將權重或偏置初始化為常數。但這種方法在實際操作中并不常見。1.1全零初始化將所有的權重參數初始化…

格密碼--Ring-SIS和Ring-LWE

1. 多項式環&#xff08;Polynomial Rings&#xff09; 設 f∈Z[x]f \in \mathbb{Z}[x]f∈Z[x] 是首一多項式&#xff08;最高次項系數為1&#xff09; 則環 RZ[x]/(f)R \mathbb{Z}[x]/(f)RZ[x]/(f) 元素為&#xff1a;所有次數 <deg?(f)< \deg(f)<deg(f) 的多項式…

前端工作需要和哪些人打交道?

前端工作中需要協作的角色及協作要點 前端工作中需要協作的角色及協作要點 前端開發處于產品實現的 “中間環節”,既要將設計方案轉化為可交互的界面,又要與后端對接數據,還需配合團隊推進項目進度。日常工作中,需要頻繁對接的角色包括以下幾類,每類協作都有其核心目標和…

萬字長文解析 OneCode3.0 AI創新設計

一、研究概述與背景 1.1 研究背景與意義 在 AI 技術重塑軟件開發的浪潮中&#xff0c;低代碼平臺正經歷從 “可視化編程” 到 “意圖驅動開發” 的根本性轉變。這種變革不僅提升了開發效率&#xff0c;更重新定義了人與系統的交互方式。作為國內領先的低代碼平臺&#xff0c;On…

重學前端006 --- 響應式網頁設計 CSS 彈性盒子

文章目錄盒模型一、盒模型的基本概念二、兩種盒模型的對比 舉例三、總結Flexbox 彈性盒子布局一、Flexbox 的核心概念??二、Flexbox 的基本語法????1. 定義 Flex 容器???2. Flex 容器的主要屬性????3. Flex 項目的主要屬性????三、Flexbox 的常見布局示例??…

rLLM:用于LLM Agent RL后訓練的創新框架

rLLM&#xff1a;用于LLM Agent RL后訓練的創新框架 本文介紹了rLLM&#xff0c;一個用于語言智能體后訓練的可擴展框架。它能讓用戶輕松構建自定義智能體與環境&#xff0c;通過強化學習進行訓練并部署。文中還展示了用其訓練的DeepSWE等智能體的出色表現&#xff0c;以及rLL…

rocky8 --Elasticsearch+Logstash+Filebeat+Kibana部署【7.1.1版本】

軟件說明&#xff1a; 所有軟件包下載地址&#xff1a;Past Releases of Elastic Stack Software | Elastic 打開頁面后選擇對應的組件及版本即可&#xff01; 所有軟件包名稱如下&#xff1a; 架構拓撲&#xff1a; 集群模式&#xff1a; 單機模式 架構規劃&#xff1a…

【JVM】內存分配與回收原則

在 Java 開發中&#xff0c;自動內存管理是 JVM 的核心能力之一&#xff0c;而內存分配與回收的策略直接影響程序的性能和穩定性。本文將詳細解析 JVM 的內存分配機制、對象回收規則以及背后的設計思想&#xff0c;幫助開發者更好地理解 JVM 的 "自動化" 內存管理邏輯…

Qt獲取hid設備信息

Qt 中通過 HID&#xff08;Human Interface Device&#xff09;接口獲取指定的 USB 設備&#xff0c;并讀取其數據。資源文件中包含了 hidapi.h、hidapi.dll 和 hidapi.lib。通過這些文件&#xff0c;您可以在 Qt 項目中實現對 USB 設備的 HID 接口調用。#include <QObject&…

Anaconda Jupyter 使用注意事項

Anaconda Jupyter 使用注意事項 1.將cell轉換為markdown。 First, select the cell you want to convertPress Esc to enter command mode (the cell border should turn blue)Press M to convert the cell to Markdown在編輯模式下按下ESC鍵&#xff0c;使單元塊&#xff08;c…