目錄
池化層?
最大池化層
MaxPool2d
最大池化操作圖示?
最大池化操作代碼演示?
綜合代碼案例?
池化層?
池化層(Pooling Layer)
核心作用:通過降采樣減少特征圖尺寸,降低計算量,增強特征魯棒性。
1. 常見類型
-
最大池化(Max Pooling):提取局部區域最大值。
-
平均池化(Average Pooling):計算局部區域平均值。
-
全局池化(Global Pooling):將每個通道的特征圖壓縮為一個標量(常用于分類任務)。
2. 參數與計算
-
窗口大小(Kernel Size):如 2×2、3×3。
-
步長(Stride):窗口滑動的步長,通常等于窗口大小(如 2)。
-
填充(Padding):邊緣填充策略,保持輸出尺寸。
MaxPool->下采樣
MaxUnPool->上采樣?
“池化層”(Pooling Layer)的命名源于其核心操作與 “池”(Pool)這一概念的類比 —— 就像從一個 “池子” 里提取所需內容,本質是對局部區域內的信息進行匯總、篩選并輸出。
具體來說,“池” 在這里可理解為 “局部數據區域”:
池化層會將輸入特征圖劃分為多個不重疊的小區域(比如 2×2 的窗口),每個小區域就像一個 “池子”;然后對每個 “池子” 里的所有數據(像素值或特征值)執行特定操作(最大池化取最大值、平均池化取平均值等),最終從每個 “池子” 里只輸出一個結果。
這個過程就像從每個 “池子” 里 “提取” 出最具代表性的信息(比如最大池化提取 “最顯著特征”,平均池化提取 “平均特征”),因此被形象地稱為 “池化”。
最大池化層
?最大池化層是卷積神經網絡(CNN)中用于下采樣(Downsampling)?的關鍵組件,通過在輸入特征圖的局部非重疊區域(池化窗口)內選取最大值作為輸出,實現特征篩選與維度壓縮。其核心是保留局部區域內最顯著的特征信號,同時降低特征圖的空間分辨率。
核心參數
-
池化窗口尺寸(Kernel Size):
常用?2×2?或?3×3,決定局部特征的感知范圍。窗口越大,壓縮率越高,但可能丟失細粒度特征。 -
步長(Stride):
窗口滑動的步幅,通常與窗口尺寸一致(如?2×2?窗口對應步長 2),此時輸出尺寸為輸入的?1/2(沿高度和寬度)。 -
通道獨立性:
池化操作在每個通道內獨立進行,不跨通道融合(輸出通道數與輸入一致)。
功能與意義
-
維度縮減與計算效率提升:
通過降低特征圖的?H×W?維度,減少后續網絡層的參數量和計算量(如?2×2?池化可使特征圖面積變為原來的?1/4)。 -
特征魯棒性增強:
-
平移不變性:對輸入特征的輕微位置偏移(如目標小幅移動)具有容錯性(只要最大值仍在窗口內,輸出不變)。
-
噪聲抑制:通過選取局部最大值,過濾次要信息(如背景噪聲),強化關鍵特征(如邊緣、紋理的強響應區域)。
-
-
防止過擬合:
減少特征冗余,降低模型對局部細節的過度依賴,提升泛化能力。
簡易解釋:
最大池化可以理解成 “抓重點” 的操作,用一個生活化的例子就能說清楚:
假設你有一張照片(對應輸入的特征圖),現在用一個小方格(比如 2x2 的池化窗口)在照片上 “掃”—— 每次掃到一個方格,就只留下這個方格里最亮的那個點(取最大值),其他點都忽略;然后方格按固定步長(比如每次挪 2 格)移到下一個位置,重復同樣的操作。
最后你會得到一張更小的照片:原來的細節少了,但保留了每個小區域里最突出的特征(比如最亮的色塊、最明顯的邊緣)。
這么做的好處很簡單:
- 照片變小了,后續處理起來更快(降維,減少計算量);
- 就算原照片里的物體稍微挪了一點位置(比如小方格稍微偏了點),只要最亮的點還在方格里,結果就不變(增強對位置變化的抗干擾能力)。?
簡單說,最大池化就是 “用最小的信息損失,把數據變小,同時抓住核心特征”。
MaxPool2d
?參數:
tip:
Floor和Ceiling兩個操作
簡單來說,就是向上下取整
?
此處ceil和floor兩個模式,表示池化核部分超出輸入圖像邊界時候是否保留?
池化后的形狀大小:
??
最大池化操作圖示?
最大池化操作代碼演示?
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2dinput = torch.tensor([[1, 2, 0, 3, 1],[0, 1, 2, 3, 1],[1, 2, 1, 0, 0],[5, 2, 3, 1, 1],[2, 1, 0, 1, 1]
])
print(input.shape)
"""
打印結果:
torch.Size([5, 5])
不符合卷積層的輸入要求
在最簡單的情況下,輸入尺寸為 (N,C,H,W)
N:批量數
C:通道數
H:高度
W:寬度
"""
input = torch.reshape(input, (1, 1, 5, 5))class Mymodule(nn.Module):def __init__(self):super().__init__()self.maxpool = MaxPool2d(kernel_size=3, ceil_mode=True)def forward(self, input):output = self.maxpool(input)return outputmodel = Mymodule()
output = model(input)
print(output)
??Ceil_model=True時
?Ceil_model=False時?
綜合代碼案例?
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10("../torchvision_dataset", train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset, batch_size=64)class Mymodule(nn.Module):def __init__(self):super().__init__()self.maxpool = MaxPool2d(kernel_size=3, ceil_mode=True)def forward(self, input):output = self.maxpool(input)return outputmodel = Mymodule()
step = 0
writer = SummaryWriter("logs_test5")
for data in dataloader:imgs, targets = datawriter.add_images("input", imgs, step)output = model(imgs)writer.add_images("output", output, step)step += 1writer.close()