用 PyTorch 實現一個簡單的神經網絡:從數據到預測

PyTorch 是目前最流行的深度學習框架之一,以其靈活性和易用性受到開發者的喜愛。本文將帶你從零開始,用 PyTorch 實現一個簡單的神經網絡,用于解決經典的 MNIST 手寫數字分類問題。我們將涵蓋數據準備、模型構建、訓練和預測的完整流程,并提供可運行的代碼示例。

1. 環境準備

首先,確保你已安裝 PyTorch 和相關依賴。本例使用 Python 3.8+ 和 PyTorch。你可以通過以下命令安裝:

pip install torch torchvision

我們將使用 MNIST 數據集,它包含 28x28 像素的手寫數字圖像(0-9),目標是訓練一個神經網絡來識別這些數字。

2. 數據準備

MNIST 數據集可以通過 PyTorch 的?torchvision?模塊直接加載。我們需要將數據加載為張量,并進行歸一化處理以加速訓練。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 定義數據預處理:將圖像轉換為張量并歸一化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST 的均值和標準差
])# 加載 MNIST 數據集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 創建數據加載器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

代碼說明

  • transforms.ToTensor()?將圖像轉換為 PyTorch 張量,并將像素值從 [0, 255] 縮放到 [0, 1]。

  • transforms.Normalize?標準化數據,加速梯度下降收斂。

  • DataLoader?用于批量加載數據,batch_size=64?表示每次處理 64 張圖像。

3. 構建神經網絡

我們將定義一個簡單的全連接神經網絡,包含兩個隱藏層,適合處理 MNIST 的分類任務。

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()  # 將 28x28 圖像展平為 784 維向量self.fc1 = nn.Linear(28 * 28, 128)  # 第一個全連接層self.relu = nn.ReLU()  # 激活函數self.fc2 = nn.Linear(128, 64)  # 第二個全連接層self.fc3 = nn.Linear(64, 10)   # 輸出層,10 個類別(0-9)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# 實例化模型
model = SimpleNN()

代碼說明

  • nn.Module?是 PyTorch 模型的基類,自定義模型需要繼承它。

  • forward?方法定義了前向傳播的計算流程。

  • 網絡結構:輸入層 (784) → 隱藏層1 (128) → ReLU → 隱藏層2 (64) → ReLU → 輸出層 (10)。

4. 定義損失函數和優化器

我們使用交叉熵損失(適合分類任務)和 Adam 優化器來訓練模型。

import torch.optim as optim# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

代碼說明

  • nn.CrossEntropyLoss?結合了 softmax 和負對數似然損失,適合多分類任務。

  • Adam?優化器以 0.001 的學習率更新模型參數。

5. 訓練模型

接下來,我們訓練模型 5 個 epoch,觀察損失變化。

def train(model, train_loader, criterion, optimizer, epochs=5):model.train()  # 切換到訓練模式for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()  # 清零梯度outputs = model(images)  # 前向傳播loss = criterion(outputs, labels)  # 計算損失loss.backward()  # 反向傳播optimizer.step()  # 更新參數running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")# 開始訓練
train(model, train_loader, criterion, optimizer)

代碼說明

  • model.train()?啟用訓練模式(影響 dropout 和 batch norm 等層)。

  • 每次迭代清零梯度、計算損失、反向傳播并更新參數。

  • 每輪 epoch 打印平均損失。

6. 測試模型

訓練完成后,我們在測試集上評估模型的準確率。

def test(model, test_loader, criterion):model.  # 切換到評估模式correct = 0total = 0test_loss = 0.0with torch.no_grad():  # 禁用梯度計算for images, labels in test_loader:outputs = model(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)  # 獲取預測類別total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {accuracy:.2f}%")# 測試模型
test(model, test_loader, criterion)

代碼說明

  • model.?切換到評估模式,禁用 dropout 等。

  • 使用?torch.no_grad()?減少內存消耗。

  • 計算測試集的損失和準確率。

7. 進行預測

最后,我們用訓練好的模型對單張圖像進行預測。

import matplotlib.pyplot as plt# 獲取一張測試圖像
dataiter = iter(test_loader)
images, labels = next(dataiter)
image, label = images[0], labels[0]# 預測
model.
with torch.no_grad():output = model(image.unsqueeze(0))  # 增加 batch 維度_, predicted = torch.max(output, 1)# 顯示圖像和預測結果
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"Predicted: {predicted.item()}, Actual: {label.item()}")
plt.savefig('prediction.png')  # 保存圖像

代碼說明

  • 從測試集取一張圖像,調用模型進行預測。

  • 使用 Matplotlib 顯示圖像及其預測結果,保存為 PNG 文件。

8. 完整代碼

以下是完整的可運行代碼,整合了上述所有步驟:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 數據準備
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 定義模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# 實例化模型、損失函數和優化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練函數
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")# 測試函數
def test(model, test_loader, criterion):model.correct = 0total = 0test_loss = 0.0with torch.no_grad():for images, labels in test_loader:outputs = model(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {accuracy:.2f}%")# 訓練和測試
train(model, train_loader, criterion, optimizer)
test(model, test_loader, criterion)# 預測單張圖像
dataiter = iter(test_loader)
images, labels = next(dataiter)
image, label = images[0], labels[0]
model.
with torch.no_grad():output = model(image.unsqueeze(0))_, predicted = torch.max(output, 1)
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"Predicted: {predicted.item()}, Actual: {label.item()}")
plt.savefig('prediction.png')

9. 總結

通過本文,可以了解到如何用 PyTorch 實現一個簡單的神經網絡,包括:

  • 加載和預處理 MNIST 數據集。

  • 構建一個全連接神經網絡。

  • 使用交叉熵損失和 Adam 優化器進行訓練。

  • 在測試集上評估模型性能。

  • 對單張圖像進行預測并可視化結果。

這個模型雖然簡單,但在 MNIST 數據集上通常能達到 95% 以上的準確率。可以進一步嘗試調整網絡結構(如增加層數)、優化超參數(如學習率)或使用卷積神經網絡(CNN)來提升性能。希望這篇文章對你理解 PyTorch 和深度學習有所幫助!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/91967.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/91967.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/91967.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

四級頁表通俗講解與實踐(以 64 位 ARM Cortex-A 為例)

📖 🎥 B 站博文精講視頻:點擊鏈接,配合視頻深度學習 四級頁表通俗講解與實踐(以 64 位 ARM Cortex-A 為例) 本文面向希望徹底理解現代 64 位架構下四級頁表的開發者,結合 ARM Cortex-A 系列處理…

AI模型整合包上線!一鍵部署ComfyUI,2.19TB模型全解析

最近體驗了AIStarter平臺上線的AI模型整合包,包含2.19TB ComfyUI大模型,整合市面主流模型,一鍵部署ComfyUI,省去重復下載煩惱!以下是使用心得和部署步驟,適合AI開發者參考。工具亮點這款AI模型整合包由熊哥…

灰色優選模型及算法MATLAB代碼

電子裝備試驗方案優選是一個典型的多屬性決策問題,通常涉及指標復雜、信息不完整、數據量少且存在不確定性的特點。灰色系統理論(Grey System Theory)特別擅長處理“小樣本、貧信息”的不確定性問題,因此非常適合用于此類方案的優…

AI框架工具FastRTC快速上手6——視頻流案例之物體檢測(下)

一 前言 上一篇,我們實現了用YOLO對圖片上的物體進行檢測,并在圖片上框出具體的對象并打出標簽。但只是應用在單張圖片,且還沒用上FastRTC。 本篇,我們希望結合FastRTC的能力,實現基于YOLO的實時視頻流的物體檢測。 本篇文字將不會太多。學習完本篇,對比前面的文章,你…

PHP常見中高面試題匯總

一、 PHP部分 1、PHP如何實現靜態化 PHP的靜態化分為:純靜態和偽靜態。其中純靜態又分為:局部純靜態和全部純靜態。 PHP偽靜態:利用Apache mod_rewrite實現URL重寫的方法; PHP純靜態,就是生成HTML文件的方式&#xff0…

基于Java AI(人工智能)生成末日題材的實踐

Java AI 生成《全球末日》文章的實例 使用Java結合AI技術生成《全球末日》題材的文章可以通過多種方式實現,包括調用預訓練模型、使用自然語言處理庫或結合生成式AI框架。以下是30個實例的生成方法和示例代碼片段。 調用預訓練模型(如GPT-3或GPT-4) 使用OpenAI API生成末日…

針對軟件定義車載網絡的動態服務導向機制

我是穿拖鞋的漢子,魔都中堅持長期主義的汽車電子工程師。 老規矩,分享一段喜歡的文字,避免自己成為高知識低文化的工程師: 做到欲望極簡,了解自己的真實欲望,不受外在潮流的影響,不盲從,不跟風。把自己的精力全部用在自己。一是去掉多余,凡事找規律,基礎是誠信;二是…

Pytorch實現嬰兒哭聲檢測和識別

Pytorch實現嬰兒哭聲檢測和識別 目錄 Pytorch實現嬰兒哭聲檢測識別 1. 項目說明 2. 數據說明 (1)嬰兒哭聲語音數據集 (2)自定義數據集 3. 模型訓練 (1)項目安裝 (2)準備Tra…

海信IP810N/海信IP811N_海思MV320-安卓9.0主板-TTL燒錄包-可救磚

海信IP810N/海信IP811N_海思MV320處理器-安卓9主板-TTL燒錄包-可救磚準備工作:TTL線自備跑碼工具【putty跑碼中文版】路徑:【工具大全】-【putty跑碼中文版】測試跑碼以后將跑碼窗口關閉;然后到下方下載燒錄工具并大致看下教程燒錄…

Go 中的 interface{} 與 Java 中的 Object:相似之處與本質差異

在軟件系統開發中,“通用類型”的處理是各語言設計中不可忽視的一部分。Java 使用 Object,Go 使用 interface{},它們都可以容納任意類型的值,是實現動態行為或通用容器的基礎類型。然而,雖然兩者在使用層面看似相似&am…

Docker-07.Docker基礎-數據卷掛載

一.案例首先我們通過一則案例來引出問題。我們要修改nginx容器內的html目錄下的index.html文件,并且要將靜態資源部署到nginx的html目錄,就要首先知道該html目錄的所在位置。我們首先查看nginx鏡像的幫助文檔,這里就是將有關靜態資源目錄的&a…

數據結構(三)雙向鏈表

一、什么是 make 工具?make 是一個自動化構建工具,主要用于管理 C/C 項目的編譯和鏈接過程。它通過讀取 Makefile 文件中定義的規則,自動判斷哪些文件被修改,并僅重新編譯這些部分,從而大幅提高構建效率。二、什么是 M…

如何在沒有iCloud的情況下將聯系人轉移到新iPhone?

升級到新 iPhone 后,設置已完成,想在不使用 iCloud 的情況下將聯系人從 iPhone 轉移到 iPhone 嗎?別擔心。還有其他 5 種方法可以幫助您輕松地將聯系人轉移到新 iPhone。這樣,您就無需再次重置新設備了。第 1 部分:如何…

SpringBoot3.x入門到精通系列:4.2 整合 Kafka 詳解

SpringBoot 3.x 整合 Kafka 詳解 🎯 Kafka簡介 Apache Kafka是一個分布式流處理平臺,主要用于構建實時數據管道和流應用程序。它具有高吞吐量、低延遲、可擴展性和容錯性等特點。 核心概念 Producer: 生產者,發送消息到Kafka集群Consumer: 消…

Android audio之 AudioDeviceInventory

1. 類介紹 AudioDeviceInventory 是 Android 音頻系統中的一個核心類,位于 frameworks/base/services/core/java/com/android/server/audio/ 路徑下。它負責 管理所有音頻設備的連接狀態,包括設備的添加、移除、狀態更新以及策略應用。 設備連接狀態管理:記錄所有已連接的音…

系統設計入門:成為更優秀的工程師

系統設計入門指南 動機 現在你可以學習如何設計大規模系統,為系統設計面試做準備。本指南包含的是一個有組織的資源集合,旨在幫助你了解如何構建可擴展的系統。 學習設計大規模系統 學習如何設計可擴展系統將幫助你成為更優秀的工程師。系統設計是一個…

Pandas數據分析工具基礎

文章目錄 0. 學習目標 1. Pandas的數據結構分析 1.1 Series - 序列 1.1.1 Series概念 1.1.2 Series類的構造方法 1.1.3 創建Series對象 1.1.3.1 基于列表創建Series對象 1.1.3.2 基于字典創建Series對象 1.1.4 獲取Series對象的數據 1.1.5 Series對象的運算 1.1.6 增刪Series對…

大模型——Qwen開源會寫中文的生圖模型Qwen-Image

Qwen開源會寫中文的生圖模型Qwen-Image 會寫中文,這基本上是開源圖片生成模型的獨一份了。 這次開源的Qwen-Image 的最大賣點是“像素級文字生成”。它能直接在像素空間內完成排版:從小字注腳到整版海報均可清晰呈現,且同時支持英文字母與漢字。 以下圖片均來自官網的生成…

大模型知識庫(1)京東云 JoyAgent介紹

一、核心定位? JoyAgent 是京東云推出的 ?首個 100% 開源的企業級多智能體平臺,定位為“可插拔的智能發動機”,旨在通過開箱即用的產品級能力,降低企業部署智能體的門檻。其特點包括: ?完整開源?:前端&#xff0…

PowerShell 入門2: 使用幫助系統

PowerShell 入門 2:使用幫助系統 🎯 一、認識 PowerShell 幫助系統 1. 使用 Get-Help 查看命令說明 Get-Help Get-Service或使用別名: gsv2. 更新幫助系統 Update-Help3. 搜索包含關鍵詞的命令(模糊搜索) Help *log*&a…