深度學習篇---MNIST:手寫數字數據集

下面我將詳細介紹使用 PyTorch 處理 MNIST 手寫數字數據集的完整流程,包括數據加載、模型定義、訓練和評估,并解釋每一行代碼的含義和注意事項。

整個流程可以分為五個主要步驟:準備工作、數據加載與預處理、模型定義、模型訓練和模型評估

# MNIST手寫數字數據集完整處理流程
# 包含數據加載、模型定義、訓練和評估的全步驟# 1. 導入必要的庫
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt# 2. 設置超參數
batch_size = 64       # 每次訓練的樣本數量
learning_rate = 0.001 # 學習率
num_epochs = 5        # 訓練輪數
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 注意:如果有GPU,會使用cuda加速訓練,否則使用CPU# 3. 數據預處理與加載
# 定義數據變換:將圖像轉為Tensor并標準化
transform = transforms.Compose([transforms.ToTensor(),  # 轉換為Tensor格式,像素值從0-255歸一化到0-1# 標準化處理:均值為0.1307,標準差為0.3081(MNIST數據集的統計特性)transforms.Normalize((0.1307,), (0.3081,))
])# 加載訓練集
train_dataset = datasets.MNIST(root='./data',        # 數據保存路徑train=True,           # True表示加載訓練集download=True,        # 如果數據不存在則自動下載transform=transform   # 應用上面定義的數據變換
)# 加載測試集
test_dataset = datasets.MNIST(root='./data',train=False,          # False表示加載測試集download=True,transform=transform
)# 創建數據加載器,用于批量加載數據
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True          # 訓練時打亂數據順序
)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False         # 測試時不需要打亂順序
)# 4. 可視化樣本數據(可選,用于理解數據)
def show_samples():# 獲取一些隨機的訓練樣本dataiter = iter(train_loader)images, labels = next(dataiter)# 顯示6個樣本plt.figure(figsize=(10, 4))for i in range(6):plt.subplot(1, 6, i+1)plt.imshow(images[i].numpy().squeeze(), cmap='gray')plt.title(f'Label: {labels[i].item()}')plt.axis('off')plt.show()# 調用函數顯示樣本
show_samples()# 5. 定義神經網絡模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 第一個卷積塊:卷積層 + 激活函數 + 池化層self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 第二個卷積塊self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全連接層self.fc1 = nn.Linear(7 * 7 * 64, 128)  # 經過兩次池化后,28x28變為7x7self.relu3 = nn.ReLU()self.fc2 = nn.Linear(128, 10)  # 10個輸出,對應0-9十個數字def forward(self, x):# 前向傳播過程x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = x.view(-1, 7 * 7 * 64)  # 展平操作x = self.relu3(self.fc1(x))x = self.fc2(x)return x# 初始化模型并移動到設備上
model = SimpleCNN().to(device)# 6. 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()  # 交叉熵損失,適合分類問題
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam優化器# 7. 訓練模型
def train_model():# 記錄訓練過程中的損失和準確率train_losses = []train_accuracies = []# 開始訓練model.train()  # 設置為訓練模式for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0# 遍歷訓練數據for i, (images, labels) in enumerate(train_loader):# 將數據移動到設備上images = images.to(device)labels = labels.to(device)# 清零梯度optimizer.zero_grad()# 前向傳播outputs = model(images)loss = criterion(outputs, labels)# 反向傳播和優化loss.backward()  # 計算梯度optimizer.step()  # 更新參數# 統計損失和準確率running_loss += loss.item()# 計算預測結果_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 每100個批次打印一次信息if (i + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], 'f'Loss: {running_loss/100:.4f}, Accuracy: {100*correct/total:.2f}%')running_loss = 0.0# 記錄每個epoch的平均損失和準確率epoch_loss = running_loss / len(train_loader)epoch_acc = 100 * correct / totaltrain_losses.append(epoch_loss)train_accuracies.append(epoch_acc)print(f'Epoch [{epoch+1}/{num_epochs}] completed. Training Accuracy: {epoch_acc:.2f}%')print('訓練完成!')return train_losses, train_accuracies# 調用訓練函數
train_losses, train_accuracies = train_model()# 8. 繪制訓練曲線
def plot_training_curves(losses, accuracies):plt.figure(figsize=(12, 5))# 損失曲線plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs+1), losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')# 準確率曲線plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs+1), accuracies)plt.title('Training Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.tight_layout()plt.show()# 繪制訓練曲線
plot_training_curves(train_losses, train_accuracies)# 9. 評估模型
def evaluate_model():model.eval()  # 設置為評估模式correct = 0total = 0# 不需要計算梯度with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'測試集準確率: {100 * correct / total:.2f}%')return 100 * correct / total# 評估模型在測試集上的表現
test_accuracy = evaluate_model()# 10. 可視化預測結果(可選)
def show_predictions():model.eval()dataiter = iter(test_loader)images, labels = next(dataiter)# 選擇前6個樣本進行預測images = images[:6].to(device)labels = labels[:6]with torch.no_grad():outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 顯示結果plt.figure(figsize=(10, 4))for i in range(6):plt.subplot(1, 6, i+1)plt.imshow(images[i].cpu().numpy().squeeze(), cmap='gray')plt.title(f'真實: {labels[i].item()}\n預測: {predicted[i].item()}')plt.axis('off')plt.tight_layout()plt.show()# 顯示預測結果
show_predictions()# 11. 保存模型(可選)
torch.save(model.state_dict(), 'mnist_cnn_model.pth')
print("模型已保存為 'mnist_cnn_model.pth'")

代碼詳細解釋與注意事項

1. 準備工作

  • 我們導入了 PyTorch 的核心庫以及處理圖像所需的工具
  • device設置會自動檢測是否有可用的 GPU,如果有則使用 GPU 加速訓練,否則使用 CPU

2. 數據加載與預處理

  • 數據變換 (transforms)

    • ToTensor()將圖像從 PIL 格式轉換為 PyTorch 的 Tensor 格式,并將像素值從 0-255 歸一化到 0-1 范圍
    • Normalize()進行標準化,使用的均值和標準差是 MNIST 數據集的統計特性,這有助于模型更快收斂
  • 數據集加載

    • datasets.MNIST會自動下載數據(如果本地沒有)并加載
    • train=True加載訓練集(60,000 張圖片),train=False加載測試集(10,000 張圖片)
  • DataLoader

    • 用于批量加載數據,支持自動打亂數據順序
    • batch_size=64表示每次處理 64 張圖片
    • 訓練時shuffle=True打亂數據順序,測試時shuffle=False保持順序

3. 模型定義

  • 我們定義了一個簡單的卷積神經網絡 (SimpleCNN),包含:

    • 兩個卷積塊:每個卷積塊由卷積層、ReLU 激活函數和池化層組成
    • 兩個全連接層:最后一層輸出 10 個值,對應 0-9 十個數字的預測概率
  • 卷積操作的作用:

    • 提取圖像的局部特征,如邊緣、紋理等
    • 池化層用于降低特征圖尺寸,減少計算量

4. 模型訓練

  • 損失函數:使用CrossEntropyLoss,適合多分類問題

  • 優化器:使用Adam優化器,比傳統的 SGD 收斂更快

  • 訓練過程中的關鍵步驟:

    1. 清零梯度:optimizer.zero_grad()
    2. 前向傳播:計算模型輸出和損失
    3. 反向傳播:loss.backward()計算梯度
    4. 更新參數:optimizer.step()應用梯度更新
  • 注意事項:

    • 訓練前調用model.train()設置為訓練模式
    • 定期打印損失和準確率,監控訓練進度
    • 將數據和模型移動到相同的設備上(CPU 或 GPU)

5. 模型評估

  • 評估時調用model.eval()設置為評估模式,這會關閉 dropout 等訓練特有的操作
  • 使用torch.no_grad()關閉梯度計算,節省內存并加速計算
  • 計算測試集上的準確率,評估模型的泛化能力

6. 常見問題與解決方法

  1. 訓練速度慢

    • 檢查是否使用了 GPU(代碼會自動檢測,但需要正確安裝 PyTorch GPU 版本)
    • 嘗試調大batch_size(受限于 GPU 內存)
  2. 過擬合

    • 增加訓練輪數
    • 添加正則化(如 Dropout)
    • 增加數據增強
  3. 準確率低

    • 檢查模型結構是否合理
    • 嘗試調整學習率
    • 增加訓練輪數

通過這個完整流程,你可以加載 MNIST 數據集,訓練一個卷積神經網絡對手寫數字進行分類,并評估模型性能。對于初學者來說,這個例子涵蓋了深度學習的基本流程和關鍵概念,是一個很好的入門練習。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/97947.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/97947.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/97947.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

k8s集群搭建(二)-------- 集群搭建

安裝 containerd 需要在集群內的每個節點上都安裝容器運行時&#xff08;containerd runtime&#xff09;&#xff0c;這個軟件是負責運行容器的軟件。 1. 啟動 ipv4 數據包轉發 # 設置所需的 sysctl 參數&#xff0c;參數在重新啟動后保持不變 cat <<EOF | sudo tee …

【Docker】P1 前言:容器化技術發展之路

目錄容器發展之路物理服務器時代&#xff1a;一機一應用的局限虛擬化時代&#xff1a;突破與局限并存容器化時代&#xff1a;輕量級的革新技術演進的價值體現各位&#xff0c;歡迎來到容器化時代。 容器發展之路 現代業務的核心是應用程序&#xff08;Application&#xff09;…

WPF依賴屬性和依賴屬性的包裝器:

依賴屬性是WPF&#xff08;Windows Presentation Foundation&#xff09;中的一種特殊類型的屬性&#xff0c;特別適用于內存使用優化和屬性值繼承。依賴屬性的定義包括以下幾個步驟&#xff1a; 使用 DependencyProperty.Register 方法注冊依賴屬性。 該方法需要四個參數&…

圖生圖算法

圖生圖算法研究細分&#xff1a;技術演進、應用與爭議 1. 基于GAN的傳統圖生圖方法 定義&#xff1a;利用生成對抗網絡&#xff08;GAN&#xff09;將輸入圖像轉換為目標域圖像&#xff08;如語義圖→照片、草圖→彩圖&#xff09;。關鍵發展與趨勢&#xff1a; Pix2Pix&#…

Go 自建庫的使用教程與測試

附加一個Go庫的實現&#xff0c;相較于Python&#xff0c;Go的實現更較為日常&#xff0c;不需要額外增加setup.py類的文件去額外定義,計算和并發的性能更加。 1. 創建 Go 模塊項目結構 首先創建完整的項目結構&#xff1a; gomathlib/ ├── go.mod ├── go.sum ├── cor…

What is a prototype network in few-shot learning?

A prototype network is a method used in few-shot learning to classify new data points when only a small number of labeled examples (the “shots”) are available per class. It works by creating a representative “prototype” for each class, which is typical…

Linux中用于線程/進程同步的核心函數——`sem_wait`函數

<摘要> sem_wait 是 POSIX 信號量操作函數&#xff0c;用于對信號量執行 P 操作&#xff08;等待、獲取&#xff09;。它的核心功能是原子地將信號量的值減 1。如果信號量的值大于 0&#xff0c;則減 1 并立即返回&#xff1b;如果信號量的值為 0&#xff0c;則調用線程&…

25高教社杯數模國賽【B題超高質量思路+問題分析】

注&#xff1a;本內容由”數模加油站“ 原創出品&#xff0c;雖無償分享&#xff0c;但創作不易。 歡迎參考teach&#xff0c;但請勿抄襲、盜賣或商用。 B 題 碳化硅外延層厚度的確定碳化硅作為一種新興的第三代半導體材料&#xff0c;以其優越的綜合性能表現正在受到越來越多…

【Linux篇章】再續傳輸層協議UDP :從低可靠到極速傳輸的協議重生之路,揭秘無連接通信的二次進化密碼!

&#x1f4cc;本篇摘要&#xff1a; 本篇將承接上次的UDP系列網絡編程&#xff0c;來深入認識下UDP協議的結構&#xff0c;特性&#xff0c;底層原理&#xff0c;注意事項及應用場景&#xff01; &#x1f3e0;歡迎拜訪&#x1f3e0;&#xff1a;點擊進入博主主頁 &#x1f4c…

《A Study of Probabilistic Password Models》(IEEE SP 2014)——論文閱讀

提出更高效的密碼評估工具&#xff0c;將統計語言建模技術引入密碼建模&#xff0c;系統評估各類概率密碼模型性能&#xff0c;打破PCFGw的 “最優模型” 認知。一、研究背景當前研究存在兩大關鍵問題&#xff1a;一是主流的 “猜測數圖” 計算成本極高&#xff0c;且難以覆蓋強…

校園外賣點餐系統(代碼+數據庫+LW)

摘要 隨著校園生活節奏的加快&#xff0c;學生對外賣的需求日益增長。然而&#xff0c;傳統的外賣服務存在諸多不便&#xff0c;如配送時間長、菜品選擇有限、信息更新不及時等。為解決這些問題&#xff0c;本研究開發了一款校園外賣點餐系統&#xff0c;采用前端 Vue、后端 S…

友思特案例 | 食品行業視覺檢測案例集錦(三)

食品制造質量檢測對保障消費者安全和產品質量穩定至關重要&#xff0c;覆蓋原材料至成品全階段&#xff0c;含過程中檢測與成品包裝檢測。近年人工智能深度學習及自動化系統正日益融入食品生產。本篇文章將介紹案例三&#xff1a;友思特Neuro-T深度學習平臺進行面餅質量檢測。在…

SQLynx 3.7 發布:數據庫管理工具的性能與交互雙重進化

目錄 &#x1f511; 核心功能更新 1. 單頁百萬級數據展示 2. 更安全的數據更新與刪除機制 3. 更智能的 SQL 代碼提示 4. 新增物化視圖與外表支持 5. 數據庫搜索與過濾功能重構 ? 總結與思考 在大數據與云原生應用快速發展的今天&#xff0c;數據庫管理工具不僅要“能用…

10G網速不是夢!5G-A如何“榨干”毫米波,跑出比5G快10倍的速度?

5G-A&#xff08;5G-Advanced&#xff09;網絡技術已經在中國福建省廈門市軟件園成功實現萬兆&#xff08;10Gbps&#xff09;速率驗證&#xff0c;標志著我國正式進入5G增強版商用階段。這一突破性成果不僅驗證了5G-A技術的可行性&#xff0c;也為6G網絡的發展奠定了堅實基礎。…

Linux筆記---UDP套接字實戰:簡易聊天室

1. 項目需求分析 我們要設計的是一個簡單的匿名聊天室&#xff0c;用戶的客戶端要求用戶輸入自己的昵稱之后即可在一個公共的群聊當中聊天。 為了簡單起見&#xff0c;我們設計用戶在終端當中與客戶端交互&#xff0c;而在一個文件當中顯式群聊信息&#xff1a; 當用戶輸入的…

RTP打包與解包全解析:從RFC規范到跨平臺輕量級RTSP服務和低延遲RTSP播放器實現

引言 在實時音視頻系統中&#xff0c;RTSP&#xff08;Real-Time Streaming Protocol&#xff09;負責會話與控制&#xff0c;而 RTP&#xff08;Real-time Transport Protocol&#xff09;負責媒體數據承載。開發者在實現跨平臺、低延遲的 RTSP 播放器或輕量級 RTSP 服務時&a…

Ubuntu 用戶和用戶組

一、 Linux 用戶linux 是一個多用戶操作系統&#xff0c;不同的用戶擁有不同的權限&#xff0c;可以查看和操作不同的文件。 Ubuntu 有三種用戶1、初次創建的用戶2、root 用戶---上帝3、普通用戶初次創建的用戶權限比普通用戶要多&#xff0c;但是沒有 root 用戶多。Linux 用戶…

FastGPT社區版大語言模型知識庫、Agent開源項目推薦

? FastGPT 項目說明 項目概述 FastGPT 是一個基于大語言模型&#xff08;LLM&#xff09;的知識庫問答系統&#xff0c;提供開箱即用的數據處理和模型調用能力&#xff0c;支持通過可視化工作流編排實現復雜問答場景。 技術架構 前端: Next.js TypeScript Chakra UI 后…

jsencrypt公鑰分段加密,支持后端解密

前端使用jsencryp實現分段加密。 解決長文本RSA加密報錯問題。 支持文本包含中文。 支持后端解密。前端加密代碼&#xff1a; // import { JSEncrypt } from jsencrypt const JSEncrypt require(jsencrypt) /*** 使用 JSEncrypt 實現分段 RSA 加密&#xff08;正確處理中文字符…

生成一份關于電腦電池使用情況、健康狀況和壽命估算的詳細 HTML 報告

核心作用 powercfg /batteryreport 是一個在 Windows 命令提示符或 PowerShell 中運行的命令。它的核心作用是&#xff1a;生成一份關于電腦電池使用情況、健康狀況和壽命估算的詳細 HTML 報告。 這份報告非常有用&#xff0c;特別是對于筆記本電腦用戶&#xff0c;它可以幫你&…