【入門】5分鐘了解卷積神經網絡CNN是什么

本文來自《老餅講解-BP神經網絡》https://www.bbbdata.com/

目錄

  • 一、卷積神經網絡的結構
    • 1.1.卷積與池化的作用
    • 2.2.全連接層的作用
  • 二、卷積神經網絡的運算
    • 2.1.卷積層的運算
    • 2.2.池化的運算
    • 2.3.全連接層運算
  • 三、pytorch實現一個CNN例子
    • 3.1.模型的搭建
    • 3.2.CNN完整訓練代碼

CNN神經網絡常用于圖片識別,是深度學習中常用的模型。
本文簡單快速了解卷積神經網絡是什么東西,并展示一個簡單的示例。

一、卷積神經網絡的結構

一個經典的卷積神經網絡的結構如下:
卷積神經網絡
C代表卷積層,P代表池化層,F代表全連接層。
卷積神經網絡主要的、樸素的用途是圖片識別。即輸入圖片,然后識別圖片的類別,例如輸入一張圖片,識別該圖片是貓還是狗。

1.1.卷積與池化的作用

卷積層與池化層共同是卷積神經網絡的核心,它用于將輸入圖片進行壓縮,例如一張224x224的圖片,經過卷積+池化后,可能得到的就是55x55的圖片,也就是說,卷積與池化的目的就是使得輸入圖片變小,同時盡量不要損失太多與類別相關的信息。例如一張貓的圖片經過卷積與池化之后,盡量減少圖片的大小,但要盡可能地保留"貓"的信息。

2.2.全連接層的作用

全連接層主要用于預測圖片的類別。全連接層實際可以看作一個BP神經網絡模型, 使用"卷積+池化"之后得到的特征來擬合圖片的類別。

二、卷積神經網絡的運算

2.1.卷積層的運算

卷積層的運算如下:
卷積運算
卷積層中的卷積核就是一個矩陣,直觀來看它就是一個窗口,卷積窗口一般為正方形,即長寬一致,
卷積運算通過從左到右,從上往下移動卷積核窗口,將窗口覆蓋的每一小塊輸入進行加權,作為輸出

2.2.池化的運算

池化層是通過一個池化窗口,對輸入進行逐塊掃描,每次將窗口的元素合并為一個元素,
池化層的運算如下:
池化運算
池化層一般分為均值池化與最大值池化,顧名思義,就是計算時使用均值還是最大值:
均值池化與最大值池化

2.3.全連接層運算

全連接層就相當于一個BP神經網絡模型,即每一層與下一層都是全連接形式。
全連接層

假設前一層傳過來的輸入的是X,則當前層的輸出是tanh(WX+b)

三、pytorch實現一個CNN例子

下面以手寫數字識別為例,展示如何使用pytorch實現一個CNN
在這里插入圖片描述

3.1.模型的搭建

如下所示,就搭建了一個CNN模型

# 卷積神經網絡的結構
class ConvNet(nn.Module):def __init__(self,in_channel,num_classes):super(ConvNet, self).__init__()self.nn_stack=nn.Sequential(#--------------C1層-------------------nn.Conv2d(in_channel,6, kernel_size=5,stride=1,padding=2),nn.ReLU(inplace=True),  nn.AvgPool2d(kernel_size=2,stride=2),# 輸出14*14#--------------C2層-------------------nn.Conv2d(6,16, kernel_size=5,stride=1,padding=2),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2,stride=2),# 輸出7*7#--------------C3層-------------------nn.Conv2d(16,80,kernel_size=7,stride=1,padding=0),# 輸出1*1*80#--------------全連接層F4----------nn.Flatten(),          # 對C3的結果進行展平nn.Linear(80, 120),  nn.ReLU(inplace=True),                                   #--------------全連接層F5----------                      nn.Linear(120, num_classes)                       )def forward(self, x):p = self.nn_stack(x)return p

從代碼里可以看到,只需按自己所設定的結構進行隨意搭建就可以了。
搭建了之后再使用數據進行訓練可以了,然后就可以使用模型對樣本進行預測。

3.2.CNN完整訓練代碼

完整的CNN訓練代碼示例如下:

import torch
from   torch import nn
from   torch.utils.data   import DataLoader
import torchvision
import numpy as np#--------------------模型結構--------------------------------------------
# 卷積神經網絡的結構
class ConvNet(nn.Module):def __init__(self,in_channel,num_classes):super(ConvNet, self).__init__()self.nn_stack=nn.Sequential(#--------------C1層-------------------nn.Conv2d(in_channel,6, kernel_size=5,stride=1,padding=2),nn.ReLU(inplace=True),  nn.AvgPool2d(kernel_size=2,stride=2),# 輸出14*14#--------------C2層-------------------nn.Conv2d(6,16, kernel_size=5,stride=1,padding=2),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2,stride=2),# 輸出7*7#--------------C3層-------------------nn.Conv2d(16,80,kernel_size=7,stride=1,padding=0),# 輸出1*1*80#--------------全連接層F4----------nn.Flatten(),          # 對C3的結果進行展平nn.Linear(80, 120),  nn.ReLU(inplace=True),                                   #--------------全連接層F5----------                      nn.Linear(120, num_classes)                       )def forward(self, x):p = self.nn_stack(x)return p#-----------------------模型訓練---------------------------------------
# 參數初始化函數
def init_param(model):# 初始化權重閾值                                                                         param_list = list(model.named_parameters())                                                # 將模型的參數提取為列表                      for i in range(len(param_list)):                                                           # 逐個初始化權重、閾值is_weight = i%2==0                                                                     # 如果i是偶數,就是權重參數,i是奇數就是閾值參數if is_weight:                                                                          torch.nn.init.normal_(param_list[i][1],mean=0,std=0.01)                            # 對于權重,以N(0,0.01)進行隨機初始化else:                                                                                  torch.nn.init.constant_(param_list[i][1],val=0)                                     # 閾值初始化為0# 訓練函數                                                                                     
def train(dataloader,valLoader,model,epochs,goal,device):                                      for epoch in range(epochs):                                                                err_num  = 0                                                                           # 本次epoch評估錯誤的樣本eval_num = 0                                                                           # 本次epoch已評估的樣本print('-----------當前epoch:',str(epoch),'----------------')                           for batch, (imgs, labels) in enumerate(dataloader):                                    # -----訓練模型-----                                                               x, y = imgs.to(device), labels.to(device)                                          # 將數據發送到設備optimizer.zero_grad()                                                              # 將優化器里的參數梯度清空py   = model(x)                                                                    # 計算模型的預測值   loss = lossFun(py, y)                                                              # 計算損失函數值loss.backward()                                                                    # 更新參數的梯度optimizer.step()                                                                   # 更新參數# ----計算錯誤率----                                                               idx      = torch.argmax(py,axis=1)                                                 # 模型的預測類別eval_num = eval_num + len(idx)                                                     # 更新本次epoch已評估的樣本err_num  = err_num +sum(y != idx)                                                  # 更新本次epoch評估錯誤的樣本if(batch%10==0):                                                                   # 每10批打印一次結果print('err_rate:',err_num/eval_num)                                            # 打印錯誤率# -----------驗證數據誤差---------------------------                                   model.eval()                                                                           # 將模型調整為評估狀態val_acc_rate = calAcc(model,valLoader,device)                                          # 計算驗證數據集的準確率model.train()                                                                          # 將模型調整回訓練狀態print("驗證數據的準確率:",val_acc_rate)                                                # 打印準確率    if((err_num/eval_num)<=goal):                                                          # 檢查退出條件break                                                                              print('訓練步數',str(epoch),',最終訓練誤差',str(err_num/eval_num))                         # 計算數據集的準確率                                                                           
def calAcc(model,dataLoader,device):                                                           py = np.empty(0)                                                                           # 初始化預測結果y  = np.empty(0)                                                                           # 初始化真實結果for batch, (imgs, labels) in enumerate(dataLoader):                                        # 逐批預測cur_py =  model(imgs.to(device))                                                       # 計算網絡的輸出cur_py = torch.argmax(cur_py,axis=1)                                                   # 將最大者作為預測結果py     = np.hstack((py,cur_py.detach().cpu().numpy()))                                 # 記錄本批預測的yy      = np.hstack((y,labels))                                                         # 記錄本批真實的yacc_rate = sum(y==py)/len(y)                                                               # 計算測試樣本的準確率return acc_rate                                                                               #--------------------------主流程腳本----------------------------------------------       
#-------------------加載數據--------------------------------
train_data = torchvision.datasets.MNIST(root       = 'D:\pytorch\data'                                                             # 路徑,如果路徑有,就直接從路徑中加載,如果沒有,就聯網獲取,train     = True                                                                          # 獲取訓練數據,transform = torchvision.transforms.ToTensor()                                             # 轉換為tensor數據,download  = True                                                                          # 是否下載,選為True,就下載到root下面,target_transform= None)                                                                   
val_data = torchvision.datasets.MNIST(root       = 'D:\pytorch\data'                                                             # 路徑,如果路徑有,就直接從路徑中加載,如果沒有,就聯網獲取,train     = False                                                                         # 獲取測試數據,transform = torchvision.transforms.ToTensor()                                             # 轉換為tensor數據,download  = True                                                                          # 是否下載,選為True,就下載到root下面,target_transform= None)                                                                   #-------------------模型訓練--------------------------------                                   
trainLoader = DataLoader(train_data, batch_size=1000, shuffle=True)                            # 將數據裝載到DataLoader
valLoader   = DataLoader(val_data  , batch_size=100)                                           # 將驗證數據裝載到DataLoader 
device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')                     # 設置訓練設備  
model       = ConvNet(in_channel =1,num_classes=10).to(device)                                 # 初始化模型,并發送到設備  
lossFun     = torch.nn.CrossEntropyLoss()                                                      # 定義損失函數為交叉熵損失函數
optimizer   = torch.optim.SGD(model.parameters(), lr=0.01,momentum =0.9,dampening=0.0005)      # 初始化優化器
train(trainLoader,valLoader,model,1000,0.01,device)                                            # 訓練模型,訓練100步,錯誤低于1%時停止訓練# -----------模型效果評估--------------------------- 
model.eval()                                                                                   # 將模型切換到評估狀態(屏蔽Dropout)
train_acc_rate = calAcc(model,trainLoader,device)                                              # 計算訓練數據集的準確率
print("訓練數據的準確率:",train_acc_rate)                                                      # 打印準確率
val_acc_rate = calAcc(model,valLoader,device)                                                  # 計算驗證數據集的準確率
print("驗證數據的準確率:",val_acc_rate)                                                        # 打印準確率

運行結果如下:

-----------當前epoch: 0 ---------------- 
err_rate: tensor(0.7000)                 
驗證數據的準確率: 0.3350877192982456     
-----------當前epoch: 1 ---------------- 
err_rate: tensor(0.6400)                 
驗證數據的準確率: 0.3350877192982456     
-----------當前epoch: 2 ---------------- 
.......
.......
-----------當前epoch: 77 ----------------
err_rate: tensor(0.0100)                 
驗證數據的準確率: 1.0                    
-----------當前epoch: 78 ----------------
err_rate: tensor(0.)                     
驗證數據的準確率: 1.0                    
-----------當前epoch: 79 ----------------
err_rate: tensor(0.0200)                 
驗證數據的準確率: 1.0                    
-----------當前epoch: 80 ----------------
err_rate: tensor(0.0100)                 
驗證數據的準確率: 0.9982456140350877     
-----------------------------------------
訓練步數 80 ,最終訓練誤差 tensor(0.0088) 
訓練數據的準確率: 0.9982456140350877     
驗證數據的準確率: 0.9982456140350877 

可以看到,識別效果達到了99.8%。CNN模型對圖片的識別是非常有效的。


相關鏈接:

《老餅講解-機器學習》:老餅講解-機器學習教程-通俗易懂
《老餅講解-神經網絡》:老餅講解-matlab神經網絡-通俗易懂
《老餅講解-神經網絡》:老餅講解-深度學習-通俗易懂

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/38033.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/38033.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/38033.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Dison夏令營 Day 04】如何用 Python 編寫簡單的數字猜謎游戲代碼

上個周末&#xff0c;我整理了一份可以用 Python 編寫的游戲列表。但為什么呢&#xff1f; 如果您是 Python 程序員初學者&#xff0c;編寫有趣的游戲可以幫助您更快更好地學習 Python 語言&#xff0c;而不會被語法之類的東西所困擾。我在學習 Python 的時候曾制作過一些這樣…

Hadoop-03-Hadoop集群 免密登錄 超詳細 3節點公網云 分發腳本 踩坑筆記 SSH免密 服務互通 集群搭建 開啟ROOT

章節內容 上一節完成&#xff1a; HDFS集群XML的配置MapReduce集群XML的配置Yarn集群XML的配置統一權限DNS統一配置 背景介紹 這里是三臺公網云服務器&#xff0c;每臺 2C4G&#xff0c;搭建一個Hadoop的學習環境&#xff0c;供我學習。 之前已經在 VM 虛擬機上搭建過一次&…

短視頻矩陣系統搭建APP源碼開發

前言 短視頻矩陣系統不僅有助于提升品牌影響力和營銷效率&#xff0c;還能幫助企業更精準地觸達目標受眾&#xff0c;增強用戶互動&#xff0c;并利用數據分析來持續優化營銷策略。 一、短視頻矩陣系統是什么&#xff1f; 短視頻矩陣系統是一種通過多個短視頻平臺進行內容創作…

Vue 3 實戰教程(快速入門)

Vue 3 實戰教程&#xff08;快速入門&#xff09; Vue.js 是一個用于構建用戶界面的漸進式框架&#xff0c;Vue 3 是 Vue 的最新版本&#xff0c;帶來了許多改進和新特性。本文將通過一個簡單的項目示例&#xff0c;帶你快速入門 Vue 3 的基礎使用。 環境設置 安裝 Node.js …

多多代播24小時值守:電商直播時代是帶貨爆單的關鍵

在電商直播盛行的今天&#xff0c;直播帶貨已成為品牌與消費者溝通的關鍵。然而&#xff0c;流量波動大&#xff0c;競爭激烈&#xff0c;使品牌面臨諸多挑戰。因此&#xff0c;許多品牌尋求專業代播服務&#xff0c;并特別強調24小時值守的重要性。 流量來源的不穩定性是一個顯…

《VUE.js 實戰》讀書筆記

1. 初識vue.js MVVM模式從MVC模式演化而來&#xff0c;但是MVVM模式更多應用在前端&#xff0c;MVC則是前后端共同表現。傳統開發模式&#xff1a;jQuery RequireJS ( SeaJS ) artTemplate ( doT ) Gulp ( Grunt)。vue.js可以直接通過script引入方式開發&#xff0c;也可以…

Linux下安裝RocketMQ:從零開始的消息中間件之旅

感謝您閱讀本文&#xff0c;歡迎“一鍵三連”。作者定會不負眾望&#xff0c;按時按量創作出更優質的內容。 ?? 1. 畢業設計專欄&#xff0c;畢業季咱們不慌&#xff0c;上千款畢業設計等你來選。 RocketMQ是一款分布式消息中間件&#xff0c;具有高吞吐量、低延遲、高可用性…

本末倒置!做660+880一定要避免出現這3種情況!

每年都有不少人做過660題&#xff0c;但是做過之后&#xff0c;并沒有真正理解其中的題目&#xff0c;所以做過之后效果也不好&#xff01;再去做880題&#xff0c;做的也會比較吃力。 那該怎么辦呢&#xff0c;不建議你繼續做880題&#xff0c;先把660給吃透再說。 接下來給…

PostgreSQL使用教程

安裝 PostgreSQL 您可以從 PostgreSQL 官方網站下載適合您操作系統的安裝程序&#xff0c;并按照安裝向導進行安裝。 啟動數據庫服務器 安裝完成后&#xff0c;根據您的操作系統&#xff0c;通過相應的方式啟動數據庫服務器。 連接到數據庫 可以使用命令行工具&#xff08;如 p…

Objective-C使用塊枚舉的細節

對元素類型的要求 在 Objective-C 中&#xff0c;NSArray 只能存儲對象類型&#xff0c;而不能直接存儲基本類型&#xff08;例如 int&#xff09;。但是&#xff0c;可以將基本類型封裝在 NSNumber 等對象中&#xff0c;然后將這些對象存儲在 NSArray 中。這樣&#xff0c;en…

Maven編譯打包時報“PKIX path building failed”異常

提示&#xff1a;文章寫完后&#xff0c;目錄可以自動生成&#xff0c;如何生成可參考右邊的幫助文檔 文章目錄 方法11.報錯信息2.InstallCert.java3.生成證書文件 jssecacerts4.復制 jssecacerts 文件5. 然后重啟Jenkins 或者maven即可 方法21.下載證書2. 導入證書執行keytool…

7.優化算法之分治-快排歸并

0.分治 分而治之 1.顏色分類 75. 顏色分類 - 力扣&#xff08;LeetCode&#xff09; 給定一個包含紅色、白色和藍色、共 n 個元素的數組 nums &#xff0c;原地對它們進行排序&#xff0c;使得相同顏色的元素相鄰&#xff0c;并按照紅色、白色、藍色順序排列。 我們使用整數…

Elasticsearch (1):ES基本概念和原理簡單介紹

Elasticsearch&#xff08;簡稱 ES&#xff09;是一款基于 Apache Lucene 的分布式搜索和分析引擎。隨著業務的發展&#xff0c;系統中的數據量不斷增長&#xff0c;傳統的關系型數據庫在處理大量模糊查詢時效率低下。因此&#xff0c;ES 作為一種高效、靈活和可擴展的全文檢索…

PHP爬蟲類的使用技巧與注意事項

php爬蟲類的使用技巧與注意事項 隨著互聯網的迅猛發展&#xff0c;大量的數據被不斷地生成和更新。為了方便獲取和處理這些數據&#xff0c;爬蟲技術應運而生。PHP作為一種廣泛應用的編程語言&#xff0c;也有許多成熟且強大的爬蟲類庫可供使用。在本文中&#xff0c;我們將介…

Qt Creator 的設置文件保存位置

在使用 Qt Creator 進行開發時,備份或遷移設置(例如文本編輯器偏好、語法高亮等)是常見需求。了解這些設置文件在不同操作系統中的保存位置,可以簡化這個過程。本文將為您詳細介紹 Qt Creator 保存設置文件的位置。 默認文件位置 Qt Creator 會創建多個文件和目錄來存儲其…

springboot系列八: springboot靜態資源訪問,Rest風格請求處理, 接收參數相關注解

文章目錄 WEB開發-靜態資源訪問官方文檔基本介紹快速入門注意事項和細節 Rest風格請求處理基本介紹應用實例注意事項和細節思考題 接收參數相關注解基本介紹應用實例PathVariableRequestHeaderRequestParamCookieValueRequestBodyRequestAttributeSessionAttribute ?? 上一篇…

微服務-網關Gateway

個人對于網關路由的理解&#xff1a; 網關就相當于是一個項目里面的保安&#xff0c;主要作用就是做一個限制項。&#xff08;zuul和gateway兩個不同的網關&#xff09; 在路由中進行配置過濾器 過濾器工廠&#xff1a;對請求或響應進行加工 其中filters&#xff1a;過濾器配置…

探索QCS6490目標檢測AI應用開發(三):模型推理

作為《探索QCS6490目標檢測AI應用開發》文章&#xff0c;緊接上一期&#xff0c;我們介紹如何在應用程序中介紹如何使用解碼后的視頻幀結合Yolov8n模型推理。 高通 Qualcomm AI Engine Direct 是一套能夠針對高通AI應用加速的軟件SDK&#xff0c;更多的內容可以訪問&#xff1a…

摸魚大數據——Spark基礎——Spark環境安裝——PySpark搭建

三、PySpark環境安裝 PySpark: 是Python的庫, 由Spark官方提供. 專供Python語言使用. 類似Pandas一樣,是一個庫 Spark: 是一個獨立的框架, 包含PySpark的全部功能, 除此之外, Spark框架還包含了對R語言\ Java語言\ Scala語言的支持. 功能更全. 可以認為是通用Spark。 功能 P…

Golang | Leetcode Golang題解之第199題二叉樹的右視圖

題目&#xff1a; 題解&#xff1a; /** 102. 二叉樹的遞歸遍歷*/ func levelOrder(root *TreeNode) [][]int {arr : [][]int{}depth : 0var order func(root *TreeNode, depth int)order func(root *TreeNode, depth int) {if root nil {return}if len(arr) depth {arr a…