- 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
- 🍖 原作者:K同學啊
前言
- ResNext是分組卷積的開始之作,這里本文將學習ResNext網絡;
- 本文復現了ResNext50神經網絡,并用其進行了猴痘病分類實驗;
- 沒有最好的網絡。只有最適合的網絡,網絡不是越復雜,越優秀越好,必須根據實際數據情況,目標要求決定,很多時候,簡單的網絡反而效果更好;
- 歡迎收藏 + 關注,本人將會持續更新
文章目錄
- 1、知識簡介
- 1、分組卷積
- 2、split-transform-merge
- 3、ResNext-50簡介
- 2、ResNext-50實驗
- 1、導入數據
- 1、導入庫
- 2、查看數據信息和導入數據
- 3、展示數據
- 4、數據導入
- 5、數據劃分
- 6、動態加載數據
- 2、構建ResNext-50網絡
- 3、模型訓練
- 1、構建訓練集
- 2、構建測試集
- 3、設置超參數
- 4、模型訓練
- 5、結果可視化
- 6、模型評估
- 3、參考資料
1、知識簡介
1、分組卷積
分組卷積最早出現在AlexNet網絡中,在這里將通道數分成兩組,采用兩個GPU并行提取特征,網絡結構如下:
提取到的特征圖如下:
作者發現第一組提取的主要是黑白特征,第二組提取的主要是彩色特征,這樣分組特征可以更好的提取不同特征數據。
普通卷積 VS 分組卷積
先看常規卷積,在常規卷積中,輸入feature map尺寸為 n 個,輸出feature map與卷積和數量相同也是n個,卷積核大小為:c * k * k
,n個卷積核總大小為:n * c * k * k
,最后輸出的維度是:n * h1 * w1
,如下圖左邊所示:
分組卷積,就是對輸入的feature map進行分組,然后每組分別卷積。假設輸入feature map的尺寸為 c * h * w
,輸出的feature map為 n,假設分為 g 組,則每組的輸入的feature map數量為 c / g,每組輸出的feature map為 n / g。但是注意:只是
每個卷積核的輸入通道數量變成了 c / g
,卷積核大小是不變的,每一組卷積核運算后得到了 (n / g) * h1 * w1
,最后將各組矩陣進行拼接就可以得出最后的結果,最后輸出的維度依然是n * h1 * w1
,與常規卷積一樣。
參數了對比:
- 常規卷積:c * k * k * n,c通道數,k * k:卷積核矩陣大小,n卷積核數量;
- 分組卷積:(c / g) * k * k * (n / g) * g = k * k * c * n * (1 / g),從參數了來看,分組卷積更小;
更詳細的圖如下:
2、split-transform-merge
“Split-Transform-Merge” 是一種常見的設計模式或處理流程,廣泛應用于軟件開發、數據處理和系統架構中。它的核心思想是將一個復雜的問題分解為更小的部分(Split),對每個部分進行獨立的處理或轉換(Transform),最后將處理后的結果重新組合(Merge)以完成整體任務。
1. Split(拆分)
在這一階段,輸入數據或任務被分解成更小、更易于管理的部分。拆分的方式取決于具體問題和上下文。例如:
- 數據拆分:將大數據集分割成多個小塊。
- 任務拆分:將一個復雜的任務分解為多個子任務。
- 并行化:通過拆分實現并行處理,提高效率。
示例:
- 分組卷積中,輸入通道分組拆分,分組進行卷積。
2. Transform(轉換/處理)
在拆分后,每個部分被獨立處理或轉換。這是整個流程的核心階段,通常涉及計算、分析或修改操作。轉換的具體內容取決于任務需求:
- 數據清洗、格式轉換。
- 算法計算或模型推理。
- 對子任務的獨立執行。
示例:
- 分組卷積中 ,每一組分別進行卷積計算,互補干擾。
3. Merge(合并)
在所有子任務完成后,將處理后的結果重新組合起來,形成最終的輸出。合并的方式需要確保結果的完整性和一致性:
- 數據合并:將多個處理后的數據塊拼接成完整的數據集。
- 結果整合:將多個子任務的結果匯總為最終答案。
- 沖突解決:如果子任務之間存在沖突或重復,需要在合并階段解決。
示例:
- 分組卷積中,最后將每一組卷積的結果進行組合。
3、ResNext-50簡介
ResNext
網絡被譽為,分組卷積的開山之作,是何凱明團隊在2017年CVPR會與提出的,是ResNet
網絡的升級版。
在論文中,作者提到了一個普遍存在的現象,提高模型準確率,往往采用的是加深或加寬網絡的方法,這種方法雖然有一定效果,但是網絡設計的難度和計算了也隨著增加,因為不代表網絡越深就越好,有時候提升了精度,但是代價也大,就如VGG16提出來的時候,計算了龐大。
在論文中,作者提出了在不額外增加計算代價的情況下,提升網絡精度,提出了cardinality
概念(cardinality
指的是分組卷積中的“組數”).
下圖中,左邊是(Resnet)右邊數(Resnext)的模塊差異,在ResNet中,輸入具有256個通道特征經過1 * 1
卷積壓縮到4倍到64個
通道特征,然后通過3 * 3
卷積核進行特征提取,最后經過 3 * 3
卷積核進行還原通道數量輸出,并于原來特征進行殘差連接。在ResNext中,將256個輸入通道特征分成32個組,每個組首先進行64倍壓縮到4
個通道,然后用3 * 3
卷積核大小進行特征提取,最后通過1 * 1
卷積核進行通道還原,后會將每個分組的結構進行維度拼接并與原始特征進行殘差連接。
cardinatity
指的是一個block中所具有的相同分支的數目,即“組數”.
下面進行ResNext-50網絡圖的搭建(pytorch復現)
2、ResNext-50實驗
1、導入數據
1、導入庫
import torch
import torch.nn as nn
import torchvision
import numpy as np
import os, PIL, pathlib # 設置設備
device = "cuda" if torch.cuda.is_available() else "cpu"device
'cuda'
2、查看數據信息和導入數據
數據目錄有兩個文件:一個數據文件,一個權重。
data_dir = "./data/"data_dir = pathlib.Path(data_dir)# 類別數量
classnames = [str(path).split('/')[0] for path in os.listdir(data_dir)]classnames
['Monkeypox', 'Others']
3、展示數據
import matplotlib.pylab as plt
from PIL import Image # 獲取文件名稱
data_path_name = "./data/Others"
data_path_list = [f for f in os.listdir(data_path_name) if f.endswith(('jpg', 'png'))]# 創建畫板
fig, axes = plt.subplots(2, 8, figsize=(16, 6))for ax, img_file in zip(axes.flat, data_path_list):path_name = os.path.join(data_path_name, img_file)img = Image.open(path_name) # 打開# 顯示ax.imshow(img)ax.axis('off')plt.show()
?
?
4、數據導入
from torchvision import transforms, datasets # 數據統一格式
img_height = 224
img_width = 224 data_tranforms = transforms.Compose([transforms.Resize([img_height, img_width]),transforms.ToTensor(),transforms.Normalize( # 歸一化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
])# 加載所有數據
total_data = datasets.ImageFolder(root=data_dir, transform=data_tranforms)
5、數據劃分
# 大小 8 : 2
train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size train_data, test_data = torch.utils.data.random_split(total_data, [train_size, test_size])
6、動態加載數據
batch_size = 32 train_dl = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True
)test_dl = torch.utils.data.DataLoader(test_data,batch_size=batch_size,shuffle=False
)
# 查看數據維度
for data, labels in train_dl:print("data shape[N, C, H, W]: ", data.shape)print("labels: ", labels)break
data shape[N, C, H, W]: torch.Size([32, 3, 224, 224])
labels: tensor([1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,0, 1, 0, 0, 0, 1, 0, 0])
2、構建ResNext-50網絡
ResNet-50網絡結構圖:
在復現ResNext50網絡中,我查閱了不少資料,但是我好像都沒怎么看懂那個代碼,后面我發現這個就是在ResNet50上加了分組卷積,其他網絡結構就是在每一層,第二層的數量是resnet的2倍,后面基于以前搭建的ResNet網絡結果進行修改,代碼如下所示。
在ResNext50中,有幾個參數需要注意:
- 分組卷積:cardinality參數代表分組卷積數量,在Conv2d中groups參數就是分組卷積數量。
- 通道數計算:每組的輸出通道數由 group_depth 決定,總輸出通道數為 cardinality × group_depth。這里,下面本人搭建的ResNext50網絡結構,每一層輸入通道數,輸出通道數,都是自己手動輸入的,故這里group_depth隱藏在filters中(手動計算).
回憶:
Bottleneck 的基本概念
Bottleneck 結構通常由三個卷積層組成,他是ResNet以及其變體的基本網絡層單元。
- 第一個 1×1 卷積:降低輸入特征圖的通道數,減少后續計算量。
- 中間的 3×3 卷積:核心特征提取過程。在 ResNeXt 中,這一層使用分組卷積來增強表達能力。
- 最后一個 1×1 卷積:恢復通道數到原始或者更高的數量,以便與輸入特征圖進行殘差連接。
注意:
- 在ResNext網絡結構中,分組卷積只在Bottleneck只在第二層使用
import torch.nn.functional as F# Bottleneck: 分為殘差模塊一、殘差模塊二# 定義殘差模塊一,這個用于處理輸入和輸出通道一樣的情況
'''
卷積核大小:1 3 1
核心特點:尺寸不變:輸入和輸出的尺寸保持一致。 沒有下采樣:沒有使用步長大于1的卷積操作,因此沒有改變特征圖的空間尺寸
'''
class Identity_block(nn.Module):def __init__(self, in_channels, kernel_size, filters, cardinality):super(Identity_block, self).__init__()# 輸出通道filter1, filter2, filter3 = filters# 卷積層一, 降維self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=1)self.bn1 = nn.BatchNorm2d(filter1)# 卷積層2, 分組卷積, 核心:特征提取self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1,groups=cardinality) # 通過卷積輸入輸出公式發現,padding=1,可以保證輸入和輸出尺寸相同self.bn2 = nn.BatchNorm2d(filter2)# 卷積層3, 升維self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1)self.bn3 = nn.BatchNorm2d(filter3)def forward(self, x):# 記錄原始值xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))# 殘差連接,輸入、輸出維度不變x += xxx = F.relu(x)return x # 定義卷積模塊二:用于處理輸入和輸出不一樣的情況
'''
* 卷積核還是:1 3 1
* stride=2
* 這里的分支是采用一個Conv2D,和一個歸一化BN層,也是為了處理數據維度吧, 這種維度的變化,可以用ai舉例子核心特點:尺寸變化,stride=2降維
'''
class ConvBlock(nn.Module):def __init__(self, in_channels, kernel_size, filters, cardinality, stride=2):super(ConvBlock, self).__init__()filter1, filter2, filter3= filters# 卷積層1, 降維self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride)self.bn1 = nn.BatchNorm2d(filter1)# 卷積2, 分組卷積,核心:特征提取self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1,groups=cardinality) # 需要維持維度不變self.bn2 = nn.BatchNorm2d(filter2)# 卷積3, 降維self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1) # stride = 1,維持通道不變self.bn3 = nn.BatchNorm2d(filter3)# 用于匹配維度的shortcut卷積,這個就是上面Identity_block的x分支self.shortcut = nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride)self.shortcut_bn = nn.BatchNorm2d(filter3)def forward(self, x):xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))temp = self.shortcut_bn(self.shortcut(xx))x += tempx = F.relu(x)return x # 定義ResNext50
class ResNext50(nn.Module):def __init__(self, classes): # 類別數量super().__init__()# 頭頂, resnet以及變體一般都是這個self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 第一部分self.part1_1 = ConvBlock(64, 3, [128, 128, 256], cardinality=32, stride=1)self.part1_2 = Identity_block(256, 3, [128, 128, 256], cardinality=32)self.part1_3 = Identity_block(256, 3, [128, 128, 256], cardinality=32)# 第二部分self.part2_1 = ConvBlock(256, 3, [256, 256, 512], cardinality=32)self.part2_2 = Identity_block(512, 3, [256, 256, 512], cardinality=32)self.part2_3 = Identity_block(512, 3, [256, 256, 512], cardinality=32)self.part2_4 = Identity_block(512, 3, [256, 256, 512], cardinality=32)# 第三部分self.part3_1 = ConvBlock(512, 3, [512, 512, 1024], cardinality=32)self.part3_2 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_3 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_4 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_5 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_6 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)# 第四部分self.part4_1 = ConvBlock(1024, 3, [1024, 1024, 2048], cardinality=32)self.part4_2 = Identity_block(2048, 3, [1024, 1024, 2048], cardinality=32)self.part4_3 = Identity_block(2048, 3, [1024, 1024, 2048], cardinality=32)# 平均池化self.avg_pool = nn.AvgPool2d(kernel_size=7)# 全連接self.fn1 = nn.Linear(2048, classes)def forward(self, x):# 頭部x = F.relu(self.bn1(self.conv1(x)))x = self.max_pool(x)x = self.part1_1(x)x = self.part1_2(x)x = self.part1_3(x)x = self.part2_1(x)x = self.part2_2(x)x = self.part2_3(x)x = self.part2_4(x)x = self.part3_1(x)x = self.part3_2(x)x = self.part3_3(x)x = self.part3_4(x)x = self.part3_5(x)x = self.part3_6(x)x = self.part4_1(x)x = self.part4_2(x)x = self.part4_3(x)x = self.avg_pool(x)x = x.view(x.size(0), -1) # 扁平化x = self.fn1(x)return x model = ResNext50(classes=len(classnames)).to(device)model
ResNext50((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(part1_1): ConvBlock((conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(shortcut_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_2): Identity_block((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_3): Identity_block((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_1): ConvBlock((conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_2): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_3): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_4): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_1): ConvBlock((conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_2): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_3): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_4): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_5): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_6): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_1): ConvBlock((conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_2): Identity_block((conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_3): Identity_block((conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(avg_pool): AvgPool2d(kernel_size=7, stride=7, padding=0)(fn1): Linear(in_features=2048, out_features=2, bias=True)
)
3、模型訓練
1、構建訓練集
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)batch_size = len(dataloader)train_acc, train_loss = 0, 0 for X, y in dataloader:X, y = X.to(device), y.to(device)# 訓練pred = model(X)loss = loss_fn(pred, y)# 梯度下降法optimizer.zero_grad()loss.backward()optimizer.step()# 記錄train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_acc /= sizetrain_loss /= batch_sizereturn train_acc, train_loss
2、構建測試集
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)batch_size = len(dataloader)test_acc, test_loss = 0, 0 with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_acc /= sizetest_loss /= batch_sizereturn test_acc, test_loss
3、設置超參數
loss_fn = nn.CrossEntropyLoss() # 損失函數
learn_lr = 1e-4 # 超參數
optimizer = torch.optim.Adam(model.parameters(), lr=learn_lr) # 優化器
4、模型訓練
import copy train_acc = []
train_loss = []
test_acc = []
test_loss = []epoches = 50best_acc = 0for i in range(epoches):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_model if epoch_test_acc > best_acc: best_acc = epoch_test_acc best_model = copy.deepcopy(model) # 拷貝最好模型train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 獲取當前的學習率 lr = optimizer.state_dict()['param_groups'][0]['lr']# 輸出template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')print(template.format(i + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))print("Done")PATH = './best_model.pth' # 保存的參數文件名
torch.save(best_model.state_dict(), PATH)
Epoch: 1, Train_acc:62.3%, Train_loss:0.696, Test_acc:66.4%, Test_loss:0.604
Epoch: 2, Train_acc:67.9%, Train_loss:0.620, Test_acc:69.9%, Test_loss:0.580
Epoch: 3, Train_acc:69.5%, Train_loss:0.580, Test_acc:68.3%, Test_loss:0.603
Epoch: 4, Train_acc:71.6%, Train_loss:0.547, Test_acc:73.9%, Test_loss:0.530
Epoch: 5, Train_acc:74.7%, Train_loss:0.519, Test_acc:75.1%, Test_loss:0.520
Epoch: 6, Train_acc:78.2%, Train_loss:0.464, Test_acc:67.8%, Test_loss:0.683
Epoch: 7, Train_acc:78.1%, Train_loss:0.459, Test_acc:69.0%, Test_loss:0.652
Epoch: 8, Train_acc:80.8%, Train_loss:0.411, Test_acc:72.7%, Test_loss:0.643
Epoch: 9, Train_acc:84.8%, Train_loss:0.362, Test_acc:74.8%, Test_loss:0.575
Epoch:10, Train_acc:87.4%, Train_loss:0.314, Test_acc:77.9%, Test_loss:0.536
Epoch:11, Train_acc:89.3%, Train_loss:0.266, Test_acc:79.0%, Test_loss:0.505
Epoch:12, Train_acc:89.4%, Train_loss:0.260, Test_acc:78.3%, Test_loss:0.601
Epoch:13, Train_acc:90.7%, Train_loss:0.226, Test_acc:81.4%, Test_loss:0.493
Epoch:14, Train_acc:93.9%, Train_loss:0.159, Test_acc:80.4%, Test_loss:0.616
Epoch:15, Train_acc:93.8%, Train_loss:0.152, Test_acc:80.4%, Test_loss:0.620
Epoch:16, Train_acc:92.2%, Train_loss:0.190, Test_acc:82.3%, Test_loss:0.621
Epoch:17, Train_acc:94.0%, Train_loss:0.142, Test_acc:82.3%, Test_loss:0.582
Epoch:18, Train_acc:95.8%, Train_loss:0.106, Test_acc:79.3%, Test_loss:0.625
Epoch:19, Train_acc:95.5%, Train_loss:0.127, Test_acc:81.1%, Test_loss:0.625
Epoch:20, Train_acc:95.4%, Train_loss:0.113, Test_acc:83.0%, Test_loss:0.482
Epoch:21, Train_acc:96.7%, Train_loss:0.087, Test_acc:83.0%, Test_loss:0.667
Epoch:22, Train_acc:97.3%, Train_loss:0.083, Test_acc:80.4%, Test_loss:0.695
Epoch:23, Train_acc:97.1%, Train_loss:0.077, Test_acc:83.7%, Test_loss:0.634
Epoch:24, Train_acc:96.6%, Train_loss:0.086, Test_acc:82.5%, Test_loss:0.732
Epoch:25, Train_acc:96.6%, Train_loss:0.098, Test_acc:83.9%, Test_loss:0.711
Epoch:26, Train_acc:96.0%, Train_loss:0.107, Test_acc:75.3%, Test_loss:0.821
Epoch:27, Train_acc:95.6%, Train_loss:0.105, Test_acc:81.6%, Test_loss:0.596
Epoch:28, Train_acc:96.7%, Train_loss:0.088, Test_acc:84.4%, Test_loss:0.606
Epoch:29, Train_acc:97.5%, Train_loss:0.071, Test_acc:86.5%, Test_loss:0.615
Epoch:30, Train_acc:98.2%, Train_loss:0.051, Test_acc:80.4%, Test_loss:0.772
Epoch:31, Train_acc:98.5%, Train_loss:0.041, Test_acc:83.7%, Test_loss:0.694
Epoch:32, Train_acc:98.5%, Train_loss:0.048, Test_acc:82.8%, Test_loss:0.671
Epoch:33, Train_acc:97.7%, Train_loss:0.064, Test_acc:84.1%, Test_loss:0.745
Epoch:34, Train_acc:98.4%, Train_loss:0.054, Test_acc:83.7%, Test_loss:0.661
Epoch:35, Train_acc:98.2%, Train_loss:0.068, Test_acc:83.0%, Test_loss:0.605
Epoch:36, Train_acc:96.8%, Train_loss:0.086, Test_acc:83.2%, Test_loss:0.551
Epoch:37, Train_acc:97.8%, Train_loss:0.063, Test_acc:82.3%, Test_loss:0.739
Epoch:38, Train_acc:97.6%, Train_loss:0.065, Test_acc:83.0%, Test_loss:0.583
Epoch:39, Train_acc:98.2%, Train_loss:0.045, Test_acc:83.4%, Test_loss:0.697
Epoch:40, Train_acc:98.1%, Train_loss:0.048, Test_acc:82.5%, Test_loss:0.710
Epoch:41, Train_acc:98.2%, Train_loss:0.054, Test_acc:83.2%, Test_loss:0.564
Epoch:42, Train_acc:98.4%, Train_loss:0.051, Test_acc:85.5%, Test_loss:0.514
Epoch:43, Train_acc:99.0%, Train_loss:0.025, Test_acc:83.9%, Test_loss:0.663
Epoch:44, Train_acc:99.1%, Train_loss:0.029, Test_acc:85.5%, Test_loss:0.594
Epoch:45, Train_acc:98.3%, Train_loss:0.036, Test_acc:84.6%, Test_loss:0.719
Epoch:46, Train_acc:98.7%, Train_loss:0.036, Test_acc:84.4%, Test_loss:0.631
Epoch:47, Train_acc:97.7%, Train_loss:0.055, Test_acc:81.4%, Test_loss:0.643
Epoch:48, Train_acc:98.7%, Train_loss:0.040, Test_acc:85.1%, Test_loss:0.607
Epoch:49, Train_acc:98.8%, Train_loss:0.037, Test_acc:80.2%, Test_loss:0.897
Epoch:50, Train_acc:98.6%, Train_loss:0.042, Test_acc:84.4%, Test_loss:0.601
Done
5、結果可視化
import matplotlib.pyplot as plt
#隱藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息epochs_range = range(epoches)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()
?
?
6、模型評估
# 加載最好模型
best_model.load_state_dict(torch.load(PATH, map_location=device))
# 模型測試
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)print(epoch_test_acc, epoch_test_loss)
0.8648018648018648 0.6145411878824234
3、參考資料
- 深度學習——分類之ResNeXt - 知乎
- 通義 - 你的個人AI助手
- ResNeXt代碼復現+超詳細注釋(PyTorch)-CSDN博客