Pytorch8實現CNN卷積神經網絡

CNN卷積神經網絡

本章提供一個對CNN卷積網絡的快速實現

全連接網絡 VS 卷積網絡

全連接神經網絡之所以不太適合圖像識別任務,主要有以下幾個方面的問題:

  • 參數數量太多 考慮一個輸入10001000像素的圖片(一百萬像素,現在已經不能算大圖了),輸入層有10001000=100萬節點。假設第一個隱藏層有100個節點(這個數量并不多),那么僅這一層就有(1000*1000+1)*100=1億參數,這實在是太多了!我們看到圖像只擴大一點,參數數量就會多很多,因此它的擴展性很差。
  • 沒有利用像素之間的位置信息 對于圖像識別任務來說,每個像素和其周圍像素的聯系是比較緊密的,和離得很遠的像素的聯系可能就很小了。如果一個神經元和上一層所有神經元相連,那么就相當于對于一個像素來說,把圖像的所有像素都等同看待,這不符合前面的假設。當我們完成每個連接權重的學習之后,最終可能會發現,有大量的權重,它們的值都是很小的(也就是這些連接其實無關緊要)。努力學習大量并不重要的權重,這樣的學習必將是非常低效的。
  • 網絡層數限制 我們知道網絡層數越多其表達能力越強,但是通過梯度下降方法訓練深度全連接神經網絡很困難,因為全連接神經網絡的梯度很難傳遞超過3層。因此,我們不可能得到一個很深的全連接神經網絡,也就限制了它的能力。

那么,卷積神經網絡又是怎樣解決這個問題的呢?主要有三個思路:

  • 局部連接 這個是最容易想到的,每個神經元不再和上一層的所有神經元相連,而只和一小部分神經元相連。這樣就減少了很多參數。
  • 權值共享 一組連接可以共享同一個權重,而不是每個連接有一個不同的權重,這樣又減少了很多參數。
  • 下采樣 可以使用Pooling來減少每層的樣本數,進一步減少參數數量,同時還可以提升模型的魯棒性。

對于圖像識別任務來說,卷積神經網絡通過盡可能保留重要的參數,去掉大量不重要的參數,來達到更好的學習效果。

卷積結構

卷積層

卷積層可以產生一組平行的特征圖(feature map),它通過在輸入圖像上滑動不同的卷積核并執行一定的運算而組成。此外,在每一個滑動的位置上,卷積核與輸入圖像之間會執行一個元素對應乘積并求和的運算以將感受視野內的信息投影到特征圖中的一個元素。這一滑動的過程可稱為步幅 Z_s,步幅 Z_s 是控制輸出特征圖尺寸的一個因素。卷積核的尺寸要比輸入圖像小得多,且重疊或平行地作用于輸入圖像中,一張特征圖中的所有元素都是通過一個卷積核計算得出的,也即一張特征圖共享了相同的權重和偏置項。

池化層

池化(Pooling)是卷積神經網絡中另一個重要的概念,它實際上是一種非線性形式的降采樣。有多種不同形式的非線性池化函數,而其中“最大池化(Max pooling)”是最為常見的。它是將輸入的圖像劃分為若干個矩形區域,對每個子區域輸出最大值。

一個特征的精確位置遠不及它相對于其他特征的粗略位置重要。池化層會不斷地減小數據的空間大小,因此參數的數量和計算量也會下降,這在一定程度上也控制了過擬合。通常來說,CNN的網絡結構中的卷積層之間都會周期性地插入池化層。池化操作提供了另一種形式的平移不變性。因為卷積核是一種特征發現器,我們通過卷積層可以很容易地發現圖像中的各種邊緣。但是卷積層發現的特征往往過于精確,我們即使高速連拍拍攝一個物體,照片中的物體的邊緣像素位置也不大可能完全一致,通過池化層我們可以降低卷積層對邊緣的敏感性。

全連接層

最后,在經過幾個卷積和最大池化層之后,神經網絡中的高級推理通過完全連接層來完成。就和常規的非卷積人工神經網絡中一樣,完全連接層中的神經元與前一層中的所有激活都有聯系。因此,它們的激活可以作為仿射變換來計算,也就是先乘以一個矩陣然后加上一個偏差(bias)偏移量(向量加上一個固定的或者學習來的偏差量)。

卷積神經網絡(LeNet)

模型實現

LeNet是最早發布的卷積神經網絡之一,因其在計算機視覺任務中的高效性能而受到廣泛關注。

用Pytorch框架實現此類模型非常簡單。我們只需要實例化一個Sequential塊并將需要的層連接在一起。

import torch
from torch import nn
from d2l import torch as d2lnet = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

將一個大小為28×28的單通道(黑白)圖像通過LeNet。通過在每一層打印輸出的形狀,我們可以檢查模型
在這里插入圖片描述

X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t',X.shape)
Conv2d output shape:         torch.Size([1, 6, 28, 28])
Sigmoid output shape:        torch.Size([1, 6, 28, 28])
AvgPool2d output shape:      torch.Size([1, 6, 14, 14])
Conv2d output shape:         torch.Size([1, 16, 10, 10])
Sigmoid output shape:        torch.Size([1, 16, 10, 10])
AvgPool2d output shape:      torch.Size([1, 16, 5, 5])
Flatten output shape:        torch.Size([1, 400])
Linear output shape:         torch.Size([1, 120])
Sigmoid output shape:        torch.Size([1, 120])
Linear output shape:         torch.Size([1, 84])
Sigmoid output shape:        torch.Size([1, 84])
Linear output shape:         torch.Size([1, 10])

模型訓練

現在我們已經實現了LeNet,讓我們看看LeNet在Fashion-MNIST數據集上的表現。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

由于完整的數據集位于內存中,因此在模型使用GPU計算數據集之前,我們需要將其復制到顯存中。

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU計算模型在數據集上的精度"""if isinstance(net, nn.Module):net.eval()  # 設置為評估模式if not device:device = next(iter(net.parameters())).device# 正確預測的數量,總預測的數量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微調所需的(之后將介紹)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

與全連接層一樣,我們使用交叉熵損失函數和小批量隨機梯度下降。

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU訓練模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 訓練損失之和,訓練準確率之和,樣本數metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

訓練和評估LeNet-5模型。

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.469, train acc 0.823, test acc 0.779
55296.6 examples/sec on cuda:0

在這里插入圖片描述

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/86819.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/86819.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/86819.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

平地起高樓: 環境搭建

技術選型 本小冊是采用純前端的技術棧模擬實現小程序架構的系列文章,所以主要以前端技術棧為主,但是為了模擬一個App應用的效果,以及小程序包下載管理流程的實現,我們還是需要搭建一個基礎的App應用。這里我們將選擇 Tauri2.0 來…

langgraph學習2 - MCP編程

3中通信方式: 目前sse用的很少 3.開發mcp框架 主流框架2個: MCP skd 官方 Fast Mcp V2 ,(V1捐給MCP 官方) 大模型如何識別用哪個tools, 以及如何使用tools:

CSS 與 JavaScript 加載優化

📄 CSS 與 JavaScript 加載優化指南:位置、阻塞與性能 讓你的網頁飛起來!🚀 本文詳細解析 CSS 和 JavaScript 標簽的放置位置如何影響頁面性能,涵蓋阻塞原理、瀏覽器機制和最佳實踐。掌握這些知識可顯著提升用戶體驗…

WSL安裝發行版上安裝podman

WSL安裝發行版上安裝podman 1.WSL拉取發行版1.1 拉取2.2.修改系統拉取的鏡像,可以加速軟件包的更新 2.podman安裝2.1.安裝podman 容器工具2.2.配置podman的鏡像倉庫2.3.拉取n8n鏡像并創建容器 本文在windows11上,使用WSL拉取并創建ubuntu24.04虛擬機&…

Excel 常用快捷鍵與對應 VBA 方法/屬性清單

功能描述快捷鍵VBA 對應方法/屬性 (核心邏輯)說明導航 (類似 End 方向鍵)這些是 End 鍵行為的直接對應向下到連續區域末尾Ctrl ↓ActiveCell.End(xlDown)從當前單元格向下,遇到第一個空單元格停止。向上到連續區域開頭Ctrl ↑ActiveCell.End(xlUp)從當前單元格向上…

計算機組成原理與體系結構-實驗四 微程序控制器 (Proteus 8.15)

一、實驗目的 1、理解“微程序”設計思想,了解“指令-微指令-微命令”的微程序結構。 2、掌握微程序控制器的結構和設計方法。 二、實驗內容 設計一個“最簡版本”的 CPU 模型機:利用時序發生器來產生 CPU 的預定時序,通過微程序控制器的自…

安卓端某音樂類 APP 逆向分享(二)協議分析

以歌曲搜索協議為例,查看charles中歌曲搜索協議詳情 拷貝出搜索協議的Curl形式 curl -H Host: interface3.music.xxx.com -H Cookie: EVNSM1.0.0; NMCIDoufhty.1667355455436.01.4; versioncode8008050; buildver221010200836; resolution2392x1440; deviceIdYDwXa…

七天學會SpringCloud分布式微服務——03——Nacos遠程調用

1、微服務項目配置類放在地方 配置類型應放位置說明通用配置類(如:跨服務通用的攔截器、全局異常處理、統一響應體封裝等)可放在一個**公共模塊(common/config)**中,被各服務引入實現代碼復用,…

基于Java+Spring Boot的校園閑置物品交易系統

源碼編號:S561 源碼名稱:基于Spring Boot的校園閑置物品交易系統 用戶類型:多角色,用戶、商家、管理員 數據庫表數量:12 張表 主要技術:Java、Vue、ElementUl 、SpringBoot、Maven 運行環境&#xff1…

SpringBoot 的 jar 包為什么可以直接運行?

一、普通jar包和SpringBoot jar包有什么區別?什么是jar包?? (1)什么是Jar包? 定義: JAR 包(Java Archive) 是 Java 平臺標準的歸檔文件格式,用于將多個 Jav…

算法-基礎算法-遞歸算法(Python)

文章目錄 前言遞歸和數學歸納法遞歸三步走遞歸的注意點避免棧溢出避免重復運算 題目斐波那契數反轉鏈表 前言 遞歸(Recursion):指的是一種通過重復將原問題分解為同類的子問題而解決的方法。在絕大數編程語言中,可以通過在函數中再…

TVFEMD-CPO-TCN-BiLSTM多輸入單輸出模型

47-TVFEMD-CPO-TCN-BiLSTM多輸入單輸出模型 適合單變量,多變量時間序列預測模型(可改進,加入各種優化算法) 時變濾波的經驗模態分解TVFEMD時域卷積TCN雙向長短期記憶網絡BiLSTM時間序列預測模型 另外以及有 TCN-BILSTM …

深入淺出Node.js中間件機制

我們用一個實際的例子來看看中間件是如何運作的。假設我們有一個非常簡單的Express應用,它只有兩個中間件函數: const express require(express); const app express();app.use((req, res, next) > {console.log(第一個中間件);next(); });app.use…

Vue-15-前端框架Vue之應用基礎編程式路由導航

文章目錄 1 RouterLink的replace屬性1.1 App.vue1.2 應用效果2 編程式路由導航2.1 場景一Home.vue2.2 場景二News.vue3 路由重定向3.1 index.ts3.2 Detail.vue3.3 About.vue1 RouterLink的replace屬性 路由每次跳轉都有記錄,默認是push,可以改為replace。 RouterLink支持兩…

android14 設置下連續點擊5次Settings標題跳轉到撥號界面

部分項目隱藏了撥號器,但開發者需要間距跳轉到撥號界面 設置一級界面: packages/apps/Settings/src/com/android/settings/homepage/SettingsHomepageActivity.java 通過dispatchTouchEvent方法先獲取Settings標題的區域X,Y數據。 import java.util.Set…

MP分頁和連表常用寫法

1. 分頁查詢 方案一&#xff1a;MyBatis XML MyBatis 內置的使用方式&#xff0c;步驟如下&#xff1a; ① 創建 AdminUserMapper.xml 文件&#xff0c;編寫兩個 SQL 查詢語句&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <!DOCTYPE m…

使用 Spring AI Alibaba構建 AI Code Review 應用

很早的時候就想著用AI來做Code Review&#xff0c;最近也看到了一些不錯的實現&#xff0c;但是沒有一個使用Java來構建的&#xff0c;看的比較費勁&#xff0c;雖然說語言只是一種工具&#xff0c;但是還是想用Java重新寫一遍&#xff0c;正好最近Spring AI Alibaba出了正式版…

力扣1590. 使數組和能被 P 整除

這一題的難點在于模運算&#xff0c;對模運算足夠了解&#xff0c;對式子進行變換就很容易得到結果&#xff0c;本質上還是一道前綴和哈希表的題 這里重點講一下模運算。 常見的模運算的用法 (a-b)%k0等價于 a%kb%k 而在這一題中由于多了一個len&#xff0c;&#xff08;數組的…

FPGA內部資源介紹

FPGA內部資源介紹 目錄 邏輯資源塊LUT&#xff08;查找表&#xff09;加法器寄存器MUX&#xff08;復用器&#xff09;時鐘網絡資源 全局時鐘網絡資源區域時鐘網絡資源IO時鐘網絡資源 時鐘處理單元BLOCK RAMDSP布線資源接口資源 用戶IO資源專用高速接口資源 總結 1. 邏輯資源…

CSS 列表

CSS 列表 引言 CSS 列表是網頁設計中常用的一種布局方式&#xff0c;它能夠幫助我們以更靈活、更美觀的方式展示數據。本文將詳細介紹 CSS 列表的創建、樣式設置以及常用技巧&#xff0c;幫助您更好地掌握這一重要技能。 CSS 列表概述 CSS 列表主要包括兩種類型&#xff1a…