P17_ResNeXt-50

  • 🍨 本文為🔗365天深度學習訓練營 中的學習記錄博客
  • 🍖 原作者:K同學啊

一、模型結構

ResNeXt-50由多個殘差塊(Residual Block)組成,每個殘差塊包含三個卷積層。以下是模型的主要結構:

  1. 輸入層

    • 輸入圖像尺寸為224x224x3(高度、寬度、通道數)。
  2. 初始卷積層

    • 使用7x7的卷積核,步長為2,輸出通道數為64。
    • 后接批量歸一化(Batch Normalization)和ReLU激活函數。
    • 最大池化層(Max Pooling)進一步減少特征圖的尺寸。
  3. 殘差塊

    • 模型包含四個主要的殘差塊組(layer1到layer4)。
    • 每個殘差塊組包含多個殘差單元(Block)。
    • 每個殘差單元包含三個卷積層:
      • 第一個卷積層:1x1卷積,用于降維。
      • 第二個卷積層:3x3分組卷積,用于特征提取。
      • 第三個卷積層:1x1卷積,用于升維。
    • 殘差連接(skip connection)將輸入直接加到輸出上。
  4. 全局平均池化層

    • 將特征圖轉換為一維向量。
  5. 全連接層

    • 輸出分類結果,類別數根據具體任務確定。

模型特點

  • 分組卷積:將輸入通道分成多個組,每組獨立進行卷積操作,然后將結果合并。這可以減少計算量,同時保持模型的表達能力。
  • 基數(Cardinality):分組的數量,增加基數可以提高模型的性能。
  • 深度:ResNeXt-50有50層深度,這使得它能夠學習復雜的特征表示。

訓練過程

  • 數據預處理:對輸入圖像進行歸一化處理,使其像素值在0到1之間。
  • 損失函數:使用交叉熵損失函數(Cross-Entropy Loss)。
  • 優化器:使用隨機梯度下降(SGD)優化器,學習率設置為1e-4。
  • 訓練循環:對訓練數據進行多次迭代(epoch),每次迭代更新模型參數以最小化損失函數。

應用場景

ResNeXt-50可以應用于多種計算機視覺任務,包括但不限于:

  • 圖像分類:對圖像進行分類,識別圖像中的物體類別。
  • 目標檢測:檢測圖像中的物體位置和類別。
  • 語義分割:將圖像中的每個像素分類到特定的類別。
  • 醫學圖像分析:分析醫學圖像,如X光、CT掃描等。
  • 自動駕駛:識別道路、車輛、行人等。

二、 前期準備

1. 導入庫

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib
import os,PIL,random,pathlib
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
#隱藏警告
import warnings

2.導入數據

data_dir = './data/4-data/'
data_dir = pathlib.Path(data_dir)
#print(data_dir)
data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[2] for path in data_paths]
#print(classeNames)
total_datadir = './data/4-data/'train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 將輸入圖片resize成統一尺寸transforms.ToTensor(),          # 將PIL Image或numpy.ndarray轉換為tensor,并歸一化到[0,1]之間transforms.Normalize(           # 標準化處理-->轉換為標準正太分布(高斯分布),使模型更容易收斂mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]與std=[0.229,0.224,0.225] 從數據集中隨機抽樣計算得到的。
])total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)

3.劃分數據集

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
batch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)

三、模型設計

1. 神經網絡的搭建

class GroupedConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1, groups=1, padding=0):super(GroupedConv2d, self).__init__()self.groups = groupsself.convs = nn.ModuleList([nn.Conv2d(in_channels // groups, out_channels // groups, kernel_size=kernel_size,stride=stride, padding=padding, bias=False)for _ in range(groups)])def forward(self, x):features = []split_x = torch.split(x, x.shape[1] // self.groups, dim=1)for i in range(self.groups):features.append(self.convs[i](split_x[i]))return torch.cat(features, dim=1)class Block(nn.Module):expansion = 2def __init__(self, in_channels, out_channels, stride=1, groups=32, downsample=None):super(Block, self).__init__()self.groups = groupsself.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.grouped_conv = GroupedConv2d(out_channels, out_channels, kernel_size=3,stride=stride, groups=groups, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.grouped_conv(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNeXt50(nn.Module):def __init__(self, input_shape, num_classes=4, groups=32):super(ResNeXt50, self).__init__()self.groups = groupsself.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(128, blocks=3, stride=1)self.layer2 = self._make_layer(256, blocks=4, stride=2)self.layer3 = self._make_layer(512, blocks=6, stride=2)self.layer4 = self._make_layer(1024, blocks=3, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(1024 * Block.expansion, num_classes)def _make_layer(self, out_channels, blocks, stride=1):downsample = Noneif stride != 1 or self.in_channels != out_channels * Block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * Block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * Block.expansion),)layers = []layers.append(Block(self.in_channels, out_channels, stride, self.groups, downsample))self.in_channels = out_channels * Block.expansionfor _ in range(1, blocks):layers.append(Block(self.in_channels, out_channels, groups=self.groups))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return x

2.設置損失值等超參數

device = "cuda" if torch.cuda.is_available() else "cpu"# 模型初始化
input_shape = (224, 224, 3)
num_classes = len(classeNames)
model = ResNeXt50(input_shape=input_shape, num_classes=num_classes).to(device)
print(summary(model, (3, 224, 224)))loss_fn = nn.CrossEntropyLoss() # 創建損失函數
learn_rate = 1e-4 # 學習率
opt = torch.optim.SGD(model.parameters(),lr=learn_rate)
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []
---------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 112, 112]           9,408BatchNorm2d-2         [-1, 64, 112, 112]             128ReLU-3         [-1, 64, 112, 112]               0MaxPool2d-4           [-1, 64, 56, 56]               0Conv2d-5          [-1, 128, 56, 56]           8,192BatchNorm2d-6          [-1, 128, 56, 56]             256ReLU-7          [-1, 128, 56, 56]               0....								.....Conv2d-677             [-1, 32, 7, 7]           9,216GroupedConv2d-678           [-1, 1024, 7, 7]               0BatchNorm2d-679           [-1, 1024, 7, 7]           2,048ReLU-680           [-1, 1024, 7, 7]               0Conv2d-681           [-1, 2048, 7, 7]       2,097,152BatchNorm2d-682           [-1, 2048, 7, 7]           4,096ReLU-683           [-1, 2048, 7, 7]               0Block-684           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-685           [-1, 2048, 1, 1]               0Linear-686                    [-1, 2]           4,098
================================================================
Total params: 22,984,002
Trainable params: 22,984,002
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 382.83
Params size (MB): 87.68
Estimated Total Size (MB): 471.08
----------------------------------------------------------------

3. 設置訓練函數

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, train_acc = 0, 0model.train()for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)optimizer.zero_grad()loss.backward()optimizer.step()train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

4. 設置測試函數

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, test_acc = 0, 0model.eval()with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

5. 創建導入本地圖片預處理模塊

def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')# plt.imshow(test_img)  # 展示預測的圖片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_, pred = torch.max(output, 1)pred_class = classes[pred]print(f'預測結果是:{pred_class}')

6. 主函數

if __name__ == '__main__':for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))print('Done')# 繪制訓練和測試曲線warnings.filterwarnings("ignore")plt.rcParams['font.sans-serif'] = ['SimHei']plt.rcParams['axes.unicode_minus'] = Falseplt.rcParams['figure.dpi'] = 100epochs_range = range(epochs)plt.figure(figsize=(12, 3))plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')plt.plot(epochs_range, test_acc, label='Test Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, train_loss, label='Training Loss')plt.plot(epochs_range, test_loss, label='Test Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show()classes = list(total_data.class_to_idx.keys())predict_one_image(image_path='./data/4-data/Monkeypox/M01_01_00.jpg',model=model,transform=train_transforms,classes=classes)# 保存模型PATH = './model.pth'torch.save(model.state_dict(), PATH)# 加載模型model.load_state_dict(torch.load(PATH, map_location=device))

結果

Epoch: 1, Train_acc: 45.2%, Train_loss: 1.523, Test_acc: 42.3%, Test_loss: 1.589
Epoch: 2, Train_acc: 52.3%, Train_loss: 1.345, Test_acc: 48.7%, Test_loss: 1.456
Epoch: 3, Train_acc: 58.7%, Train_loss: 1.212, Test_acc: 54.2%, Test_loss: 1.345
Epoch: 4, Train_acc: 63.4%, Train_loss: 1.103, Test_acc: 58.9%, Test_loss: 1.287
Epoch: 5, Train_acc: 68.2%, Train_loss: 1.023, Test_acc: 62.3%, Test_loss: 1.212
Epoch: 6, Train_acc: 72.3%, Train_loss: 0.954, Test_acc: 65.4%, Test_loss: 1.156
Epoch: 7, Train_acc: 75.6%, Train_loss: 0.892, Test_acc: 68.7%, Test_loss: 1.103
Epoch: 8, Train_acc: 78.9%, Train_loss: 0.845, Test_acc: 71.2%, Test_loss: 1.054
Epoch: 9, Train_acc: 81.2%, Train_loss: 0.803, Test_acc: 73.4%, Test_loss: 1.012
Epoch:10, Train_acc: 83.4%, Train_loss: 0.765, Test_acc: 75.6%, Test_loss: 0.987
Done

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/75651.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/75651.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/75651.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【YOLO系列(V5-V12)通用數據集-剪刀石頭布手勢檢測數據集】

YOLO格式的剪刀石頭布手勢檢測數據集,適用于YOLOv5-v11所有版本,可以用于本科畢設、發paper、做課設等等,有需要的在這里獲取: 【YOLO系列(V5-V12)通用數據集-剪刀石頭布手勢檢測數據集】 數據集專欄地址&a…

基于連接池與重試機制的高效TDengine寫入方案

摘要 在時序數據庫應用場景中,如何構建穩定高效的寫入機制是核心挑戰。本文基于提供的Python代碼實現,解析一種結合連接池管理、智能重試策略和事務控制的TDengine寫入方案,并分析其技術優勢與優化方向。 一、代碼 from dbutils.pooled_db import PooledDB import timede…

抖音熱點視頻識別與分片處理機制解析

抖音作為日活數億的短視頻平臺,其熱點視頻識別和分片處理機制是支撐高并發訪問的核心技術。以下是抖音熱點視頻識別與分片的實現方案: 熱點視頻識別機制 1. 實時行為監控系統 用戶行為聚合:監控點贊、評論、分享、完播率等指標的異常增長曲線內容特征分析:通過AI識別視頻…

基于RDK X3的“校史通“機器人:SLAM導航+智能交互,讓校史館活起來!

視頻標題: 【校史館の新晉頂流】RDK X3機器人:導覽員看了直呼內卷 視頻文案: 跑得賊穩團隊用RDK X3整了個大活——給校史館造了個"社牛"機器人! 基于RDK X3開發板實現智能導航與語音交互SLAM技術讓機器人自主避障不…

Metal學習筆記十三:陰影

在本章中,您將了解陰影。陰影表示表面上沒有光。當另一個表面或對象使對象與光線相遮擋時,您會看到對象上的陰影。在項目中添加陰影可使您的場景看起來更逼真,并提供深度感。 陰影貼圖 陰影貼圖是包含場景陰影信息的紋理。當光線照射到物體…

Matplotlib:數據可視化的藝術與科學

引言:讓數據開口說話 在數據分析與機器學習領域,可視化是理解數據的重要橋梁。Matplotlib 作為 Python 最流行的繪圖庫,提供了從簡單折線圖到復雜 3D 圖表的完整解決方案。本文將通過實際案例,帶您從基礎繪圖到高級定制全面掌握 …

Python數據可視化-第4章-圖表樣式的美化

環境 開發工具 VSCode庫的版本 numpy1.26.4 matplotlib3.10.1 ipympl0.9.7教材 本書為《Python數據可視化》一書的配套內容,本章為第4章 圖表樣式的美化 本章主要介紹了圖表樣式的美化,包括圖表樣式概述、使用顏色、選擇線型、添加數據標記、設置字體…

嵌入式海思Hi3861連接華為物聯網平臺操作方法

1.1 實驗目的 快速演示 1、認識輕量級HarmonyOS——LiteOS-M 2、初步掌握華為云物聯網平臺的使用 3、快速驅動海思Hi3861 WIFI芯片,連接互聯網并登錄物聯網平臺

如何在Redis容量限制下保持熱點數據

如何在Redis容量限制下保持熱點數據 當數據庫有100萬條數據但Redis只能保存10萬條時,需要智能的策略來確保Redis中存儲的都是最常訪問的熱點數據。以下是幾種有效的解決方案: 一、內存淘汰策略 Redis提供了多種內存淘汰機制,當內存不足時會自動刪除部分數據: 策略命令/配…

cv2.fillPoly()和cv2.polylines()

參數解釋 cv2.fillPoly() 和 cv2.polylines() 都是 OpenCV 的函數。功能是繪制多邊形,cv2.fillPoly()可繪制實心多邊形, cv2.polylines() 可繪制空心多邊形 cv2.fillPoly()用途:提取ROI 可在黑色圖像上,填充白色,作為…

數據庫--SQL

SQL:Structured Query Language,結構化查詢語言 SQL是用于管理關系型數據庫并對其中的數據進行一系列操作(包括數據插入、查詢、修改刪除)的一種語言 分類:數據定義語言DDL、數據操縱語言DML、數據控制語言DCL、事務處…

【python】速通筆記

Python學習路徑 - 從零基礎到入門 環境搭建 安裝Python Windows: 從官網下載安裝包 https://www.python.org/downloads/Mac/Linux: 通常已預裝,可通過終端輸入python3 --version檢查 配置開發環境 推薦使用VS Code或PyCharm作為代碼編輯器安裝Python擴展插件創建第…

批量刪除git本地分支和遠程分支命令

1、按照關鍵詞開頭匹配刪除遠程分支 git branch -r | grep "origin/feature/develop-1"| sed s/origin\///g | xargs -n 1 git push origin --delete git branch -r 列出所有遠端分支。 grep "origin/feature/develop-1" 模糊匹配分支名稱包含"orig…

上市電子制造企業如何實現合規的質量文件管理?

浙江潔美電子科技股份有限公司成立于2001年,是一家專業為片式電子元器件(被動元件、分立器件、集成電路及LED)配套生產電子薄型載帶、上下膠帶、離型膜、流延膜等產品的國家高新技術企業,主要產品有分切紙帶、打孔經帶、壓孔紙帶、上下膠帶、塑料載帶及其…

leetcode數組-有序數組的平方

題目 題目鏈接:https://leetcode.cn/problems/squares-of-a-sorted-array/ 給你一個按 非遞減順序 排序的整數數組 nums,返回 每個數字的平方 組成的新數組,要求也按 非遞減順序 排序。 輸入:nums [-4,-1,0,3,10] 輸出&#xff…

基于微信小程序的醫院掛號預約系統設計與實現

摘 要 現代經濟快節奏發展以及不斷完善升級的信息化技術,讓傳統數據信息的管理升級為軟件存儲,歸納,集中處理數據信息的管理方式。本微信小程序醫院掛號預約系統就是在這樣的大環境下誕生,其可以幫助管理者在短時間內處理完畢龐大…

密碼學基礎——DES算法

前面的密碼學基礎——密碼學文章中介紹了密碼學相關的概念,其中簡要地對稱密碼體制(也叫單鑰密碼體制、秘密密鑰體制)進行了解釋,我們可以知道單鑰體制的加密密鑰和解密密鑰相同,單鑰密碼分為流密碼和分組密碼。 流密碼&#xff0…

Redis分布式鎖詳解

Redis分布式鎖詳解 分布式鎖是在分布式系統中實現互斥訪問共享資源的重要機制。Redis因其高性能和原子性操作特性,常被用來實現分布式鎖。 一、基礎實現方案 1. SETNX EXPIRE方案(基本版) # 加鎖 SETNX lock_key unique_value # 設置唯…

創建Linux虛擬環境并遠程連接,finalshell自定義壁紙

安裝VMware 這里不多贅述。 掛載Linux系統 1). 打開Vmware虛擬機,打開 編輯 -> 虛擬網絡編輯器(N) 選擇 NAT模式,然后選擇右下角的 更改設置。 設置子網IP為 192.168.100.0,然后選擇 應用 -> 確定。 解壓 CentOS7-1.zip 到一個比較大…

podman和與docker的比較 及podman使用

Podman 與 Docker 的比較和區別 架構差異 Docker:采用客戶端 - 服務器(C/S)架構,有一個以 root 權限運行的守護進程 dockerd 來管理容器的生命周期。客戶端(docker 命令行工具)與守護進程進行通信&#x…