28-29【動手學深度學習】批量歸一化 + ResNet

1. 批量歸一化

1.1 原理

當神經網絡比較深的時候會發現:數據在下面,損失函數在上面,這樣會出現什么問題?

  • 正向傳遞的時候,數據是從下往上一步一步往上傳遞
  • 反向傳遞的時候,數據是從上面往下傳遞,這時候就會出現問題:梯度在上面的時候比較大,越到下面就越容易變小(因為是n個很小的數進行相乘,越到后面結果就越小,也就是說越靠近數據的,層的梯度就越小
  • 上面的梯度比較大,那么每次更新的時候上面的層就會不斷地更新;但是下面層因為梯度比較小,所以對權重地更新就比較少,這樣的話就會導致上面的收斂比較快,而下面的收斂比較慢,這樣就會導致底層靠近數據的內容(網絡所嘗試抽取的網絡底層的特征:簡單的局部邊緣、紋理等信息)變化比較慢,上層靠近損失的內容(高層語義信息)收斂比較快,所以每一次底層發生變化,所有的層都得跟著變(底層的信息發生變化就導致上層的權重全部白學了),這樣就會導致模型的收斂比較慢

所以提出了假設:能不能在改變底部信息的時候,避免頂部不斷的重新訓練?(這也是批量歸一化所考慮的問題)

\varepsilon?是為了避免除以0

全連接層

通常,我們將批量規范化層置于全連接層中的仿射變換和激活函數之間。 設全連接層的輸入為x,權重參數和偏置參數分別為W和b,激活函數為?,批量規范化的運算符為BN。 那么,使用批量規范化的全連接層的輸出的計算詳情如下:

h=?(BN(Wx+b)).

回想一下,均值和方差是在應用變換的"相同"小批量上計算的。

卷積層

同樣,對于卷積層,我們可以在卷積層之后和非線性激活函數之前應用批量規范化。 當卷積有多個輸出通道時,我們需要對這些通道的“每個”輸出執行批量規范化,每個通道都有自己的拉伸(scale)和偏移(shift)參數,這兩個參數都是標量。 假設我們的小批量包含m個樣本,并且對于每個通道,卷積的輸出具有高度p和寬度q。 那么對于卷積層,我們在每個輸出通道的m?p?q個元素上同時執行每個批量規范化。 因此,在計算平均值和方差時,我們會收集所有空間位置的值,然后在給定通道內應用相同的均值和方差,以便在每個空間位置對值進行規范化。

批量歸一化需要在激活函數之前,因為BN是線性的嗎,而激活函數是非線性的

?使用BN,可以增大學習率,因此可以加速收斂速度

預測過程中的批量規范化

正如我們前面提到的,批量規范化在訓練模式和預測模式下的行為通常不同。 首先,將訓練好的模型用于預測時,我們不再需要樣本均值中的噪聲以及在微批次上估計每個小批次產生的樣本方差了。 其次,例如,我們可能需要使用我們的模型對逐個樣本進行預測。 一種常用的方法是通過移動平均估算整個訓練數據集的樣本均值和方差,并在預測時使用它們得到確定的輸出。 可見,和暫退法一樣,批量規范化層在訓練模式和預測模式下的計算結果也是不一樣的。

1.2 代碼

從零實現

import torch
from torch import nn
from d2l import torch as d2ldef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通過is_grad_enabled來判斷當前模式是訓練模式還是預測模式if not torch.is_grad_enabled():# 如果是在預測模式下,直接使用傳入的移動平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全連接層的情況,計算特征維上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二維卷積層的情況,計算通道維上(axis=1)的均值和方差。# 這里我們需要保持X的形狀以便后面可以做廣播運算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 訓練模式下,用當前的均值和方差做標準化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移動平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 縮放和移位return Y, moving_mean.data, moving_var.data

創建一個正確的 BatchNorm 圖層

class BatchNorm(nn.Module):# num_features:完全連接層的輸出數量或卷積層的輸出通道數。# num_dims:2表示完全連接層,4表示卷積層def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 參與求梯度和迭代的拉伸和偏移參數,分別初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型參數的變量初始化為0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在內存上,將moving_mean和moving_var# 復制到X所在顯存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新過的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

應用BatchNorm 于LeNet模型

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4),nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16,kernel_size=5), BatchNorm(16, num_dims=4),nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(), nn.Linear(16 * 4 * 4, 120),BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2),nn.Sigmoid(), nn.Linear(84, 10))

在Fashion-MNIST數據集上訓練網絡

lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

?拉伸參數 gamma 和偏移參數 beta

net[1].gamma.reshape((-1, )), net[1].beta.reshape((-1, ))

?簡明實現

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6),nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16),nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(), nn.Linear(256, 120), nn.BatchNorm1d(120),nn.Sigmoid(), nn.Linear(120, 84), nn.BatchNorm1d(84),nn.Sigmoid(), nn.Linear(84, 10))
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

小結

  • 在模型訓練過程中,批量規范化利用小批量的均值和標準差,不斷調整神經網絡的中間輸出,使整個神經網絡各層的中間輸出值更加穩定。
  • 批量規范化在全連接層和卷積層的使用略有不同。
  • 批量規范化層和暫退層一樣,在訓練模式和預測模式下計算不同。
  • 批量規范化有許多有益的副作用,主要是正則化。另一方面,”減少內部協變量偏移“的原始動機似乎不是一個有效的解釋。

2. ResNet

2.1 原理

????????只有當較復雜的函數類包含較小的函數類時,我們才能確保提高它們的性能。 對于深度神經網絡,如果我們能將新添加的層訓練成恒等映射(identity function)f(x)=x,新模型和原模型將同樣有效。 同時,由于新模型可能得出更優的解來擬合訓練數據集,因此添加層似乎更容易降低訓練誤差。?

當經過很多層卷積之后,可能通道數會產生變化,所以要加上1×1的卷積轉換通道數。 (通常情況下是,當高寬減半時,通道數變為原來的一倍)

2.2 代碼

殘差塊

import torch 
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3,padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3,padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels,kernel_size=1, stride=strides)else:self.conv3 = Noneself.bn1 =  nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)# 當 inplace=True 時,ReLU 會直接在輸入張量上修改數據(覆蓋原值),不分配額外內存存儲輸出self.relu = nn.ReLU(inplace=True)  def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y = Y + Xreturn F.relu(Y)

輸入和輸出形狀一致

blk = Residual(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = blk(X)
Y.shape

?增加輸出通道數的同時,減半輸出的高和寬

blk = Residual(3, 6, use_1x1conv=True, strides=2)
blk(X).shape

?ResNet模型

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))def resnet_block(input_channels, num_channels, num_residuals,first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels, use_1x1conv=True,strides=2))else:blk.append(Residual(num_channels, num_channels))return blk# *的含義是將list展開,變成一個個的輸入
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))net = nn.Sequential(b1, b2, b3, b4, b5, nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten(), nn.Linear(512, 10))

觀察一下ResNet中不同模塊的輸入形狀是如何變化的

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__, 'output shape:\t', X.shape)

訓練模型

lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

小結

  • 學習嵌套函數(nested function)是訓練神經網絡的理想情況。在深層神經網絡中,學習另一層作為恒等映射(identity function)較容易(盡管這是一個極端情況)。
  • 殘差映射可以更容易地學習同一函數,例如將權重層中的參數近似為零。
  • 利用殘差塊(residual blocks)可以訓練出一個有效的深層神經網絡:輸入可以通過層間的殘余連接更快地向前傳播。
  • 殘差網絡(ResNet)對隨后的深層神經網絡設計產生了深遠影響。

ResNet的梯度計算

3. 第二次kaggle競賽

競賽地址:https://www.kaggle.com/c/classify-leaves

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/80720.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/80720.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/80720.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Linux網絡】Http服務優化 - 增加請求后綴、狀態碼描述、重定向、自動跳轉及注冊多功能服務

📢博客主頁:https://blog.csdn.net/2301_779549673 📢博客倉庫:https://gitee.com/JohnKingW/linux_test/tree/master/lesson 📢歡迎點贊 👍 收藏 ?留言 📝 如有錯誤敬請指正! &…

AIGC(生成式AI)試用 32 -- AI做軟件程序測試 3

總結之前的AI做程序測試過程,試圖優化提問方式,整合完成的AI程序測試提問,探索更多可能的AI測試 AIGC(生成式AI)試用 30 -- AI做軟件程序測試 1 AIGC(生成式AI)試用 31 -- AI做軟件程序…

C語言實現迪杰斯特拉算法進行路徑規劃

使用C語言實現迪杰斯特拉算法進行路徑規劃 迪杰斯特拉算法是一種用于尋找加權圖中最短路徑的經典算法。它特別適合用于計算從一個起點到其他所有節點的最短路徑,前提是圖中的邊權重為非負數。 一、迪杰斯特拉算法的基本原理 迪杰斯特拉算法的核心思想是“貪心法”…

引領印尼 Web3 變革:Mandala Chain 如何助力 1 億用戶邁向數字未來?

當前 Web3 的發展正處于關鍵轉折點,行業亟需吸引新用戶以推動 Web3 的真正大規模采用。然而,大規模采用面臨著核心挑戰:數據泄露風險、集中存儲的安全漏洞、跨系統互操作性障礙,以及低效的服務訪問等問題。如何才能真正突破這些瓶…

WebSocket是h5定義的,雙向通信,節省資源,更好的及時通信

瀏覽器和服務器之間的通信更便利,比http的輪詢等效率提高很多, WebSocket并不是權限的協議,而是利用http協議來建立連接 websocket必須由瀏覽器發起請求,協議是一個標準的http請求,格式如下 GET ws://example.com:3…

Kaamel白皮書:IoT設備安全隱私評估實踐

1. IoT安全與隱私領域的現狀與挑戰 隨著物聯網技術的快速發展,IoT設備在全球范圍內呈現爆發式增長。然而,IoT設備帶來便捷的同時,也引發了嚴峻的安全與隱私問題。根據NSF(美國國家科學基金會)的研究表明,I…

php安裝swoole擴展

PHP安裝swoole擴展 Swoole官網 安裝準備 安裝前必須保證系統已經安裝了下列軟件 4.8 版本需要 PHP-7.2 或更高版本5.0 版本需要 PHP-8.0 或更高版本6.0 版本需要 PHP-8.1 或更高版本gcc-4.8 或更高版本makeautoconf 安裝Swool擴展 安裝官方文檔安裝后需要再php.ini中增加…

服務器傳輸數據存儲數據建議 傳輸慢的原因

一、JSON存儲的局限性 1. 性能瓶頸 全量讀寫:JSON文件通常需要整體加載到內存中才能操作,當數據量大時(如幾百MB),I/O延遲和內存占用會顯著增加。 無索引機制:查找數據需要遍歷所有條目(時間復…

Android四大核心組件

目錄 一、為什么需要四大組件? 二、Activity:看得見的界面 核心功能 生命周期圖解 代碼示例 三、Service:看不見的勞動者 兩大類型 生命周期對比 注意陷阱 四、BroadcastReceiver:消息傳遞專員 兩種注冊方式 廣播類型 …

「Mac暢玩AIGC與多模態01」架構篇01 - 展示層到硬件層的架構總覽

一、概述 AIGC(AI Generated Content)系統由多個結構層級組成,自上而下涵蓋交互界面、API 通信、模型推理、計算框架、底層驅動與硬件支持。本篇梳理 AIGC 應用的六層體系結構,明確各組件在系統中的職責與上下游關系,…

[MERN 項目實戰] MERN Multi-Vendor 電商平臺開發筆記(v2.0 從 bug 到結構優化的工程記錄)

[MERN 項目實戰] MERN Multi-Vendor 電商平臺開發筆記(v2.0 從 bug 到結構優化的工程記錄) 其實之前沒想著這么快就能把 2.0 的筆記寫出來的,之前的預期是,下一個階段會一直維持到將 MERN 項目寫完,畢竟后期很多東西都…

互斥量函數組

頭文件 #include <pthread.h> pthread_mutex_init 函數原型&#xff1a; int pthread_mutex_init(pthread_mutex_t *restrict mutex, const pthread_mutexattr_t *restrict attr); 函數參數&#xff1a; mutex&#xff1a;指向要初始化的互斥量的指針。 attr&#xf…

互聯網的下一代脈搏:深入理解 QUIC 協議

互聯網的下一代脈搏&#xff1a;深入理解 QUIC 協議 互聯網是現代社會的基石&#xff0c;而數據在其中高效、安全地傳輸是其運轉的關鍵。長期以來&#xff0c;傳輸層的 TCP&#xff08;傳輸控制協議&#xff09;一直是互聯網的主力軍。然而&#xff0c;隨著互聯網應用場景的日…

全球城市范圍30米分辨率土地覆蓋數據(1985-2020)

Global urban area 30 meter resolution land cover data (1985-2020) 時間分辨率年空間分辨率10m - 100m共享方式保護期 277 天 5 時 42 分 9 秒數據大小&#xff1a;8.98 GB數據時間范圍&#xff1a;1985-2020元數據更新時間2024-01-11 數據集摘要 1985~2020全球城市土地覆…

【Vue】單元測試(Jest/Vue Test Utils)

個人主頁&#xff1a;Guiat 歸屬專欄&#xff1a;Vue 文章目錄 1. Vue 單元測試簡介1.1 為什么需要單元測試1.2 測試工具介紹 2. 環境搭建2.1 安裝依賴2.2 配置 Jest 3. 編寫第一個測試3.1 組件示例3.2 編寫測試用例3.3 運行測試 4. Vue Test Utils 核心 API4.1 掛載組件4.2 常…

數據湖的管理系統管什么?主流產品有哪些?

一、數據湖的管理系統管什么&#xff1f; 數據湖的管理系統主要負責管理和優化存儲在數據湖中的大量異構數據&#xff0c;確保這些數據能夠被有效地存儲、處理、訪問和治理。以下是數據湖管理系統的主要職責&#xff1a; 數據攝入管理&#xff1a;管理系統需要支持從多種來源&…

英文中日期讀法

英文日期的讀法和寫法因地區&#xff08;英式英語與美式英語&#xff09;和正式程度有所不同&#xff0c;以下是詳細說明&#xff1a; 一、日期格式 英式英語 (日-月-年) 寫法&#xff1a;1(st) January 2023 或 1/1/2023讀法&#xff1a;"the first of January, twenty t…

衡量矩陣數值穩定性的關鍵指標:矩陣的條件數

文章目錄 1. 定義2. 為什么要定義條件數&#xff1f;2.1 分析線性系統 A ( x Δ x ) b Δ b A(x \Delta x) b \Delta b A(xΔx)bΔb2.2 分析線性系統 ( A Δ A ) ( x Δ x ) b (A \Delta A)(x \Delta x) b (AΔA)(xΔx)b2.3 定義矩陣的條件數 3. 性質及幾何意義3…

4月22日復盤-開始卷積神經網絡

4月24日復盤 一、CNN 視覺處理三大任務&#xff1a;圖像分類、目標檢測、圖像分割 上游&#xff1a;提取特征&#xff0c;CNN 下游&#xff1a;分類、目標、分割等&#xff0c;具體的業務 1. 概述 ? 卷積神經網絡是深度學習在計算機視覺領域的突破性成果。在計算機視覺領…

【網絡原理】從零開始深入理解TCP的各項特性和機制.(三)

上篇介紹了網絡原理傳輸層TCP協議的知識,本篇博客給大家帶來的是網絡原理剩余的內容, 總體來說,這部分內容沒有上兩篇文章那么重要,本篇知識有一個印象即可. &#x1f40e;文章專欄: JavaEE初階 &#x1f680;若有問題 評論區見 ? 歡迎大家點贊 評論 收藏 分享 如果你不知道分…