動手學深度學習—卷積神經網絡LeNet(代碼詳解)

1. LeNet

LeNet由兩個部分組成:

  • 卷積編碼器:由兩個卷積層組成;
  • 全連接層密集塊:由三個全連接層組成。

在這里插入圖片描述

  1. 每個卷積塊中的基本單元是一個卷積層、一個sigmoid激活函數和平均匯聚層;
  2. 每個卷積層使用5×5卷積核和一個sigmoid激活函數;
  3. 這些層將輸入映射到多個二維特征輸出,通常同時增加通道的數量;
  4. 每個4×4池操作(步幅2)通過空間下采樣將維數減少4倍。
import torch
from torch import nn
from d2l import torch as d2l# 定義模型net
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

該模型去掉了最后一層的高斯激活,下面將一個大小為28×28的單通道(黑白)圖像通過LeNet,打印每一層輸出的形狀。

# 觀察各層的輸入輸出通道數,寬度和高度
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)

在這里插入圖片描述

  1. 第一個卷積層使用2個像素的填充,來補償5×5卷積核導致的特征減少;
  2. 第二個卷積層沒有填充,因此高度和寬度都減少了4個像素;
  3. 隨著層疊的上升,通道的數量從輸入時的1個,增加到第一個卷積層之后的6個,再到第二個卷積層之后的16個;
  4. 每個匯聚層的高度和寬度都減半;
  5. 每個全連接層減少維數,最終輸出一個維數與結果分類數相匹配的輸出。

2. 模型訓練

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
"""定義精度評估函數:1、將數據集復制到顯存中2、通過調用accuracy計算數據集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判斷net是否屬于torch.nn.Module類if isinstance(net, nn.Module):net.eval()# 如果不在參數選定的設備,將其傳輸到設備中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定義兩個變量:正確預測的數量,總預測的數量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 將X, y復制到設備中if isinstance(X, list):# BERT微調所需的(之后將介紹)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 計算正確預測的數量,總預測的數量,并存儲到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]
"""定義GPU訓練函數:1、為了使用gpu,首先需要將每一小批量數據移動到指定的設備(例如GPU)上;2、使用Xavier隨機初始化模型參數;3、使用交叉熵損失函數和小批量隨機梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU訓練模型(在第六章定義)"""# 定義初始化參數,對線性層和卷積層生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在設備device上進行訓練print('training on', device)net.to(device)# 優化器:隨機梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 損失函數:交叉熵損失函數loss = nn.CrossEntropyLoss()# Animator為繪圖函數animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 調用Timer函數統計時間timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定義3個變量:損失值,正確預測的數量,總預測的數量metric = d2l.Accumulator(3)net.train()# enumerate() 函數用于將一個可遍歷的數據對象for i, (X, y) in enumerate(train_iter):timer.start() # 進行計時optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 將特征和標簽轉移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵損失l.backward() # 進行梯度傳遞返回optimizer.step()with torch.no_grad():# 統計損失、預測正確數和樣本數metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 計時結束train_l = metric[0] / metric[2] # 計算損失train_acc = metric[1] / metric[2] # 計算精度# 進行繪圖if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 測試精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 輸出損失值、訓練精度、測試精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 設備的計算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在這里插入圖片描述

lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在這里插入圖片描述

3. 小結

  1. 卷積神經網絡(CNN)是一類使用卷積層的網絡;
  2. 卷積神經網絡中,可以組合使用卷積層、非線性激活函數和匯聚層;
  3. 為了構造高性能的卷積神經網絡,通常對卷積層進行排列,逐漸降低其表示的空間分辨率,同時增加通道數;
  4. 在傳統的卷積神經網絡中,卷積塊編碼得到的表征在輸出之前需由一個或多個全連接層進行處理。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/43110.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/43110.shtml
英文地址,請注明出處:http://en.pswp.cn/news/43110.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

LeetCode--HOT100題(35)

目錄 題目描述:23. 合并 K 個升序鏈表(困難)題目接口解題思路1代碼解題思路2代碼 PS: 題目描述:23. 合并 K 個升序鏈表(困難) 給你一個鏈表數組,每個鏈表都已經按升序排列。 請你將所有鏈表合…

UDP 的報文結構以及注意事項

UDP協議 1.UDP協議端格式 1.圖中的16位UDP長度,表示整個數據報(UDP首部UDP數據)的最大長度 2.若校驗和出錯,會直接丟棄 2.UDP的報文結構 UDP報文主體分為兩個部分:UDP報頭(占8個字節)UDP載荷/UDP數據 1.源端口號 16位,2個字節 2.目的端口號 16位,2個字節 3.包長度 指示了…

sd-webui安裝comfyui擴展

文章目錄 導讀ComfyUI 環境安裝1. 安裝相關組件2. 啟動sd-webui3. 訪問sd-webui 錯誤信息以及解決辦法 導讀 這篇文章主要給大家介紹如何在sd-webui中來安裝ComfyUI插件 ComfyUI ComfyUI是一個基于節點流程式的stable diffusion的繪圖工具,它集成了stable diffus…

兩個list如何根據一個list中的屬性去過濾掉另一個list中不包含這部分的屬性,用流實現

你可以使用Java 8的流來實現這個功能。假設你有兩個包含對象的List,每個對象有一個屬性,你想根據一個List中的屬性值來過濾掉另一個List中不包含這個屬性值的對象。下面是一種使用流的方式來實現這個功能 import java.util.ArrayList; import java.util…

什么是閉包(closure)?為什么它在JavaScript中很有用?

聚沙成塔每天進步一點點 ? 專欄簡介? 閉包(Closure)是什么?? 閉包的用處? 寫在最后 ? 專欄簡介 前端入門之旅:探索Web開發的奇妙世界 記得點擊上方或者右側鏈接訂閱本專欄哦 幾何帶你啟航前端之旅 歡迎來到前端入門之旅&…

IO流面試題

題目一: 在磁盤中新建一個文件(如果目錄結構不存在,則創建目錄) 文件名:data.txt 文件日錄:C:\demo\test\files (盤符不限) linux目錄~/demo/test/files 題二 在新建的data.txt中添加如下內容: 張三,測試,2019-02-18 …

windows10 安裝WSL2, Ubuntu,docker

AI- 通過docker開發調試部署ChatLLM 閱讀時長:10分鐘 本文內容: window上安裝ubuntu虛擬機,并在虛擬機中安裝docker,通過docker部署數字人模型,通過vscode鏈接到虛擬機進行開發調試.調試完成后,直接部署在云…

優漫動游零基礎如何學習好UI設計

智能時代的來臨,很多企業都越來越注重用戶體驗這一塊,想要有一個吸引用戶的好頁面,UI設計師崗位不可或缺,如今越來越多的人想要學習UI設計技術,那么對于零基礎小白如何學習好UI設計呢? 零基礎小白如何學習好UI設計…

變更通知在開源SpringBoot/SpringCloud微服務中的最佳實踐

目錄導讀 變更通知在開源SpringBoot/SpringCloud微服務中的最佳實踐1. 什么是變更通知2. 變更通知的場景分析3. 變更通知的技術方案3.1 變更通知的技術實現方案 4. 變更通知的最佳實踐總結5. 參考資料 變更通知在開源SpringBoot/SpringCloud微服務中的最佳實踐 1. 什么是變更通…

Ubuntu在自己的項目中使用pcl

1、建立一個文件夾,如pcl_demos,里面建立一個.cpp文件和一個cmake文件 2、打開終端并進入該文件夾下,建立一個build文件夾存放編譯的結果并進入該文件夾 3、對上一級進行編譯 cmake .. 4、生成可執行文件 make 5、運行該可執行文件 6、可視…

最強自動化測試框架Playwright(30)-JS句柄

在 Playwright 中,JSHandle 是一個表示瀏覽器中 JavaScript 對象的類。它提供了與網頁中的 JavaScript 對象進行交互和操作的方法。 可以通過調用 Playwright中的 evaluateHandle 或 evaluate 方法來獲取 JSHandle from playwright.sync_api import sync_playwrig…

微服務中間件-分布式緩存Redis

分布式緩存 a.Redis持久化1) RDB持久化1.a) RDB持久化-原理 2) AOF持久化3) 兩者對比 b.Redis主從1) 搭建主從架構2) 數據同步原理(全量同步)3) 數據同步原理(增量同步) c.Redis哨兵1) 哨兵的作用2) 搭建Redis哨兵集群3) RedisTem…

金融語言模型:FinGPT

項目簡介 FinGPT是一個開源的金融語言模型(LLMs),由FinNLP項目提供。這個項目讓對金融領域的自然語言處理(NLP)感興趣的人們有了一個可以自由嘗試的平臺,并提供了一個與專有模型相比更容易獲取的金融數據。…

Java根據List集合中的一個字段對集合進行去重

利用HashSet 創建了一個HashSet用于存儲唯一的字段值&#xff0c;并創建了一個新的列表uniqueList用于存儲去重后的對象。遍歷原始列表時&#xff0c;如果字段值未在HashSet中出現過&#xff0c;則將其添加到HashSet和uniqueList中。 List<Person> originalList new Ar…

VS2015項目中,MFC內存中調用DLL函數(VC6生成的示例DLL)

本例主要講一下&#xff0c;用VC6如何生成DLL&#xff0c;用工具WinHex取得DLL全部內容&#xff0c;VC2015項目加載內存中的DLL函數&#xff0c;并調用函數的示例。 本例中的示例代碼下載&#xff0c;點擊可以下載 一、VC6.0生成示例DLL項目 1.新建項目&#xff0c;…

mysql中的is null和空字符串

相比于oracle&#xff0c;mysql中的is null 和空坑就沒那么多&#xff0c;直接寫就行。 不為空 and (username is not null and username !)注&#xff1a; 不為空中間用的是and。 為空 and (username is null or username !)注&#xff1a; 為空中間用的是or。

java應用運行在docker,并且其他組件也在docker

docker啟動redis容器 # create redis docker run -d --name redis-container -p 6379:6379 redis:latest創建java 應用 dockerfile FROM openjdk:17##Pre-create related directories RUN mkdir -p /data/etax/ms-app WORKDIR /data/etax/ms-appEXPOSE 10133 COPY ./target…

SQL Server Express 自動備份方案

文章目錄 SQL Server Express 自動備份方案前言方案原理SQL Server Express 自動備份1.創建存儲過程2.設定計劃任務3.結果檢查sqlcmd 參數說明SQL Server Express 自動備份方案 前言 對于許多小型企業和個人開發者來說,SQL Server Express是一個經濟實惠且強大的數據庫解決方…

Spring Framework中的Bean生命周期

目錄 一.Bean生命周期的簡介 1.基本概念 2.Spring生命周期的幾大階段 3.注意點及小結 4.生活案例 5.Spring容器管理JavaBean的初始化過程 二. Bean的單例選擇與多例選擇 1.單例選擇與多例選擇的優缺點 1.1單例模式的優點&#xff1a; 1.2單例模式的缺點&#xff1a; 1…

JDK 8 升級 JDK 17 全流程教學指南

JDK 8 升級 JDK 17 首先已有項目升級是會經歷一個較長的調試和自測過程來保證允許和兼容沒有問題。先說幾個重要的點 遇到問題別放棄仔細閱讀報錯&#xff0c;精確到每個單詞每一行&#xff0c;不是自己項目的代碼也要點進去看看源碼到底是為啥報錯明確你項目引入的包&#x…