深度學習基礎--CNN經典網絡之分組卷積與ResNext網絡實驗探究(pytorch復現)

  • 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
  • 🍖 原作者:K同學啊

前言

  • ResNext是分組卷積的開始之作,這里本文將學習ResNext網絡;
  • 本文復現了ResNext50神經網絡,并用其進行了猴痘病分類實驗;
  • 沒有最好的網絡。只有最適合的網絡,網絡不是越復雜,越優秀越好,必須根據實際數據情況,目標要求決定,很多時候,簡單的網絡反而效果更好
  • 歡迎收藏 + 關注,本人將會持續更新

文章目錄

  • 1、知識簡介
    • 1、分組卷積
    • 2、split-transform-merge
    • 3、ResNext-50簡介
  • 2、ResNext-50實驗
    • 1、導入數據
      • 1、導入庫
      • 2、查看數據信息和導入數據
      • 3、展示數據
      • 4、數據導入
      • 5、數據劃分
      • 6、動態加載數據
    • 2、構建ResNext-50網絡
    • 3、模型訓練
      • 1、構建訓練集
      • 2、構建測試集
      • 3、設置超參數
    • 4、模型訓練
    • 5、結果可視化
    • 6、模型評估
  • 3、參考資料

1、知識簡介

1、分組卷積

分組卷積最早出現在AlexNet網絡中,在這里將通道數分成兩組,采用兩個GPU并行提取特征,網絡結構如下:

在這里插入圖片描述

提取到的特征圖如下:

在這里插入圖片描述

作者發現第一組提取的主要是黑白特征,第二組提取的主要是彩色特征,這樣分組特征可以更好的提取不同特征數據。


普通卷積 VS 分組卷積

先看常規卷積,在常規卷積中,輸入feature map尺寸為 n 個,輸出feature map與卷積和數量相同也是n個,卷積核大小為:c * k * k,n個卷積核總大小為:n * c * k * k,最后輸出的維度是:n * h1 * w1如下圖左邊所示

在這里插入圖片描述

分組卷積,就是對輸入的feature map進行分組,然后每組分別卷積。假設輸入feature map的尺寸為 c * h * w,輸出的feature map為 n,假設分為 g 組,則每組的輸入的feature map數量為 c / g,每組輸出的feature map為 n / g。但是注意只是每個卷積核的輸入通道數量變成了 c / g,卷積核大小是不變的,每一組卷積核運算后得到了 (n / g) * h1 * w1,最后將各組矩陣進行拼接就可以得出最后的結果,最后輸出的維度依然是n * h1 * w1,與常規卷積一樣。

參數了對比

  • 常規卷積:c * k * k * n,c通道數,k * k:卷積核矩陣大小,n卷積核數量;
  • 分組卷積:(c / g) * k * k * (n / g) * g = k * k * c * n * (1 / g),從參數了來看,分組卷積更小

更詳細的圖如下

在這里插入圖片描述

2、split-transform-merge

“Split-Transform-Merge” 是一種常見的設計模式或處理流程,廣泛應用于軟件開發、數據處理和系統架構中。它的核心思想是將一個復雜的問題分解為更小的部分(Split),對每個部分進行獨立的處理或轉換(Transform),最后將處理后的結果重新組合(Merge)以完成整體任務。


1. Split(拆分)

在這一階段,輸入數據或任務被分解成更小、更易于管理的部分。拆分的方式取決于具體問題和上下文。例如:

  • 數據拆分:將大數據集分割成多個小塊。
  • 任務拆分:將一個復雜的任務分解為多個子任務。
  • 并行化:通過拆分實現并行處理,提高效率。

示例

  • 分組卷積中,輸入通道分組拆分,分組進行卷積。

2. Transform(轉換/處理)

在拆分后,每個部分被獨立處理或轉換。這是整個流程的核心階段,通常涉及計算、分析或修改操作。轉換的具體內容取決于任務需求:

  • 數據清洗、格式轉換。
  • 算法計算或模型推理。
  • 對子任務的獨立執行。

示例

  • 分組卷積中 ,每一組分別進行卷積計算,互補干擾。

3. Merge(合并)

在所有子任務完成后,將處理后的結果重新組合起來,形成最終的輸出。合并的方式需要確保結果的完整性和一致性:

  • 數據合并:將多個處理后的數據塊拼接成完整的數據集。
  • 結果整合:將多個子任務的結果匯總為最終答案。
  • 沖突解決:如果子任務之間存在沖突或重復,需要在合并階段解決。

示例

  • 分組卷積中,最后將每一組卷積的結果進行組合。

3、ResNext-50簡介

ResNext網絡被譽為,分組卷積的開山之作,是何凱明團隊在2017年CVPR會與提出的,是ResNet網絡的升級版。

在論文中,作者提到了一個普遍存在的現象,提高模型準確率,往往采用的是加深或加寬網絡的方法,這種方法雖然有一定效果,但是網絡設計的難度和計算了也隨著增加,因為不代表網絡越深就越好,有時候提升了精度,但是代價也大,就如VGG16提出來的時候,計算了龐大。

在論文中,作者提出了在不額外增加計算代價的情況下,提升網絡精度,提出了cardinality概念(cardinality指的是分組卷積中的“組數”).

下圖中,左邊是(Resnet)右邊數(Resnext)的模塊差異,在ResNet中,輸入具有256個通道特征經過1 * 1卷積壓縮到4倍到64個通道特征,然后通過3 * 3卷積核進行特征提取,最后經過 3 * 3卷積核進行還原通道數量輸出,并于原來特征進行殘差連接。在ResNext中,將256個輸入通道特征分成32個組,每個組首先進行64倍壓縮到4個通道,然后用3 * 3卷積核大小進行特征提取,最后通過1 * 1卷積核進行通道還原,后會將每個分組的結構進行維度拼接并與原始特征進行殘差連接。

在這里插入圖片描述

cardinatity指的是一個block中所具有的相同分支的數目,即“組數”.

下面進行ResNext-50網絡圖的搭建(pytorch復現)

2、ResNext-50實驗

1、導入數據

1、導入庫

import torch  
import torch.nn as nn
import torchvision 
import numpy as np 
import os, PIL, pathlib # 設置設備
device = "cuda" if torch.cuda.is_available() else "cpu"device 
'cuda'

2、查看數據信息和導入數據

數據目錄有兩個文件:一個數據文件,一個權重。

data_dir = "./data/"data_dir = pathlib.Path(data_dir)# 類別數量
classnames = [str(path).split('/')[0] for path in os.listdir(data_dir)]classnames
['Monkeypox', 'Others']

3、展示數據

import matplotlib.pylab as plt  
from PIL import Image # 獲取文件名稱
data_path_name = "./data/Others"
data_path_list = [f for f in os.listdir(data_path_name) if f.endswith(('jpg', 'png'))]# 創建畫板
fig, axes = plt.subplots(2, 8, figsize=(16, 6))for ax, img_file in zip(axes.flat, data_path_list):path_name = os.path.join(data_path_name, img_file)img = Image.open(path_name) # 打開# 顯示ax.imshow(img)ax.axis('off')plt.show()

?
在這里插入圖片描述

?

4、數據導入

from torchvision import transforms, datasets # 數據統一格式
img_height = 224
img_width = 224 data_tranforms = transforms.Compose([transforms.Resize([img_height, img_width]),transforms.ToTensor(),transforms.Normalize(   # 歸一化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
])# 加載所有數據
total_data = datasets.ImageFolder(root=data_dir, transform=data_tranforms)

5、數據劃分

# 大小 8 : 2
train_size = int(len(total_data) * 0.8)
test_size = len(total_data) - train_size train_data, test_data = torch.utils.data.random_split(total_data, [train_size, test_size])

6、動態加載數據

batch_size = 32 train_dl = torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True
)test_dl = torch.utils.data.DataLoader(test_data,batch_size=batch_size,shuffle=False
)
# 查看數據維度
for data, labels in train_dl:print("data shape[N, C, H, W]: ", data.shape)print("labels: ", labels)break
data shape[N, C, H, W]:  torch.Size([32, 3, 224, 224])
labels:  tensor([1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,0, 1, 0, 0, 0, 1, 0, 0])

2、構建ResNext-50網絡

ResNet-50網絡結構圖
在這里插入圖片描述

在這里插入圖片描述

在復現ResNext50網絡中,我查閱了不少資料,但是我好像都沒怎么看懂那個代碼,后面我發現這個就是在ResNet50上加了分組卷積,其他網絡結構就是在每一層,第二層的數量是resnet的2倍,后面基于以前搭建的ResNet網絡結果進行修改,代碼如下所示。

在ResNext50中,有幾個參數需要注意:

  • 分組卷積:cardinality參數代表分組卷積數量,在Conv2d中groups參數就是分組卷積數量。
  • 通道數計算:每組的輸出通道數由 group_depth 決定,總輸出通道數為 cardinality × group_depth。這里,下面本人搭建的ResNext50網絡結構,每一層輸入通道數,輸出通道數,都是自己手動輸入的,故這里group_depth隱藏在filters中(手動計算).

回憶
Bottleneck 的基本概念

Bottleneck 結構通常由三個卷積層組成,他是ResNet以及其變體的基本網絡層單元。

  1. 第一個 1×1 卷積:降低輸入特征圖的通道數,減少后續計算量。
  2. 中間的 3×3 卷積:核心特征提取過程。在 ResNeXt 中,這一層使用分組卷積來增強表達能力。
  3. 最后一個 1×1 卷積:恢復通道數到原始或者更高的數量,以便與輸入特征圖進行殘差連接。

注意:

  • 在ResNext網絡結構中,分組卷積只在Bottleneck只在第二層使用
import torch.nn.functional as F# Bottleneck: 分為殘差模塊一、殘差模塊二# 定義殘差模塊一,這個用于處理輸入和輸出通道一樣的情況
'''  
卷積核大小:1       3       1
核心特點:尺寸不變:輸入和輸出的尺寸保持一致。 沒有下采樣:沒有使用步長大于1的卷積操作,因此沒有改變特征圖的空間尺寸
'''
class Identity_block(nn.Module):def __init__(self, in_channels, kernel_size, filters, cardinality):super(Identity_block, self).__init__()# 輸出通道filter1, filter2, filter3 = filters# 卷積層一, 降維self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=1)self.bn1 = nn.BatchNorm2d(filter1)# 卷積層2, 分組卷積, 核心:特征提取self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1,groups=cardinality)   # 通過卷積輸入輸出公式發現,padding=1,可以保證輸入和輸出尺寸相同self.bn2 = nn.BatchNorm2d(filter2)# 卷積層3, 升維self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1)self.bn3 = nn.BatchNorm2d(filter3)def forward(self, x):# 記錄原始值xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))# 殘差連接,輸入、輸出維度不變x += xxx = F.relu(x)return x # 定義卷積模塊二:用于處理輸入和輸出不一樣的情況
'''  
* 卷積核還是:1 3 1
* stride=2
* 這里的分支是采用一個Conv2D,和一個歸一化BN層,也是為了處理數據維度吧, 這種維度的變化,可以用ai舉例子核心特點:尺寸變化,stride=2降維
'''
class ConvBlock(nn.Module):def __init__(self, in_channels, kernel_size, filters, cardinality, stride=2):super(ConvBlock, self).__init__()filter1, filter2, filter3= filters# 卷積層1, 降維self.conv1 = nn.Conv2d(in_channels, filter1, kernel_size=1, stride=stride)self.bn1 = nn.BatchNorm2d(filter1)# 卷積2, 分組卷積,核心:特征提取self.conv2 = nn.Conv2d(filter1, filter2, kernel_size=kernel_size, padding=1,groups=cardinality) # 需要維持維度不變self.bn2 = nn.BatchNorm2d(filter2)# 卷積3, 降維self.conv3 = nn.Conv2d(filter2, filter3, kernel_size=1, stride=1)  # stride = 1,維持通道不變self.bn3 = nn.BatchNorm2d(filter3)# 用于匹配維度的shortcut卷積,這個就是上面Identity_block的x分支self.shortcut = nn.Conv2d(in_channels, filter3, kernel_size=1, stride=stride)self.shortcut_bn = nn.BatchNorm2d(filter3)def forward(self, x):xx = xx = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.bn3(self.conv3(x))temp = self.shortcut_bn(self.shortcut(xx))x += tempx = F.relu(x)return x # 定義ResNext50
class ResNext50(nn.Module):def __init__(self, classes):   # 類別數量super().__init__()# 頭頂, resnet以及變體一般都是這個self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)self.bn1 = nn.BatchNorm2d(64)self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 第一部分self.part1_1 = ConvBlock(64, 3, [128, 128, 256], cardinality=32, stride=1)self.part1_2 = Identity_block(256, 3, [128, 128, 256], cardinality=32)self.part1_3 = Identity_block(256, 3, [128, 128, 256], cardinality=32)# 第二部分self.part2_1 = ConvBlock(256, 3, [256, 256, 512], cardinality=32)self.part2_2 = Identity_block(512, 3, [256, 256, 512], cardinality=32)self.part2_3 = Identity_block(512, 3, [256, 256, 512], cardinality=32)self.part2_4 = Identity_block(512, 3, [256, 256, 512], cardinality=32)# 第三部分self.part3_1 = ConvBlock(512, 3, [512, 512, 1024], cardinality=32)self.part3_2 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_3 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_4 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_5 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)self.part3_6 = Identity_block(1024, 3, [512, 512, 1024], cardinality=32)# 第四部分self.part4_1 = ConvBlock(1024, 3, [1024, 1024, 2048], cardinality=32)self.part4_2 = Identity_block(2048, 3, [1024, 1024, 2048], cardinality=32)self.part4_3 = Identity_block(2048, 3, [1024, 1024, 2048], cardinality=32)# 平均池化self.avg_pool = nn.AvgPool2d(kernel_size=7)# 全連接self.fn1 = nn.Linear(2048, classes)def forward(self, x):# 頭部x = F.relu(self.bn1(self.conv1(x)))x = self.max_pool(x)x = self.part1_1(x)x = self.part1_2(x)x = self.part1_3(x)x = self.part2_1(x)x = self.part2_2(x)x = self.part2_3(x)x = self.part2_4(x)x = self.part3_1(x)x = self.part3_2(x)x = self.part3_3(x)x = self.part3_4(x)x = self.part3_5(x)x = self.part3_6(x)x = self.part4_1(x)x = self.part4_2(x)x = self.part4_3(x)x = self.avg_pool(x)x = x.view(x.size(0), -1)  # 扁平化x = self.fn1(x)return x model = ResNext50(classes=len(classnames)).to(device)model
ResNext50((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(max_pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(part1_1): ConvBlock((conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))(shortcut_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_2): Identity_block((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part1_3): Identity_block((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_1): ConvBlock((conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_2): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_3): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part2_4): Identity_block((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_1): ConvBlock((conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_2): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_3): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_4): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_5): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part3_6): Identity_block((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_1): ConvBlock((conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(2, 2))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(shortcut): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2))(shortcut_bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_2): Identity_block((conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(part4_3): Identity_block((conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1))(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1))(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))(avg_pool): AvgPool2d(kernel_size=7, stride=7, padding=0)(fn1): Linear(in_features=2048, out_features=2, bias=True)
)

3、模型訓練

1、構建訓練集

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)batch_size = len(dataloader)train_acc, train_loss = 0, 0 for X, y in dataloader:X, y = X.to(device), y.to(device)# 訓練pred = model(X)loss = loss_fn(pred, y)# 梯度下降法optimizer.zero_grad()loss.backward()optimizer.step()# 記錄train_loss += loss.item()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_acc /= sizetrain_loss /= batch_sizereturn train_acc, train_loss

2、構建測試集

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)batch_size = len(dataloader)test_acc, test_loss = 0, 0 with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)test_loss += loss.item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_acc /= sizetest_loss /= batch_sizereturn test_acc, test_loss

3、設置超參數

loss_fn = nn.CrossEntropyLoss()  # 損失函數     
learn_lr = 1e-4            # 超參數
optimizer = torch.optim.Adam(model.parameters(), lr=learn_lr)   # 優化器

4、模型訓練

import copy train_acc = []
train_loss = []
test_acc = []
test_loss = []epoches = 50best_acc = 0for i in range(epoches):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_model     if epoch_test_acc > best_acc:         best_acc   = epoch_test_acc         best_model = copy.deepcopy(model)  # 拷貝最好模型train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 獲取當前的學習率     lr = optimizer.state_dict()['param_groups'][0]['lr']# 輸出template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')print(template.format(i + 1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))print("Done")PATH = './best_model.pth'  # 保存的參數文件名 
torch.save(best_model.state_dict(), PATH)
Epoch: 1, Train_acc:62.3%, Train_loss:0.696, Test_acc:66.4%, Test_loss:0.604
Epoch: 2, Train_acc:67.9%, Train_loss:0.620, Test_acc:69.9%, Test_loss:0.580
Epoch: 3, Train_acc:69.5%, Train_loss:0.580, Test_acc:68.3%, Test_loss:0.603
Epoch: 4, Train_acc:71.6%, Train_loss:0.547, Test_acc:73.9%, Test_loss:0.530
Epoch: 5, Train_acc:74.7%, Train_loss:0.519, Test_acc:75.1%, Test_loss:0.520
Epoch: 6, Train_acc:78.2%, Train_loss:0.464, Test_acc:67.8%, Test_loss:0.683
Epoch: 7, Train_acc:78.1%, Train_loss:0.459, Test_acc:69.0%, Test_loss:0.652
Epoch: 8, Train_acc:80.8%, Train_loss:0.411, Test_acc:72.7%, Test_loss:0.643
Epoch: 9, Train_acc:84.8%, Train_loss:0.362, Test_acc:74.8%, Test_loss:0.575
Epoch:10, Train_acc:87.4%, Train_loss:0.314, Test_acc:77.9%, Test_loss:0.536
Epoch:11, Train_acc:89.3%, Train_loss:0.266, Test_acc:79.0%, Test_loss:0.505
Epoch:12, Train_acc:89.4%, Train_loss:0.260, Test_acc:78.3%, Test_loss:0.601
Epoch:13, Train_acc:90.7%, Train_loss:0.226, Test_acc:81.4%, Test_loss:0.493
Epoch:14, Train_acc:93.9%, Train_loss:0.159, Test_acc:80.4%, Test_loss:0.616
Epoch:15, Train_acc:93.8%, Train_loss:0.152, Test_acc:80.4%, Test_loss:0.620
Epoch:16, Train_acc:92.2%, Train_loss:0.190, Test_acc:82.3%, Test_loss:0.621
Epoch:17, Train_acc:94.0%, Train_loss:0.142, Test_acc:82.3%, Test_loss:0.582
Epoch:18, Train_acc:95.8%, Train_loss:0.106, Test_acc:79.3%, Test_loss:0.625
Epoch:19, Train_acc:95.5%, Train_loss:0.127, Test_acc:81.1%, Test_loss:0.625
Epoch:20, Train_acc:95.4%, Train_loss:0.113, Test_acc:83.0%, Test_loss:0.482
Epoch:21, Train_acc:96.7%, Train_loss:0.087, Test_acc:83.0%, Test_loss:0.667
Epoch:22, Train_acc:97.3%, Train_loss:0.083, Test_acc:80.4%, Test_loss:0.695
Epoch:23, Train_acc:97.1%, Train_loss:0.077, Test_acc:83.7%, Test_loss:0.634
Epoch:24, Train_acc:96.6%, Train_loss:0.086, Test_acc:82.5%, Test_loss:0.732
Epoch:25, Train_acc:96.6%, Train_loss:0.098, Test_acc:83.9%, Test_loss:0.711
Epoch:26, Train_acc:96.0%, Train_loss:0.107, Test_acc:75.3%, Test_loss:0.821
Epoch:27, Train_acc:95.6%, Train_loss:0.105, Test_acc:81.6%, Test_loss:0.596
Epoch:28, Train_acc:96.7%, Train_loss:0.088, Test_acc:84.4%, Test_loss:0.606
Epoch:29, Train_acc:97.5%, Train_loss:0.071, Test_acc:86.5%, Test_loss:0.615
Epoch:30, Train_acc:98.2%, Train_loss:0.051, Test_acc:80.4%, Test_loss:0.772
Epoch:31, Train_acc:98.5%, Train_loss:0.041, Test_acc:83.7%, Test_loss:0.694
Epoch:32, Train_acc:98.5%, Train_loss:0.048, Test_acc:82.8%, Test_loss:0.671
Epoch:33, Train_acc:97.7%, Train_loss:0.064, Test_acc:84.1%, Test_loss:0.745
Epoch:34, Train_acc:98.4%, Train_loss:0.054, Test_acc:83.7%, Test_loss:0.661
Epoch:35, Train_acc:98.2%, Train_loss:0.068, Test_acc:83.0%, Test_loss:0.605
Epoch:36, Train_acc:96.8%, Train_loss:0.086, Test_acc:83.2%, Test_loss:0.551
Epoch:37, Train_acc:97.8%, Train_loss:0.063, Test_acc:82.3%, Test_loss:0.739
Epoch:38, Train_acc:97.6%, Train_loss:0.065, Test_acc:83.0%, Test_loss:0.583
Epoch:39, Train_acc:98.2%, Train_loss:0.045, Test_acc:83.4%, Test_loss:0.697
Epoch:40, Train_acc:98.1%, Train_loss:0.048, Test_acc:82.5%, Test_loss:0.710
Epoch:41, Train_acc:98.2%, Train_loss:0.054, Test_acc:83.2%, Test_loss:0.564
Epoch:42, Train_acc:98.4%, Train_loss:0.051, Test_acc:85.5%, Test_loss:0.514
Epoch:43, Train_acc:99.0%, Train_loss:0.025, Test_acc:83.9%, Test_loss:0.663
Epoch:44, Train_acc:99.1%, Train_loss:0.029, Test_acc:85.5%, Test_loss:0.594
Epoch:45, Train_acc:98.3%, Train_loss:0.036, Test_acc:84.6%, Test_loss:0.719
Epoch:46, Train_acc:98.7%, Train_loss:0.036, Test_acc:84.4%, Test_loss:0.631
Epoch:47, Train_acc:97.7%, Train_loss:0.055, Test_acc:81.4%, Test_loss:0.643
Epoch:48, Train_acc:98.7%, Train_loss:0.040, Test_acc:85.1%, Test_loss:0.607
Epoch:49, Train_acc:98.8%, Train_loss:0.037, Test_acc:80.2%, Test_loss:0.897
Epoch:50, Train_acc:98.6%, Train_loss:0.042, Test_acc:84.4%, Test_loss:0.601
Done

5、結果可視化

import matplotlib.pyplot as plt
#隱藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息epochs_range = range(epoches)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training= Loss')
plt.show()

?
在這里插入圖片描述

?

6、模型評估

# 加載最好模型
best_model.load_state_dict(torch.load(PATH, map_location=device)) 
# 模型測試
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)print(epoch_test_acc, epoch_test_loss)
0.8648018648018648 0.6145411878824234

3、參考資料

  • 深度學習——分類之ResNeXt - 知乎
  • 通義 - 你的個人AI助手
  • ResNeXt代碼復現+超詳細注釋(PyTorch)-CSDN博客

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/75483.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/75483.shtml
英文地址,請注明出處:http://en.pswp.cn/web/75483.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SQL 全文檢索原理

全文檢索(Full-Text Search)是SQL中用于高效搜索文本數據的技術,與傳統的LIKE操作或簡單字符串比較相比,它能提供更強大、更靈活的文本搜索能力。 基本概念 全文檢索的核心思想是將文本內容分解為可索引的單元(通常是詞或詞組),然后建立倒排…

【Linux】Orin NX編譯 linux 內核及內核模塊

1、下載交叉編譯工具:gcc 1)下載地址:https://developer.nvidia.com/embedded/jetson-linux 選擇TOOLS中的交叉編譯工具:gcc 11.3 2)解壓 將gcc編譯器解壓到指定目錄中,如:/home/laoer/nvidia/gcc 3)配置環境變量 創建: ~/nvidia/gcc/env.sh添加: #!/bin/bash e…

Transformers 是工具箱,BERT 是工具。

Transformers 是工具箱,BERT 是工具。 🔍 詳細解釋: 名稱作用比喻理解舉例🤖 transformers(庫)一個框架,提供很多 NLP 模型的“使用方式”,包括文本分類、問答、摘要等相當于一個“…

k8s之Service類型詳解

1.ClusterIP 類型 2.NodePort 類型 3.LoadBalancer 類型 4.ExternalName 類型 類型為 ExternalName 的 Service 將 Service 映射到 DNS 名稱,而不是典型的選擇算符, 例如 my-service 或者 cassandra。你可以使用 spec.externalName 參數指定這些服務…

find指令中使用正則表達式

linux查找命令能結合正則表達式嗎 find命令要使用正則表達式需要結合-regex參數 另,-type參數可以指定查找類型(f為文件,d為文件夾) rootlocalhost:~/regular_expression# ls -alh 總計 8.0K drwxr-xr-x. 5 root root 66 4月 8日 16:26 . dr-xr-…

《穿透表象,洞察分布式軟總線“無形”之奧秘》

分布式系統已成為眾多領域的關鍵支撐技術,而分布式軟總線作為實現設備高效互聯的核心技術,正逐漸走入大眾視野。它常被描述為一條“無形”的總線,這一獨特屬性不僅是理解其技術內涵的關鍵,更是把握其在未來智能世界中重要作用的切…

Ubuntu虛擬機連不上網

橋接 虛擬機Ubuntu系統必須能連接到外網,不然不能更新軟件安裝包 配置虛擬機網絡(關機或者掛起狀態) 第一步1.重啟虛擬機網絡編輯器(還原配置) 第二步2.重啟虛擬機網絡適配器(移除再添加) 啟…

rom定制系列------紅米9A批量線刷原生安卓14雙版 miui系統解鎖可登陸線刷固件

紅米9A。聯發科Helio G25芯片。該處理器支持64位運算?,但此機miui系統運行環境是32位的,這意味著盡管處理器本身支持64位計算,但miui系統限制在32位環境下運行?。官方miui系統穩定版最終為12.5.21安卓11的版本。 原生安卓14批量線刷功能固…

Matlab 分數階PID控制永磁同步電機

1、內容簡介 Matlab 203-分數階PID控制永磁同步電機 可以交流、咨詢、答疑 2、內容說明 略 3、仿真分析 略 4、參考論文 略

Flink的 RecordWriter 數據通道 詳解

本文從基礎原理到代碼層面逐步解釋 Flink 的RecordWriter 數據通道,盡量讓初學者也能理解。 1. 什么是 RecordWriter? 通俗理解 RecordWriter 是 Flink 中負責將數據從一個任務(Task)發送到下游任務的組件。想象一下,…

Dubbo、HTTP、RMI之間的區別

Dubbo、HTTP、RMI之間的區別如下: 表格 復制 特性DubboHTTPRMI通信機制基于Netty的NIO異步通信,采用長連接,支持多種序列化方式基于標準的HTTP協議,無狀態,每次請求獨立基于Java原生的RMI機制,支持Java對…

wkhtmltopdf生成圖片的實踐教程,包含完整的環境配置、參數解析及多語言調用示例

歡迎來到濤濤聊AI,最近在研究HTML生成卡片的功能,一起學習下吧。 一、工具特性與安裝 wkhtmltoimage是基于WebKit引擎的開源命令行工具,可將HTML網頁轉換為JPG/PNG等圖片格式,支持CSS渲染、JavaScript執行和響應式布局。安裝方式…

【在Node.js項目中引入TypeScript:提高開發效率及框架選型指南】

一、TypeScript在Node.js中的核心價值 1.1 靜態類型檢測 // 錯誤示例:TypeScript會報錯 function add(a: number, b: string) {return a b }1.2 工具鏈增強 # 安裝必要依賴 npm install --save-dev typescript types/node ts-node tsconfig.json1.3 代碼維護性提…

化工企業數字化轉型:從數據貫通到生態重構的實踐路徑

一、戰略定位:破解行業核心痛點 化工行業面臨生產安全風險高(全國危化品企業事故率年增5%)、能耗與排放壓力大(占工業總能耗12%)、供應鏈協同低效(庫存周轉率低于制造業均值30%)三大挑戰。《石…

C#網絡編程(Socket編程)

文章目錄 0、寫在前面的話1、Socket 介紹1.1 Socket是什么1.2 Socket在網絡中的位置 2、C# 中的Socket參數2.1 超時控制參數2.2 緩沖區參數2.3 UDP專用參數 3、C# 中的Socket API3.1 Socket(構造函數)3.1.1 SocketType3.1.2 ProtocolType3.1.3 AddressFa…

Docker部署ES集群

引言: Elasticsearch(ES)作為分布式搜索引擎,其核心價值在于通過集群部署實現高可用性和數據冗余。 本實驗對比兩種典型部署方案: 原生Linux部署:直接安裝ES服務,適用于生產環境,資…

老硬件也能運行的Win11 IoT LTSC (OEM)物聯網版

#記錄工作 Windows 11 IoT Enterprise LTSC 2024 屬于物聯網相關的版本。 Windows 11 IoT Enterprise 是為物聯網設備和場景設計的操作系統版本。它通常針對特定的工業控制、智能設備等物聯網應用進行了優化和定制,以滿足這些領域對穩定性、安全性和長期支持的需求…

【教程】xrdp修改遠程桌面環境為xfce4

轉載請注明出處:小鋒學長生活大爆炸[xfxuezhagn.cn] 如果本文幫助到了你,歡迎[點贊、收藏、關注]哦~ 目錄 xfce4 vs GNOME對比 配置教程 1. 安裝 xfce4 桌面環境 2. 安裝 xrdp 3. 配置 xrdp 使用 xfce4 4. 重啟 xrdp 服務 5. 配置防火墻&#xff…

【數據結構 · 初階】- 順序表

目錄 一、線性表 二、順序表 1.實現動態順序表 SeqList.h SeqList.c Test.c 問題 經驗:free 出問題,2種可能性 解決問題 (2)尾刪 (3)頭插,頭刪 (4)在 pos 位…

windows主機中構建適用于K8S Operator開發環境

基于win 10 打造K8S應用開發環境(wsl & kind) 一、wsl子系統安裝 1.1 確認windows系統版本 cmd/powershell 或者win r 運行winver 操作系統要> 19044 1.2 開啟wsl功能 控制面板 -> 程序 -> 啟用或關閉Windows功能 開啟適用于Linu…