【學習筆記】深度學習實戰 | LeNet

在這里插入圖片描述

簡要聲明


  1. 學習相關網址
    1. [雙語字幕]吳恩達深度學習deeplearning.ai
    2. Papers With Code
    3. Datasets
  2. 深度學習網絡基于PyTorch學習架構,代碼測試可跑。
  3. 本學習筆記單純是為了能對學到的內容有更深入的理解,如果有錯誤的地方,懇請包容和指正。

參考文獻


  1. PyTorch Tutorials [https://pytorch.org/tutorials/]
  2. PyTorch Docs [https://pytorch.org/docs/stable/index.html]
  3. LeNet (1998) [Gradient-based learning applied to document recognition]

簡要介紹


LeNet

在這里插入圖片描述

DatasetMNIST
Input (feature maps)32×32 (28×28)
CONV Layers2
FC Layers2
ActivationSigmoid
Output10

代碼分析


函數庫調用

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

處理數據

數據下載

# 從開放數據集中下載訓練數據
train_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# 從開放數據集中下載測試數據
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)print(f'Number of training examples: {len(train_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 60000
Number of testing examples: 10000

數據加載器(可選)

batch_size = 64# 創建數據加載器
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64

創建模型

# 選擇訓練設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cuda device

class LeNet(nn.Module):def __init__(self, output_dim):super().__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2))self.conv_2 = nn.Sequential(nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc_1 = nn.Sequential(nn.Linear(16*5*5, 120),nn.Sigmoid())self.fc_2 = nn.Sequential(nn.Linear(120, 84),nn.Sigmoid())self.fc_3 = nn.Linear(84, output_dim)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = x.view(x.size(0), -1)x = self.fc_1(x)x = self.fc_2(x)x = self.fc_3(x)return xmodel = LeNet(10).to(device)
print(model)

LeNet(
(conv_1): Sequential(
(0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): Sigmoid()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(conv_2): Sequential(
(0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(1): Sigmoid()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fc_1): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): Sigmoid()
)
(fc_2): Sequential(
(0): Linear(in_features=120, out_features=84, bias=True)
(1): Sigmoid()
)
(fc_3): Linear(in_features=84, out_features=10, bias=True)
)

訓練模型

選擇損失函數和優化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

訓練循環

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

測試循環

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

訓練模型

epochs = 10.
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 10
loss: 0.015569 [ 64/60000]
loss: 0.029817 [ 6464/60000]
loss: 0.043169 [12864/60000]
loss: 0.027709 [19264/60000]
loss: 0.021492 [25664/60000]
loss: 0.011533 [32064/60000]
loss: 0.045418 [38464/60000]
loss: 0.042875 [44864/60000]
loss: 0.152001 [51264/60000]
loss: 0.040214 [57664/60000]
Test Error:
Accuracy: 98.6%, Avg loss: 0.044844

模型處理

保存模型

model_name = 'LeNet'
model_file = model_name + ".pth"
torch.save(model.state_dict(), model_file)
print("Saved PyTorch Model State to " + model_file)

Saved PyTorch Model State to LeNet.pth

Summary


安裝torchsummary

pip install torchsummary

調用summary

from torchsummary import summarymodel = LeNet(10).to(device)
summary(model, (1, 28, 28))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1            [-1, 6, 28, 28]             156Sigmoid-2            [-1, 6, 28, 28]               0MaxPool2d-3            [-1, 6, 14, 14]               0Conv2d-4           [-1, 16, 10, 10]           2,416Sigmoid-5           [-1, 16, 10, 10]               0MaxPool2d-6             [-1, 16, 5, 5]               0Linear-7                  [-1, 120]          48,120Sigmoid-8                  [-1, 120]               0Linear-9                   [-1, 84]          10,164Sigmoid-10                   [-1, 84]               0Linear-11                   [-1, 10]             850
================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.11
Params size (MB): 0.24
Estimated Total Size (MB): 0.35
----------------------------------------------------------------

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/711428.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/711428.shtml
英文地址,請注明出處:http://en.pswp.cn/news/711428.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

KubeEdge 邊緣計算

文章目錄 1.KubeEdge2.KubeEdge 特點3.KubeEdge 組成4.KubeEdge 架構 KubeEdge # KubeEdgehttps://iothub.org.cn/docs/kubeedge/ https://iothub.org.cn/docs/kubeedge/kubeedge-summary/1.KubeEdge KubeEdge 是一個開源的系統,可將本機容器化應用編排和管理擴展…

藍牙耳機和筆記本電腦配對連接上了,播放設備里沒有顯示藍牙耳機這個設備,選不了輸出設備

環境: WIN10 雜牌藍牙耳機6s 問題描述: 藍牙耳機和筆記本電腦配對連接上了,播放設備里沒有顯示藍牙耳機這個設備,選不了輸出設備 解決方案: 1.打開設備和打印機,找到這個設備 2.選中這個設備&#…

Linux下gcc編譯常用命令詳解

在Linux環境下,使用gcc編譯器進行源代碼的編譯是程序員日常工作的一部分。本篇將介紹一些常用的gcc編譯命令,幫助開發者更好地理解和使用這些命令。 1. 基本編譯命令 gcc工作流程: 編譯單個源文件 gcc source.c -o output這個命令將sour…

20240229筆記

瀏覽器預加載器 手動&#xff1a;prefetch preload <link rel"prefetch" href"next.html"> <link rel"preload" as"style" href"styles.css"> <link rel"preload" as"javascript" hr…

調試工具vue,react,redux

React Developer Tools Redux DevTools Vue devtools 使用瀏覽器官方組件擴展搜索安裝

C語言練習:(力扣645)錯誤的集合

題目鏈接&#xff1a;645. 錯誤的集合 - 力扣&#xff08;LeetCode&#xff09; 集合 s 包含從 1 到 n 的整數。不幸的是&#xff0c;因為數據錯誤&#xff0c;導致集合里面某一個數字復制了成了集合里面的另外一個數字的值&#xff0c;導致集合 丟失了一個數字 并且 有一個數字…

枚舉和聯合(共用體)

目錄 枚舉枚舉類型的定義枚舉的優點 聯合&#xff08;共用體&#xff09;聯合類型的定義聯合的特點聯合大小的計算 枚舉 枚舉顧名思義就是一一列舉&#xff0c;把可能的取值一一列舉 枚舉類型的定義 enum Day &#xff0c; enum Sex &#xff0c;enum Color 都是枚舉類型{}中…

springboot生成圖片驗證碼(借鑒并分析)

目錄 一、CaptchaUtil代碼展示二、CaptchaController 代碼展示 一、CaptchaUtil代碼展示 package com.minster.yanapi.utils;import com.google.code.kaptcha.impl.DefaultKaptcha;import com.google.code.kaptcha.util.Config; import org.springframework.context.annotatio…

MMDetection3D v1.3.0安裝教程

MMDetection3D v1.3.0安裝教程 1. 系統環境2. 安裝2.1 基本環境安裝2.2 調整具體版本2.3 驗證2.4 安裝MinkowskiEngine和TorchSparse 3. 最終環境配置 根據 v1.3.0版本官方手冊測試后的安裝配置&#xff0c;親測可行。 1. 系統環境 項目版本日期Ubuntu18.04.06 LTS-顯卡RTX 2…

曾桂華:車載座艙音頻體驗探究與思考| 演講嘉賓公布

智能車載音頻 I 分論壇將于3月27日同期舉辦&#xff01; 我們正站在一個前所未有的科技革新的交匯點上&#xff0c;重塑我們出行體驗的變革正在悄然發生。當人工智能的磅礴力量與車載音頻相交融&#xff0c;智慧、便捷與未來的探索之旅正式揚帆起航。 在駕駛的旅途中&#xff0…

安裝 Distribution Registry

Distribution Registry是由容器部署&#xff0c;所有前提是需要安裝docker 參考文檔&#xff1a;https://docs.docker.com/engine/install/centos/ Registry 官網文檔 https://distribution.github.io/distribution/ 安裝Registry倉庫 docker run -d -p 5000:5000 --restartalw…

通過css修改video標簽的原生樣式

通過css修改video標簽的原生樣式 描述實現結果 描述 修改video標簽的原生樣式 實現 在控制臺中打開設置&#xff0c;勾選顯示用戶代理 shadow DOM&#xff0c;就可以審查video標簽的內部樣式了 箭頭處標出來的就是shodow DOM的內容&#xff0c;這些內容正常不可見的&#x…

MySQL 用了哪種默認隔離級別,實現原理是什么?

MySQL 的默認隔離級別是 RR - 可重復讀&#xff0c;可以通過命令來查看 MySQL 中的默認隔離級別。 RR - 可重復讀是基于多版本并發控制&#xff08;Multi-Version Concurrency Control&#xff0c;MVCC &#xff09;實現的。MVCC&#xff0c;在讀取數據時通過一種類似快照的方…

視覺三維重建colmap框架的現狀與未來

注&#xff1a;該文章首發3D視覺工坊&#xff0c;鏈接如下3D視覺工坊 前言 眾所周知&#xff0c;三維重建的發展已經進入了穩定期&#xff0c;尤其是離線方案的發展幾乎處于停滯期&#xff0c;在各大論刊上也很少見到傳統sfmmvs亮眼的文章。這也不難理解&#xff0c;傳統的多視…

MYSQL 解釋器小記

解釋器的結果通常通過上述表格展示&#xff1a; 1. select_type 表示查詢的類型 simple: 表示簡單的選擇查詢&#xff0c;沒有子查詢或連接操作 primary:表示主查詢&#xff0c;通常是最外層的查詢 subquery :表示子查詢&#xff0c;在主查詢中嵌套的查詢 derived: 表示派…

【王道數據結構】【chapter8排序】【P360t2】

試編寫一個算法&#xff0c;使之能夠在數組L[1……n]中找出第k小的元素&#xff08;即從小到大排序后處于第k個位置的元素&#xff09;&#xff08;可以直接采用排序&#xff0c;但下面的排序的代碼只是為了方便核對是不是第k小的元素&#xff0c;k從0開始計算&#xff09; #in…

出海手游收入一路高歌,營銷上如何成功?

出海手游收入一路高歌&#xff0c;營銷上如何成功&#xff1f; 以RPG和SLG為代表的中重度游戲一直是國內廠商在海外市場的傳統優勢品類&#xff0c;因為它們具有較高的投資回報率&#xff0c;是國內廠商在國際市場上取得成功的“吸金”利器。 據伽馬數據發布的《2023全球移動游…

SpringCloud搭建微服務之Consul服務配置

1. 概述 前面有介紹過Consul既可以用于服務注冊和發現&#xff0c;也可以用于服務配置&#xff0c;本文主要介紹如何使用Consul實現微服務的配置中心&#xff0c;有需要了解如何安裝Consul的小伙伴&#xff0c;請查閱SpringCloud搭建微服務之Consul服務注冊與發現 &#xff0c…

steam怎么付款

信用卡支付 登錄Steam賬戶&#xff0c;選擇需要購買的游戲或其他物品&#xff0c;點擊“加入購物車”。在購物車頁面點擊“去結賬”按鈕&#xff0c;進入付款頁面。在付款頁面選擇信用卡付款方式&#xff0c;填寫信用卡信息&#xff0c;輸入驗證碼&#xff0c;點擊確認付款。 …

Servlet 新手村引入-編寫一個簡單的servlet項目

Servlet 新手村引入-編寫一個簡單的servlet項目 文章目錄 Servlet 新手村引入-編寫一個簡單的servlet項目一、編寫一個 Hello world 項目1.創建項目2.引入依賴3.手動創建一些必要的目錄/文件4.編寫代碼5.打包程序6.部署7.驗證程序 二、更方便的處理方案&#xff08;插件引入&am…