深度學習J6周 ResNeXt-50實戰解析

  • 🍨 本文為🔗365天深度學習訓練營中的學習記錄博客
  • 🍖 原作者:K同學啊

本周任務:

1.閱讀ResNeXt論文,了解作者的構建思路

2.對比之前介紹的ResNet50V2、DenseNet算法

3.復現ResNeXt-50算法

一、模型結構

ResNeXt由何凱明團隊,2017年CVPR會議上提出新型圖像分類網絡。它是ResNet升級版,在ResNet的基礎上,引入cardinality概念。

在論文中,作者提出當時普遍存在的一個問題,如果要提高模型準確率,往往采取加深網絡或者加寬網絡的方法。但網絡設計的難度和計算開銷也增加了。為了一點精度的提升往往付出更大的代價。因此,需要在不額外增加計算代價的情況下,提升網絡精度。

左邊--ResNet,輸入的具有256個通道的特征經過1*1卷積壓縮到64個通道,之后3*3的卷積核用于處理特征,經1*1卷積擴大通道數與原特征殘差連接后輸出。

右邊--ResNeXt,輸入的具有256個通道的特征被分為32個組,每組被壓縮到4個通道后處理,32個組相加后與原特征殘差連接后輸出。cardinality指的是一個block中所具有相同的分支的數目。

二、分組卷積

1.ResNeXt采用分組卷積:將特征圖分為不同的組,再對每組特征圖分別進行卷積,有效降低計算量。

2.分組卷積中,每個卷積核只處理部分通道,如下圖,紅色卷積核只處理紅色通道,綠色卷積核只處理綠色通道,黃色卷積核只處理黃色通道。此時,每個卷積核有2個通道,每個卷積核生成一張特征圖。

三、代碼

學習于深度學習第J6周:ResNeXt-50實戰解析_resnext50-CSDN博客

?1.前期準備

#配置GPU
import os, PIL, random, pathlib
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as Fdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)#導入數據集
data_dir = './data/'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
print(classeNames)image_count = len(list(data_dir.glob('*/*')))
print("圖片總數為:", image_count)#數據預處理+劃分數據集
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 將輸入圖片resize成統一尺寸# transforms.RandomHorizontalFlip(), # 隨機水平翻轉transforms.ToTensor(),  # 將PIL Image或numpy.ndarray轉換為tensor,并歸一化到[0,1]之間transforms.Normalize(  # 標準化處理-->轉換為標準正太分布(高斯分布),使模型更容易收斂mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]與std=[0.229,0.224,0.225] 從數據集中隨機抽樣計算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]),  # 將輸入圖片resize成統一尺寸transforms.ToTensor(),  # 將PIL Image或numpy.ndarray轉換為tensor,并歸一化到[0,1]之間transforms.Normalize(  # 標準化處理-->轉換為標準正太分布(高斯分布),使模型更容易收斂mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]與std=[0.229,0.224,0.225] 從數據集中隨機抽樣計算得到的。
])total_data = datasets.ImageFolder("./data/", transform=train_transforms)
print(total_data.class_to_idx)train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=0)
for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break

結果:

2.模型

class BN_Conv2d(nn.Module):"""BN_CONV_RELU"""def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bias=False):super(BN_Conv2d, self).__init__()self.seq = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=bias),nn.BatchNorm2d(out_channels))def forward(self, x):return F.relu(self.seq(x))class ResNeXt_Block(nn.Module):"""ResNeXt block with group convolutions"""def __init__(self, in_chnls, cardinality, group_depth, stride):super(ResNeXt_Block, self).__init__()self.group_chnls = cardinality * group_depthself.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls*2, 1, stride=1, padding=0)self.bn = nn.BatchNorm2d(self.group_chnls*2)self.short_cut = nn.Sequential(nn.Conv2d(in_chnls, self.group_chnls*2, 1, stride, 0, bias=False),nn.BatchNorm2d(self.group_chnls*2))def forward(self, x):out = self.conv1(x)out = self.conv2(out)out = self.bn(self.conv3(out))out += self.short_cut(x)return F.relu(out)class ResNeXt(nn.Module):"""ResNeXt builder"""def __init__(self, layers: object, cardinality, group_depth, num_classes) -> object:super(ResNeXt, self).__init__()self.cardinality = cardinalityself.channels = 64self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)d1 = group_depthself.conv2 = self.___make_layers(d1, layers[0], stride=1)d2 = d1 * 2self.conv3 = self.___make_layers(d2, layers[1], stride=2)d3 = d2 * 2self.conv4 = self.___make_layers(d3, layers[2], stride=2)d4 = d3 * 2self.conv5 = self.___make_layers(d4, layers[3], stride=2)self.fc = nn.Linear(self.channels, num_classes)   # 224x224 input sizedef ___make_layers(self, d, blocks, stride):strides = [stride] + [1] * (blocks-1)layers = []for stride in strides:layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride))self.channels = self.cardinality*d*2return nn.Sequential(*layers)def forward(self, x):out = self.conv1(x)out = F.max_pool2d(out, 3, 2, 1)out = self.conv2(out)out = self.conv3(out)out = self.conv4(out)out = self.conv5(out)out = F.avg_pool2d(out, 7)out = out.view(out.size(0), -1)out = F.softmax(self.fc(out),dim=1)return out
# 定義完成,測試一下
model = ResNeXt([3, 4, 6, 3], 32, 4, 4)
model.to(device)# 統計模型參數量以及其他指標
import torchsummary as summary
summary.summary(model, (3, 224, 224))

結果:

?

?

?3.訓練運行

 
# 訓練循環
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 訓練集的大小num_batches = len(dataloader)  # 批次數目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化訓練損失和正確率for X, y in dataloader:  # 獲取圖片及其標簽X, y = X.to(device), y.to(device)# 計算預測誤差pred = model(X)  # 網絡輸出loss = loss_fn(pred, y)  # 計算網絡輸出和真實值之間的差距,targets為真實值,計算二者差值即為損失# 反向傳播optimizer.zero_grad()  # grad屬性歸零loss.backward()  # 反向傳播optimizer.step()  # 每一步自動更新# 記錄acc與losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_lossdef test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 測試集的大小num_batches = len(dataloader)  # 批次數目test_loss, test_acc = 0, 0# 當不進行訓練時,停止梯度更新,節省計算內存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 計算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
 
import copyoptimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()  # 創建損失函數epochs = 10train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0  # 設置一個最佳準確率,作為最佳模型的判別指標for epoch in range(epochs):# 更新學習率(使用自定義學習率時使用)# adjust_learning_rate(optimizer, epoch, learn_rate)model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)# scheduler.step() # 更新學習率(調用官方動態學習率接口時使用)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 獲取當前的學習率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的參數文件名
torch.save(model.state_dict(), PATH)print('Done')

結果:

?

4.打印訓練圖

import matplotlib.pyplot as plt
# 隱藏警告
import warningswarnings.filterwarnings("ignore")  # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標簽
plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號
plt.rcParams['figure.dpi'] = 100  # 分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

四、總結

1.讀論文原文要花很長時間,但有講義,就會快速知道論文的創新點是什么。

2.實驗的流程已經很熟悉,現在就在慢慢學每一步的具體內容,爭取下次能自己寫出。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/64435.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/64435.shtml
英文地址,請注明出處:http://en.pswp.cn/web/64435.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Langchain Chat Model 和 Chat Prompt Template

0. 簡介 Chat Model 不止是一個用于聊天對話的模型抽象,更重要的是提供了多角色提示能力(System,AI,Human,Function)。 Chat Prompt Template 則為開發者提供了便捷維護不同角色的提示模板與消息記錄的接口。 1. 構造 ChatPromptTemplate from langch…

對話 Project Astra 研究主管:打造通用 AI 助理,主動視頻交互和全雙工對話是未來重點

Project Astra 愿景之一:「系統不僅能在你說話時做出回應,還能在持續的過程中幫助你。」 近期,Google DeepMind 的 YouTube 頻道采訪了 Google DeepMind 研究主管格雷格韋恩 (Greg Wayne)。 格雷格韋恩的研究工作為 DeepMind 的諸多突破性成…

全國青少年信息學奧林匹克競賽(信奧賽)備考實戰之循環結構(for循環語句)(四)

實戰訓練1—最大差值 問題描述: 輸入n個非負整數,找出這個n整數的最大值與最小值,并求最大值和最小值的差值。 輸入格式: 共兩行,第一行為整數的個數 n(1≤n≤1000)。第二行為n個整數的值(整…

純Dart Flutter庫適配HarmonyOS

純Dart Flutter庫適配HarmonyOS介紹: Flutter基本組件、Flutter布局組件、Flutter圖片組件、Flutter字體、Flutter圖標、Fluter路由、flutter動畫、 Flutter表單、flutter異步等,純Dart庫無需任何處理,可以直接編譯成HarmonyOs應用。 具體步…

LunarVim安裝

LunarVim以其豐富的功能和靈活的定制性,迅速在Nvim用戶中流行開來。它不僅提供了一套完善的默認配置,還允許用戶根據自己的需求進行深度定制。無論是自動補全、內置終端、文件瀏覽器,還是模糊查找、LSP支持、代碼檢測、格式化和調試&#xff…

劍指Offer|LCR 015. 找到字符串中所有字母異位詞

LCR 015. 找到字符串中所有字母異位詞 給定兩個字符串 s 和 p,找到 s 中所有 p 的 變位詞 的子串,返回這些子串的起始索引。不考慮答案輸出的順序。 變位詞 指字母相同,但排列不同的字符串。 示例 1: 輸入: s "cbaebaba…

高質量 Next.js 后臺管理模板源碼分享,開發者必備

高質量 Next.js后臺管理模板源碼分享,開發者必備 Taplox 是一個基于 Bootstrap 5 和 Next.js 構建的現代化后臺管理模板和 UI 組件庫。它不僅設計精美,還提供了一整套易用的工具,適合各種 Web 應用、管理系統和儀表盤項目。無論你是初學者還是…

開發場景中Java 集合的最佳選擇

在 Java 開發中,集合類是處理數據的核心工具。合理選擇集合,不僅可以提高代碼效率,還能讓代碼更簡潔。本篇文章將重點探討 List、Set 和 Map 的適用場景及優缺點,幫助你在實際開發中找到最佳解決方案。 一、List:有序存…

Java包裝類型的緩存

Java 基本數據類型的包裝類型的大部分都用到了緩存機制來提升性能。 Byte,Short,Integer,Long 這 4 種包裝類默認創建了數值 [-128,127] 的相應類型的緩存數據,Character 創建了數值在 [0,127] 范圍的緩存數據,Boolean 直接返回 True or Fal…

工程師 - MinGW

MinGW Minimalist GNU for Windows,前身為mingw32,是一個免費開源的軟件開發環境,從2010年開始項目停止并不再使用。后續提供MinGW-w64。 MinGW包括: - 移植到Windows上的GNU編譯器集(GCC),包括C、C、ADA和…

EasyExcel(讀取操作和填充操作)

文章目錄 1.準備Read.xlsx(具有兩個sheet)2.讀取第一個sheet中的數據1.模板2.方法3.結果 3.讀取所有sheet中的數據1.模板2.方法3.結果 EasyExcel填充1.簡單填充1.準備 Fill01.xlsx2.無模版3.方法4.結果 2.列表填充1.準備 Fill02.xlsx2.模板3.方法4.結果 …

CKA認證 | Day7 K8s存儲

第七章 Kubernetes存儲 1、數據卷與數據持久卷 為什么需要數據卷? 容器中的文件在磁盤上是臨時存放的,這給容器中運行比較重要的應用程序帶來一些問題。 問題1:當容器升級或者崩潰時,kubelet會重建容器,容器內文件會…

Python調用R語言中的程序包來執行回歸樹、隨機森林、條件推斷樹和條件推斷森林算法

要使用Python調用R語言中的程序包來執行回歸樹、隨機森林、條件推斷樹和條件推斷森林算法,重新計算中國居民收入不平等,并進行分類匯總,我們可以使用rpy2庫。rpy2允許在Python中嵌入R代碼并調用R函數。以下是一個詳細的步驟和示例代碼&#x…

關于JAVA方法值傳遞問題

1.1 前言 之前在學習C語言的時候,將實參傳遞給方法(或函數)的方式分為兩種:值傳遞和引用傳遞,但在JAVA中只有值傳遞(顛覆認知,基礎沒學踏實) 參考文章:https://blog.csd…

Excel基礎知識

一:數組 一行或者一列數據稱為一維數組,多行多列稱為二維數組,數組支持算術運算(如加減乘除等)。 行:{1,2,3,4} 數組中的每個值用逗號分隔列:{1;2;3;4} 數組中的每個值用分號分隔行列&#xf…

基于DIODES AP43781+PI3USB31531+PI3DPX1207C的USB-C PD Video 之全功能顯示器連接端口方案

隨著USB-C連接器和PD功能的出現,新一代USB-C PD PC顯示器可以用作個人和專業PC工作環境的電源和數據集線器。 雖然USB-C PD顯示器是唯一插入墻壁插座的交流電源輸入設備,但它可以作為數據UFP(上游接口)連接到連接到TCD&#xff0…

gazebo_world 基本圍墻。

如何使用&#xff1f; 參考gazebo harmonic的官方教程。 本人使用harmonic的template&#xff0c;在里面進行修改就可以分流暢地使用下去。 以下是world 文件. <?xml version"1.0" ?> <!--Try sending commands:gz topic -t "/model/diff_drive/…

解決無法在 Ubuntu 24.04 上運行 AppImage 應用

在 Ubuntu 24.04 中運行 AppImage 應用的完整指南 在 Ubuntu 24.04 中&#xff0c;許多用戶可能會遇到 AppImage 應用無法啟動的問題。即使你已經設置了正確的文件權限&#xff0c;AppImage 仍然拒絕運行。這通常是由于缺少必要的庫文件所致。 問題根源&#xff1a;缺少 FUSE…

Pytorch使用手冊-DCGAN 指南(專題十四)

1. Introduction 本教程將通過一個示例介紹 DCGANs(深度卷積生成對抗網絡)。我們將訓練一個生成對抗網絡(GAN),在給它展示大量真實名人照片后,它能夠生成新的“名人”圖片。這里的大部分代碼來源于 PyTorch 官方示例中的 DCGAN 實現,而本文檔將對該實現進行詳細解釋,并…

springboot配置oracle+達夢數據庫多數據源配置并動態切換

項目場景&#xff1a; 在工作中很多情況需要跨數據庫進行數據操作,自己總結的經驗希望對各位有所幫助 問題描述 總結了幾個問題 1.識別不到mapper 2.識別不到xml 3.找不到數據源 原因分析&#xff1a; 1.配置文件編寫導致識別mapper 2.配置類編寫建的格式有問題 3.命名…