用 PyTorch 實現全連接網絡識別 MNIST 手寫數字

目錄

一、什么是全連接網絡

二、代碼實現步驟

1. 導入必要的庫

2. 數據準備

3. 定義網絡結構

4. 模型訓練

5. 模型保存和加載

6. 預測單張圖片

7. 主函數

三、運行結果說明

四、小結


一、什么是全連接網絡

全連接神經網絡(Fully Connected Neural Network)是一種最基礎的神經網絡結構,其特點是每一層的每個神經元都與上一層的所有神經元相連。

打個比方,就像公司里的部門架構:輸入層是基層員工,隱藏層是中層管理,輸出層是高層決策。基層的每個人都要向所有中層匯報,中層再向所有高層匯報,這樣信息就能經過多層處理后得到最終結果。

但全連接網絡處理圖像時有個缺點:它會把圖像的二維像素矩陣轉換成一維向量,這就像把一張完整的圖片撕成一條線,會丟失圖像的空間特征。

二、代碼實現步驟

1. 導入必要的庫

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

這些庫就像我們的工具包:

  • torch?是 PyTorch 的核心庫
  • nn?模塊包含神經網絡相關的工具
  • optim?提供優化器
  • torchvision?有現成的數據集和圖像處理工具
  • DataLoader?幫助我們批量加載數據
  • PIL?用于處理圖像

2. 數據準備

def build_data():transform = transforms.Compose([transforms.ToTensor(),])train_set = datasets.MNIST(root = '../dataset',train = True,download = True,transform = transform)test_set = datasets.MNIST(root = '../dataset',train = False,download = True,transform = transform)train_loader = DataLoader(dataset = train_set,batch_size = 128,shuffle = True)test_loader = DataLoader(dataset = test_set,batch_size = 64,shuffle = True)return train_loader, test_loader

這段代碼做了三件事:

  • 定義了數據轉換方式,ToTensor()會把圖像轉換成張量并歸一化
  • 加載 MNIST 數據集(手寫數字數據集,包含 0-9 共 10 類數字)
  • DataLoader把數據分成批次,方便訓練時批量處理

batch_size表示每次處理多少張圖片,shuffle=True表示打亂數據順序,讓模型學習更全面。

3. 定義網絡結構

class MNISTNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(28 * 28, 256)self.relu1 = nn.ReLU()self.fc2 = nn.Linear(256, 128)self.relu2 = nn.ReLU()self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)  # 把28x28的圖像展平成784維向量x = self.relu1(self.fc1(x))x = self.relu2(self.fc2(x))x = self.fc3(x)return x

我們定義了一個 3 層的全連接網絡:

  • 輸入層:MNIST 圖像是 28x28 的,展平后是 784 個像素點
  • 第一個隱藏層:256 個神經元,使用 ReLU 激活函數
  • 第二個隱藏層:128 個神經元,同樣使用 ReLU 激活函數
  • 輸出層:10 個神經元(對應 0-9 十個數字)

激活函數 ReLU 的作用是引入非線性,讓網絡能夠學習復雜的模式,就像給計算器增加了更多運算功能。

4. 模型訓練

def train(model, train_loader, epochs):criterion = nn.CrossEntropyLoss()  # 交叉熵損失函數,適合分類問題opt = optim.SGD(model.parameters(), lr=0.01)  # 隨機梯度下降優化器for epoch in range(epochs):loss_sum = 0count = 0for x, y in train_loader:y_pred = model(x)  # 前向傳播,得到預測結果loss = criterion(y_pred, y)  # 計算損失# 反向傳播更新參數opt.zero_grad()  # 清空梯度loss.backward()  # 計算梯度opt.step()  # 更新參數loss_sum += loss.item()_, pred = torch.max(y_pred, dim=1)  # 找到概率最大的類別count += (pred == y).sum().item()  # 統計正確的數量acc = count / len(train_loader.dataset)  # 計算準確率print(f'epoch: {epoch+1}, Loss: {loss_sum:.4f}, Acc: {acc:.4f}')

訓練過程就像學生做習題:

  1. 先用當前模型做預測(前向傳播)
  2. 計算預測結果和正確答案的差距(損失函數)
  3. 分析哪里錯了,怎么改進(反向傳播計算梯度)
  4. 調整模型參數(優化器更新參數)

我們用交叉熵損失函數來衡量預測錯誤的程度,用隨機梯度下降(SGD)來優化模型參數,學習率lr=0.01控制每次調整的幅度。

5. 模型保存和加載

def save_model(model, model_path):torch.save(model.state_dict(), model_path)  # 保存模型參數def load_model(model_path):model = MNISTNet()model.load_state_dict(torch.load(model_path))  # 加載模型參數return model

訓練好的模型可以保存下來,下次用的時候直接加載,不用重新訓練,就像保存游戲進度一樣。

6. 預測單張圖片

def predict(model, filePath):img = Image.open(filePath)# 圖像預處理:調整大小、轉成張量、歸一化transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])t_img = transform(img)with torch.no_grad():  # 預測時不需要計算梯度y_pred = model(t_img)_, pred = torch.max(y_pred, dim=1)print(f'預測結果: {pred.item()}')

預測時需要對輸入圖片做和訓練數據相同的預處理,with torch.no_grad()可以加快計算速度,因為預測時不需要更新參數。

7. 主函數

if __name__ == '__main__':train_loader, test_loader = build_data()model = MNISTNet()# 訓練模型train(model, train_loader, epochs=10)# 保存模型save_model(model, './mnist.pt')# 加載模型并預測model_pred = load_model('./mnist.pt')predict(model_pred, './img/3.png')

三、運行結果說明

訓練過程中,我們會看到損失(Loss)逐漸減小,準確率(Acc)逐漸提高,這說明模型在不斷進步。

對于 MNIST 這種簡單數據集,用這個全連接網絡通常能達到 97% 以上的準確率。如果想進一步提高性能,可以考慮使用卷積神經網絡(CNN),它能更好地保留圖像的空間特征。

四、小結

本文用 PyTorch 實現了一個全連接神經網絡來識別 MNIST 手寫數字,主要步驟包括:

  1. 準備數據:加載并預處理 MNIST 數據集
  2. 定義網絡:設計 3 層全連接網絡
  3. 訓練模型:使用交叉熵損失和 SGD 優化器
  4. 保存和加載模型:方便復用
  5. 單張圖片預測:實際應用模型

全連接網絡雖然簡單,但它是理解更復雜神經網絡的基礎。通過這個例子,我們可以了解神經網絡的基本工作原理和 PyTorch 的使用方法。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/90328.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/90328.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/90328.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

vscode怎么安裝MINGW

下載: 第一步選擇MINGW官網:MinGW-w64 - for 32 and 64 bit Windows - SourceForge.net 點擊Files 點擊Toolchains targetting Win64 點擊第一個 Personal Builds 點擊mingw-builds 選擇8.1.0 點擊第二個 threads-posix 點擊第二個seh 最后左鍵點擊下…

CSS圖片分層設置

在CSS中實現圖片分層效果,主要通過定位屬性和層疊上下文控制。以下是核心實現方法和示例: 一、核心實現原理定位方式 使用 position: relative/absolute/fixed 使圖片脫離文檔流 .layer {position: absolute; /* 關鍵屬性 */top: 0;left: 0; }層疊控制 通…

GEMINUS 和 Move to Understand a 3D Scene

論文鏈接:https://arxiv.org/abs/2507.14456 代碼鏈接:https://github.com/newbrains1/GEMINUS 端到端自動駕駛的挑戰 端到端自動駕駛是一種“一站式”方法:模型直接從傳感器輸入(如攝像頭圖像)生成駕駛軌跡或控制信號…

算法與數據結構:線性表

C語言數據結構基礎:線性表詳解線性表是數據結構中最基礎、最常用的形式,就像一列整齊排隊的游客:每個元素有固定位置(前驅和后繼),長度可動態變化。在C語言中,它主要通過順序表(數組…

制作mac 系統U盤

使用 installinstallmacos.py(更兼容) 蘋果官方不提供所有歷史版本的安裝器,但可以通過一個開源腳本下載(Apple 提供的企業支持工具): git clone https://github.com/munki/macadmin-scripts.git cd macadm…

滲透部分總結

docker環境搭建以及dns等原理講解Docker搭建:Linux 系統上安裝 Docker 引擎并啟動服務:# 安裝Docker引擎 curl -fsSL https://get.docker.com | sh 通過 curl 下載并執行 Docker 官方的安裝腳本,這會自動配置 Docker 倉庫并安裝最新版本的 Do…

k8s pvc是否可綁定在多個pod上

1.pvc是否可綁定在多個podPVC 是否能被多個 Pod 使用,取決于它的 accessModes。PVC 的 accessModes是否支持多個 Pod 同時使用說明ReadWriteOnce (RWO)? 若多個Pod,需在相同節點上(僅允許被單個節點上的Pod掛載)常用于本地磁盤、…

如何加固Endpoint Central服務器的安全?(下)

Endpoint Central 作為企業終端管理的 “中樞系統”,掌控著全網終端的補丁推送、軟件部署、配置管理、遠程控制等關鍵權限,存儲著大量終端資產信息、用戶數據及企業策略配置。一旦服務器被攻破,攻擊者可能篡改管理指令(如推送惡意…

信息整合注意力IIA,通過雙方向注意力機制重構空間位置信息,動態增強目標關鍵特征并抑制噪聲

在遙感圖像語義分割等視覺任務中,編碼器 - 解碼器結構通過跳躍連接融合多尺度特征時,常面臨兩大挑戰:一是編碼器的局部細節特征與解碼器的全局語義特征融合時,空間位置信息易丟失,導致目標定位不準;二是復雜…

如何遷移jenkins至另一臺服務器

前言公司舊的服務器快到期了,需要將部署在其上的jenkins整體遷移到另一臺服務器,兩臺都是aws ec2服務器。文章主要提供給大家一種遷移思路,并不一定是最優解,僅供參考,大家根據實際情況自行選用和修改,舉一…

在vue中遇到Uncaught TypeError: Assignment to constant variable(常亮無法修改)

1.問題如下:2.出現這個問題的原因----在設計變量的時候采用了const來進行修飾,在修改的時候直接對其進行修改3.利用響應式變量的特點,修改為下面這樣就可以正常了

RCE隨筆-奇技淫巧(2)

Linux命令長度限制在7個字符的情況下&#xff0c;如何拿到shell <?php $param $_REQUEST[param]; If ( strlen($param) < 8 ) { echo shell_exec($param); }分析代碼&#xff1a;這段代碼傳入參數param然后進入if語句判斷是否小于8個字符&#xff0c;然后如果小于就會進…

設計模式九:構建器模式 (Builder Pattern)

動機(Motivation)1、在軟件系統中&#xff0c;有時候面臨著“一個復雜對象”的創建工作&#xff0c;其通常由各個部分的子對象用一定的算法構成&#xff1b;由于需求的變化&#xff0c;這個復雜對象的各個部分經常面臨著劇烈的變化&#xff0c;但是將它們組合在一起的算法卻相對…

如何高效合并音視頻文件

在自我學習或者進行視頻剪輯的時候&#xff0c;經常從資源網址下載音視頻分離的文件&#xff0c;例如audio_file1.m4a和video_1.mp4&#xff0c;之后需要把這兩個文件合并在一起。于是條件反射得想要利用剪映等第三方工具&#xff0c;進行音視頻的封裝。可惜不幸的是&#xff0…

虛幻 5 與 3D 軟件的協作:實時渲染,所見所得

《曼達洛人》的星際飛船在片場實時掠過虛擬荒漠&#xff0c;游戲開發者拖動滑塊就能即時看到角色皮膚的通透變化&#xff0c;實時渲染技術正以 “所見即所得” 的核心優勢&#xff0c;重塑著 3D 創作的整個邏輯。虛幻引擎 5&#xff08;UE5&#xff09;憑借 Lumen 全局光照和 N…

?Eyeriss 架構中的訪存行為解析(騰訊元寶)

?Eyeriss 架構中的訪存行為解析?Eyeriss 是 MIT 提出的面向卷積神經網絡&#xff08;CNN&#xff09;的能效型 NPU&#xff08;神經網絡處理器&#xff09;架構&#xff0c;其核心創新在于通過硬件結構優化訪存行為&#xff0c;以解決傳統 GPU 在處理 CNN 時因數據搬運導致的…

數字圖像處理(三:圖像如果當作矩陣,那加減乘除處理了矩陣,那圖像咋變):從LED冬奧會、奧運會及春晚等等大屏,到手機小屏,快來挖一挖里面都有什么

數字圖像處理&#xff08;三&#xff09;一、&#xff08;準備工作&#xff1a;咋玩&#xff0c;用什么玩具&#xff09;圖像以矩陣形式存儲&#xff0c;那矩陣一變、圖像立刻跟著變&#xff1f;1. Python Jupyter Notebook/Lab 庫 (NumPy, OpenCV, Matplotlib, scikit-image…

docker-desktop啟動失敗

報錯提示deploying WSL2 distributions ensuring main distro is deployed: checking if main distro is up to date: checking main distro bootstrap version: getting main distro bootstrap version: open \\wsl$\docker-desktop\etc\wsl_bootstrap_version: The network n…

基于FastMCP創建MCP服務器的小白級教程

以下是基于windows 11操作系統環境的開發步驟。 1、python環境搭建 訪問官網&#xff1a;https://www.python.org/。下載相應的版本&#xff08;如&#xff1a;3.13.5&#xff09;&#xff0c;然后安裝。 安裝完成之后&#xff0c;使用命令行工具輸入python&#xff0c;顯示…

網絡協議與層次對應表

網絡協議與層次對應表&#xff08;OSI & TCP/IP模型&#xff09;OSI七層模型TCP/IP四層模型協議/技術核心功能與應用?應用層?應用層HTTP/HTTPS網頁傳輸協議&#xff08;HTTP&#xff09;及其加密版&#xff08;HTTPS&#xff09;FTP文件上傳/下載協議SMTP/POP3/IMAPSMTP發…