【Pytorch深度學習開發實踐學習】【AlexNet】經典算法復現-Pytorch實現AlexNet神經網絡(1)model.py

在這里插入圖片描述

算法簡介

AlexNet是人工智能深度學習在CV領域的開山之作,是最先把深度卷積神經網絡應用于圖像分類領域的研究成果,對后面的諸多研究起到了巨大的引領作用,因此有必要學習這個算法并能夠實現它。

主要的創新點在于:

  1. 首次使用GPU進行神經網絡加速訓練
  2. 使用使用了非飽和的激活函數ReLU,而不是傳統的sigmoid和tanh
  3. 使用了數據增強手段抑制過擬合
  4. 提出了Dropout隨機失活抑制過擬合
  5. 提出了LRN局部響應歸一化
  6. 使用了重疊池化抑制過擬合

model.py代碼講解

import torch.nn as nn
import torchclass AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  使用48個11*11的卷積核,步長為4,padding為2 output[48, 55, 55]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # input[48, 55, 55]  output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return x

model.py的全部代碼如上
現在逐行進行分析

class AlexNet(nn.Module):def __init__(self, num_classes=1000, init_weights=False):super(AlexNet, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  使用48個11*11的卷積核,步長為4,padding為2 output[48, 55, 55]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # input[48, 55, 55]  output[48, 27, 27]nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6])

class AlexNet(nn.Module):
定義了一個AlexNet的類,這個類繼承了nn.Module
def init(self,num_classes=1000):
定義了類的初始化函數,它有個可選的參數 num_classes是我們這個神經網絡在輸出的分類數

super(AlexNet,self).__init()
這是為了調用父類的初始化函數

self.features = nn.Sequential()
在這里插入圖片描述
這里非常重要,我們可以去Pytorch的官方文檔上看看,
官方的解釋是:
模塊將按照傳入構造函數的順序添加到其中。另外,也可以傳入一個有序字典的模塊。Sequential的forward()方法接受任何輸入,并將其轉發給它包含的第一個模塊。然后,對于每個后續模塊,它將輸出“鏈接”到輸入,最終返回最后一個模塊的輸出。
Sequential相對于手動調用一系列模塊的優勢在于,它允許將整個容器視為單個模塊,這樣對Sequential執行的轉換將應用于其存儲的每個模塊(它們分別是Sequential的注冊子模塊)。
Sequential和torch.nn.ModuleList之間有什么區別?ModuleList就像它的名字一樣-用于存儲Module的列表!另一方面,Sequential中的層以級聯方式連接。

論文中的AlexNet網絡結構圖如下:
在這里插入圖片描述
AlexNet是第一個網絡結構開始變得更加復雜的神經網絡模型(Lenet)只有兩個卷積層和兩個全連接層,而AlexNet有5個卷積層和3個全連接層,對于逐漸復雜的網絡結構,我們可以利用Sequential函數搭建序列化的網絡模塊

比如這里我們首先定義了一個features模塊
nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),
第一個卷積層 輸入是2242243 48個1111的卷積核 步長是4,填充是2
輸出是55
55*48

nn.ReLU(inplace=True),ReLU激活函數

nn.MaxPool2d(kernel_size=3, stride=2),
定義一個最大池化層,使用3x3的池化核,步長為2。這將進一步減少特征圖的尺寸。

nn.Conv2d(48, 128, kernel_size=5, padding=2),
又是一個卷積層,輸入是272748 128個55的卷積核 填充是2,輸出是2727*128

然后以此類推
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), 又是激活函數和池化,池化后輸出 1313128
nn.Conv2d(128, 192, kernel_size=3, padding=1), 輸入1313128 輸出1313192

nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1),輸入1313192 輸出1313192

nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1), 輸入1313192
輸出1313128

nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), 輸入1313128 輸出 66128

self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)

第二個模塊,上一個是5層卷積層加3層池化層提取特征
下面這個模塊就是全連接層做分類

首先是drouput隨機失活抑制過擬合的操作
然后是 nn.Linear(128 * 6 * 6, 2048),12866的原因是全連接層是接著前面的最后一個也是第三個池化層,池化層的輸出就是12866
后面再接兩個全連接層,最后一個全連接層的輸出就是對1000個類的預測結果

   def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return x

def forward(self, x):

定義一個名為forward的方法,這是PyTorch中自定義神經網絡層或模型的標準做法。這個方法描述了輸入數據x通過網絡的前向傳播過程。
x = self.features(x)
將輸入數據x傳遞給feature模塊
x = torch.flatten(x, start_dim=1)

使用PyTorch的flatten函數將特征圖x在指定的維度(start_dim=1,通常是指從第二個維度開始,即特征圖的深度維度)展平。這通常是為了將多維的特征圖轉換為一維的張量,以便輸入到全連接層。
這里要重點說明一下,在feature后輸出的x是一個四維的參數(B,C,H,W)分別是batchsize channel 高、寬 而這個函數的意思是從第二維channel開始,對后三維 通道數、寬、高進行展開,轉為一維的向量輸入全連接層

x = self.classifier(x)
將展平后的特征x傳遞給classifier
return x
返回經過分類器處理后的輸出。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/710962.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/710962.shtml
英文地址,請注明出處:http://en.pswp.cn/news/710962.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

AI語音識別的技術解析

從語音識別算法的發展來看,語音識別技術主要分為三大類,第一類是模型匹配法,包括矢量量化(VQ) 、動態時間規整(DTW)等;第二類是概率統計方法,包括高斯混合模型(GMM) 、隱馬爾科夫模型(HMM)等;第三類是辨別器…

golang gin單獨部署vue3.0前后端分離應用

概述 因為公司最近的項目前端使用vue 3.0,后端api使用golang gin框架。測試通過后,博文記錄,用于備忘。 步驟 npm run build,構建出前端項目的dist目錄,dist目錄的結構具體如下圖 將dist目錄復制到后端程序同級目錄…

嵌入式軟件bug從哪里來,到哪里去

摘要:軟件從來不是一次就能完美的,需要以包容的眼光看待它的殘缺。那問題究竟為何產生,如何去除呢? 1、軟件問題從哪來 軟件缺陷問題千千萬萬,主要是需求、實現、和運行環境三方面。 1.1 需求描述偏差 客戶角度的描…

PSO-CNN-LSTM多輸入回歸預測|粒子群算法優化的卷積-長短期神經網絡回歸預測(Matlab)——附代碼數據

目錄 一、程序及算法內容介紹: 基本內容: 亮點與優勢: 二、實際運行效果: 三、算法介紹: 四、完整程序數據分享下載: 一、程序及算法內容介紹: 基本內容: 本代碼基于Matlab平臺…

5 局域網基礎(3)

1.AAA 服務器 AAA 是驗證、授權和記賬(Authentication、Authorization、Accounting)3個英文單詞的簡稱,是一個能夠處理用戶訪問請求的服務器程序,提供驗證授權以及帳戶服務,主要目的是管理用戶訪問網絡服務器,對具有訪問權的用戶提供服務。AAA服務器通常…

Java TCP文件上傳案例

文件上傳分析 【客戶端】輸入流,從硬盤讀取文件數據到程序中。【客戶端】輸出流,寫出文件數據到服務端。【服務端】輸入流,讀取文件數據到服務端程序。【服務端】輸出流,寫出文件數據到服務器硬盤中。 基本實現 服務端實現 pu…

【二分查找】樸素二分查找

二分查找 題目描述 給定一個 n 個元素有序的(升序)整型數組 nums 和一個目標值 target ,寫一個函數搜索 nums 中的 target,如果目標值存在返回下標,否則返回 -1。 示例 1: 輸入: nums [-1,0,3,5,9,12], target 9…

網絡編程:基于TCP和UDP的服務器、客戶端

1.基于TCP通信服務器 程序代碼&#xff1a; 1 #include<myhead.h>2 #define SER_IP "192.168.126.121"//服務器IP3 #define SER_PORT 8888//服務器端口號4 int main(int argc, const char *argv[])5 {6 //1.創建用于監聽的套接字7 int sfd-1;8 sf…

MYSQL C++鏈接接口編程

使用MYSQL 提供的C接口來訪問數據庫,官網比較零碎,又不想全部精讀一下,百度CSDN都是亂七八糟的,大部分不可用 官網教程地址 https://dev.mysql.com/doc/connector-cpp/1.1/en/connector-cpp-examples-connecting.html 網上之所以亂七八糟,主要是MYSQL提供了3個接口兩個包,使用…

C++ //練習 10.9 實現你自己的elimDups。測試你的程序,分別在讀取輸入后、調用unique后以及調用erase后打印vector的內容。

C Primer&#xff08;第5版&#xff09; 練習 10.9 練習 10.9 實現你自己的elimDups。測試你的程序&#xff0c;分別在讀取輸入后、調用unique后以及調用erase后打印vector的內容。 環境&#xff1a;Linux Ubuntu&#xff08;云服務器&#xff09; 工具&#xff1a;vim 代碼…

Flask g對象和插件

四、Flask進階 1. Flask插件 I. flask-caching 安裝 pip install flask-caching初始化 from flask_cache import Cache cache Cache(config(CACHE_TYPE:"simple" )) cache.init_app(appapp)使用 在視圖函數上添加緩存 blue.route("/") cache.cached(tim…

django5生產級部署和并發測試(開發者服務器和uvicorn服務器)

目錄 1. 創建django項目2. 安裝壓力測試工具3. 安裝生產級服務器uvicorn4. 多進程部署 1. 創建django項目 在桌面創建一個名為django_test的項目&#xff1a; django-admin startproject django_test然后使用cd命令進入django_test文件夾內&#xff0c;使用開發者服務器運行項…

前端架構: 腳手架包管理工具之lerna的全流程開發教程

Lerna 1 &#xff09;文檔 Lerna 文檔 https://www.npmjs.com/package/lernahttps://lerna.js.org [請直達這個鏈接] 使用 Lerna 幫助我們做包管理&#xff0c;并不復雜&#xff0c;中間常用的命令并不是很多這里是命令直達&#xff1a;https://lerna.js.org/docs/api-referen…

掌匯云 | FBIF個性化票務系統,展會活動數據好沉淀

“把票全賣光&#xff01;賣到一票難求&#xff0c;現場座無虛席。” 賣票人和買票人可能永遠不在一個頻道上。 2022年辦活動&#xff0c;就是一個字&#xff0c;搏&#xff01;和“黑天鵝”趕時間&#xff0c;能不能辦不由主辦方說了算。這種情況在2023年得到了改善&#xff…

【字典樹】【KMP】【C++算法】3045統計前后綴下標對 II

作者推薦 動態規劃的時間復雜度優化 本文涉及知識點 字符串 字典樹 KMP 前后綴 LeetCode:3045統計前后綴下標對 II 給你一個下標從 0 開始的字符串數組 words 。 定義一個 布爾 函數 isPrefixAndSuffix &#xff0c;它接受兩個字符串參數 str1 和 str2 &#xff1a; 當 st…

C++——內存管理(new和delete)詳解

目錄 C/C內存管理 案例&#xff1a;變量在內存中到底會在哪&#xff1f; New和delete Operator new和operator delete函數 New和delete的原理 對內置類型 對自定義類型 定位new New/delete和malloc/free的區別 C/C內存管理 C/C內存管理分布圖&#xff1a;&#xff08;從…

項目案例:圖像分類技術在直播電商中的應用與實踐

一、引言 在數字化浪潮的推動下&#xff0c;電商行業迎來了一場革命性的變革。直播電商&#xff0c;作為一種新興的購物模式&#xff0c;正以其獨特的互動性和娛樂性&#xff0c;重塑著消費者的購物習慣。通過實時的直播展示&#xff0c;商品的細節得以清晰呈現&#xff0c;而互…

matlab:涉及復雜函數圖像的交點求解

matlab&#xff1a;涉及復雜函數圖像的交點求解 在MATLAB中求解兩個圖像的交點是一個常見的需求。本文將通過一個示例&#xff0c;展示如何求解兩個圖像的交點&#xff0c;并提供相應的MATLAB代碼。 畫出圖像 首先&#xff0c;我們需要繪制兩個圖像&#xff0c;以便直觀地看…

【JavaEE】_HttpServletResponse類

目錄 1. 核心方法 2. 關于setStatus(400)與sendError 2.1 setStatus(400) 2.2 sendError 3. setHeader方法 4. 構造重定向響應 4.1 使用setHeader和setStatus實現重定向 4.2 使用sendRedirect實現重定向 本專欄已有文章介紹HttpServlet和HttpServletRequest類&#…

仿真科普|CAE技術賦能無人機 低空經濟蓄勢起飛

喝一杯無人機送來的現磨熱咖啡&#xff1b;在擁堵的早高峰打個“空中的士”上班&#xff1b;乘坐水陸兩棲飛機來一場“陸海空”立體式觀光……曾經只出現在科幻片里的5D城市魔幻場景&#xff0c;正逐漸走進現實。而推動上述場景實現的&#xff0c;就是近年來越來越熱的“低空經…