神經網絡基礎知識:LeNet的搭建-訓練-預測

1.參考視頻:

2.1 pytorch官方demo(Lenet)_嗶哩嗶哩_bilibili

2.總結:

(1)LeNet網絡就是 我最開始用來預測mnist數據集的那個網絡,簡單的2個conv+2個maxpool+3個linear層

(2)up主整理的train.py等內容里面的細節分析值得學習

(3)對于預測代碼的撰寫,可以參考代碼的predict.py文件

3.幾個文件的源代碼我都貼一下(都不多——但很精):

(1)首先是 model.py:

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 16, 5)self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)x = self.pool1(x)            # output(16, 14, 14)x = F.relu(self.conv2(x))    # output(32, 10, 10)x = self.pool2(x)            # output(32, 5, 5)x = x.view(-1, 32*5*5)       # output(32*5*5)x = F.relu(self.fc1(x))      # output(120)x = F.relu(self.fc2(x))      # output(84)x = self.fc3(x)              # output(10)return x

模型 == 2個conv + 2個max_pool + 3個linear

(2) train.py訓練模型的文件:

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transformsdef main():# 定義transform的數據增強transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 處理cifar10的 train和val的數據集的問題# 50000張訓練圖片# 第一次使用時要將download設置為True才會自動去下載數據集train_set = torchvision.datasets.CIFAR10(root='./data', train=True,download=False, transform=transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,shuffle=True, num_workers=0)# 10000張驗證圖片# 第一次使用時要將download設置為True才會自動去下載數據集val_set = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=5000,shuffle=False, num_workers=0)val_data_iter = iter(val_loader)val_image, val_label = next(val_data_iter)# classes = ('plane', 'car', 'bird', 'cat',#            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 訓練前的準備: 實例化model網絡net , 定義 loss函數 CrossEntropyLoss() 和 Adam優化器net = LeNet()loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 開始訓練:zero_grad() + outputs + loss backward + optim stepfor epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs = net(inputs)loss = loss_function(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if step % 500 == 499:    # print every 500 mini-batcheswith torch.no_grad():outputs = net(val_image)  # [batch, 10]predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 最后把 model的 參數save 為一個.pth文件save_path = './Lenet.pth'torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

分析:數據集劃分 + 實例化網絡_優化器_loss函數 + 分epoch開始尋 + save_pth權重

(3)predict.py:

import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():# 將需要檢測圖像 裁剪為32*32transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')#實例化網絡 + 才入權重net = LeNet()net.load_state_dict(torch.load('Lenet.pth'))# 打開圖像,轉換格式im = Image.open('1.jpg')im = transform(im)  # [C, H, W]im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]# 輸入到網絡中, 得到預測的結果with torch.no_grad():outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()print(classes[int(predict)])if __name__ == '__main__':main()

predict == 處理圖像 + 實例化權重 + 得到預測結果

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/714562.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/714562.shtml
英文地址,請注明出處:http://en.pswp.cn/news/714562.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

SQL面試題(2)

第一題 創建trade_orders表: create table `trade_orders`( `trade_id` varchar(255) NULL DEFAULT NULL, `uers_id` varchar(255), `trade_fee` int(20), `product_id` varchar(255), `time` varchar(255) )ENGINE = InnoDB CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_…

web自動化筆記九:驗證碼的處理方式

一、驗證碼常用的處理方式 ①、說明:Selenium中并沒有對驗證碼處理的方法,在這里我們介紹一下針對驗證碼的幾種常用處理方式 ②、方式: 1)、去掉驗證碼(測試環境下采用) …

RDD算子介紹

1. RDD算子 RDD算子也叫RDD方法,主要分為兩大類:轉換和行動。轉換,即一個RDD轉換為另一個RDD,是功能的轉換與補充,比如map,flatMap。行動,則是觸發任務的執行,比如collect。所謂算子…

LeetCode 1551.是數組中所有元素相等的最小操作數

存在一個長度為 n 的數組 arr &#xff0c;其中 arr[i] (2 * i) 1 &#xff08; 0 < i < n &#xff09;。 一次操作中&#xff0c;你可以選出兩個下標&#xff0c;記作 x 和 y &#xff08; 0 < x, y < n &#xff09;并使 arr[x] 減去 1 、arr[y] 加上 1 &…

Mac專用投屏工具AirServer 7.27 for Mac中文版2024最新圖文教程

Mac專用投屏工具AirServer 7.27 for Mac中文版是一款適用于Mac的投屏工具&#xff0c;可以將Mac屏幕快速投影到其他設備上&#xff0c;如電視、投影儀、平板等。 Mac專用投屏工具AirServer 7.27 for Mac中文版具有優秀的兼容性&#xff0c;可以與各種設備配合使用。無論是iPhon…

基于springboot+vue的在線考試系統(源碼+論文)

文章目錄 目錄 文章目錄 前言 一、功能設計 二、功能頁面 三、論文 前言 現在我國關于在線考試系統的發展以及專注于對無紙化考試的完善程度普遍不高&#xff0c;關于對考試的模式還大部分還停留在紙介質使用的基礎上&#xff0c;這種教學模式已不能解決現在的時代所產生的考試…

【MySQL】數據庫的操作

【MySQL】數據庫的操作 目錄 【MySQL】數據庫的操作創建數據庫數據庫的編碼集和校驗集查看系統默認字符集以及校驗規則查看數據庫支持的字符集查看數據庫支持的字符集校驗規則校驗規則對數據庫的影響數據庫的刪除 數據庫的備份和恢復備份還原不備份整個數據庫&#xff0c;而是備…

YOLOv9改進|增加SPD-Conv無卷積步長或池化:用于低分辨率圖像和小物體的新 CNN 模塊

專欄介紹&#xff1a;YOLOv9改進系列 | 包含深度學習最新創新&#xff0c;主力高效漲點&#xff01;&#xff01;&#xff01; 一、文章摘要 卷積神經網絡(CNNs)在計算即使覺任務中如圖像分類和目標檢測等取得了顯著的成功。然而&#xff0c;當圖像分辨率較低或物體較小時&…

【LeetCode刷題】146. LRU 緩存

請你設計并實現一個滿足 LRU (最近最少使用) 緩存 約束的數據結構。 實現 LRUCache 類&#xff1a; LRUCache(int capacity) 以 正整數 作為容量 capacity 初始化 LRU 緩存int get(int key) 如果關鍵字 key 存在于緩存中&#xff0c;則返回關鍵字的值&#xff0c;否則返回 -…

全量知識系統問題及SmartChat給出的答復 之9 三套工具之4語法解析器 之2

Q23. 一個語言的語法簡約規則 這些規則顯示show 在一個給定單詞&#xff08;a given word&#xff09;的右邊或左邊可能出現的單詞的類別。句型的多樣性variety不是復雜文法&#xff08;a complex grammar&#xff09;的結果&#xff0c;而是簡單語法&#xff08;a simple gra…

【InternLM 實戰營筆記】浦語·靈筆的圖文理解及創作部署、 Lagent 工具調用 Demo

浦語靈筆的圖文理解及創作部署 浦語靈筆是基于書生浦語大語言模型研發的視覺-語言大模型&#xff0c;提供出色的圖文理解和創作能力&#xff0c;結合了視覺和語言的先進技術&#xff0c;能夠實現圖像到文本、文本到圖像的雙向轉換。使用浦語靈筆大模型可以輕松的創作一篇圖文推…

進程間的通信 -- 共享內存

一 共享內存的概念 1. 1 共享內存的原理 之前我們學過管道通信&#xff0c;分為匿名管道和命名管道&#xff0c;匿名管道通過父子進程的屬性繼承原理來完成父子進程看到同一份資源的目的&#xff0c;而命名管道則是通過路徑與文件名來唯一標識管道文件&#xff0c;來讓不同的進…

學習Android的第二十一天

目錄 Android ProgressDialog (進度條對話框) 例子 Android DatePickerDialog 日期選擇對話框 例子 Android TimePickerDialog 時間選擇對話框 Android PopupWindow 懸浮框 構造函數 方法 例子 官方文檔 Android OptionMenu 選項菜單 例子 官方文檔 Android Progr…

Java實戰:Spring Boot中各類參數校驗機制

引言 在開發Web應用程序時&#xff0c;對客戶端傳入的參數進行有效校驗是保證系統安全性和穩定性的重要環節。Spring Boot作為一個現代化的Java開發框架&#xff0c;提供了多種參數校驗的方法和工具&#xff0c;以滿足不同場景下的需求。本文將深入探討Spring Boot中實現各種參…

typescript 的常用方式

文章目錄 前言一、綁定props 默認值的方式&#xff1a;withDefaults1.vue2 的props設置默認值2.vue3 的props設置默認值(1) 不設置默認值的寫法(2) 設置默認值的寫法&#xff08;分離模式&#xff09;(3) 設置默認值的寫法&#xff08;組合模式&#xff09; 二、定義一個二維數…

Matlab在同一張圖中如何加入多個圖例

根據代碼最終畫出的圖片如下&#xff1a; 其實原理很簡單&#xff0c;就是在一張figure中畫多個坐標軸&#xff0c;每個坐標軸都有對應的圖例&#xff0c;之后再將多余坐標軸隱藏&#xff0c;只保留一個即可。 代碼如下&#xff1a; clear all; close all;dd_linewidth 1;a …

maven archetype 項目原型

拓展閱讀 maven 包管理平臺-01-maven 入門介紹 Maven、Gradle、Ant、Ivy、Bazel 和 SBT 的詳細對比表格 maven 包管理平臺-02-windows 安裝配置 mac 安裝配置 maven 包管理平臺-03-maven project maven 項目的創建入門 maven 包管理平臺-04-maven archetype 項目原型 ma…

Spring學習筆記(六)利用Spring的jdbc實現學生管理系統的用戶登錄功能

一、案例分析 本案例要求學生在控制臺輸入用戶名密碼&#xff0c;如果用戶賬號密碼正確則顯示用戶所屬班級&#xff0c;如果登錄失敗則顯示登錄失敗。 &#xff08;1&#xff09;為了存儲學生信息&#xff0c;需要創建一個數據庫。 &#xff08;2&#xff09;為了程序連接數…

洛谷P1927防護傘

題目描述 據說 20122012 的災難和太陽黑子的爆發有關。于是地球防衛小隊決定制造一個特殊防護傘&#xff0c;擋住太陽黑子爆發的區域&#xff0c;減少其對地球的影響。由于太陽相對于地球來說實在是太大了&#xff0c;我們可以把太陽表面看作一個平面&#xff0c;中心定為(0,0…

C 基本語法

我們已經看過 C 程序的基本結構&#xff0c;這將有助于我們理解 C 語言的其他基本的構建塊。 C 的令牌&#xff08;Token&#xff09; C 程序由各種令牌組成&#xff0c;令牌可以是關鍵字、標識符、常量、字符串值&#xff0c;或者是一個符號。例如&#xff0c;下面的 C 語句…