30天打牢數模基礎-卷積神經網絡講解

案例代碼實現

一、代碼說明

本案例使用PyTorch實現一個改進版LeNet-5模型,用于CIFAR-10數據集的圖像分類任務。代碼包含以下核心步驟:

數據加載與預處理(含數據增強,劃分訓練/驗證/測試集);

定義CNN網絡結構(LeNet-5改進版,適配3通道輸入);

模型訓練(用驗證集評估泛化能力);

模型測試與結果可視化(用獨立測試集最終評估)。

適合人群:數模小白(無需深度學習基礎,代碼注釋詳細,邏輯清晰)。運行環境:Python3.8+、PyTorch1.10+、torchvision0.11+、matplotlib3.5+。

二、完整代碼實現

# 導入必要的庫
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np# ------------------------------
# 1. 配置全局參數(數模小白可調整這里)
# ------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 優先用GPU
BATCH_SIZE = 64  # 每批數據量(越大訓練越快,但占內存越多)
EPOCHS = 10  # 訓練輪數(越大模型越準,但訓練時間越長)
LEARNING_RATE = 0.001  # 學習率(越小收斂越穩,但訓練越慢)
VAL_SPLIT = 0.2  # 驗證集占訓練集的比例(20%)# ------------------------------
# 2. 數據加載與預處理(含數據增強,劃分訓練/驗證/測試集)
# ------------------------------
def load_data():"""加載CIFAR-10數據集,返回訓練/驗證/測試DataLoader"""# 訓練集數據增強(防止過擬合):隨機裁剪、水平翻轉、歸一化train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),  # 隨機裁剪32x32,邊緣補4像素transforms.RandomHorizontalFlip(),     # 隨機水平翻轉(50%概率)transforms.ToTensor(),                 # 轉為Tensor(0-1)transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 歸一化到[-1,1]])# 驗證集/測試集預處理(不增強,保持真實分布)val_test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 下載/加載數據集(第一次運行會下載,約170MB)full_train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)val_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=val_test_transform)test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_test_transform)# 劃分訓練集和驗證集(8:2)train_size = int((1 - VAL_SPLIT) * len(full_train_dataset))val_size = len(full_train_dataset) - train_sizetrain_dataset, _ = random_split(full_train_dataset, [train_size, val_size])_, val_dataset = random_split(val_dataset, [train_size, val_size])  # 保持驗證集transform正確# 生成DataLoader(批量加載數據)train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)return train_loader, val_loader, test_loader# ------------------------------
# 3. 定義CNN網絡結構(改進版LeNet-5)
# ------------------------------
class LeNet5(nn.Module):"""改進版LeNet-5,適配CIFAR-10的3通道輸入(3x32x32)"""def __init__(self):super(LeNet5, self).__init__()# 卷積層1:提取邊緣特征(3通道→6通道,5x5 kernel)self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)# 最大池化層1:簡化特征(2x2窗口,步長2)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 卷積層2:提取紋理/形狀特征(6通道→16通道,5x5 kernel)self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)# 最大池化層2:進一步簡化特征(2x2窗口,步長2)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全連接層1:整合高級特征(16*5*5→120)self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全連接層2:進一步整合特征(120→84)self.fc2 = nn.Linear(120, 84)# 輸出層:分類決策(84→10類,對應CIFAR-10標簽)self.fc3 = nn.Linear(84, 10)# 激活函數(ReLU,引入非線性,解決線性模型表達能力不足問題)self.relu = nn.ReLU()def forward(self, x):"""前向傳播:定義數據在網絡中的流動路徑"""# 卷積層1 → ReLU → 池化層1:3x32x32 → 6x28x28 → 6x14x14x = self.pool1(self.relu(self.conv1(x)))# 卷積層2 → ReLU → 池化層2:6x14x14 → 16x10x10 → 16x5x5x = self.pool2(self.relu(self.conv2(x)))# 展平:將二維特征圖轉為一維向量(16x5x5 → 400),適配全連接層x = x.view(-1, 16 * 5 * 5)# 全連接層1 → ReLU:400 → 120x = self.relu(self.fc1(x))# 全連接層2 → ReLU:120 → 84x = self.relu(self.fc2(x))# 輸出層:84 → 10(不使用Softmax,因為CrossEntropyLoss會自動處理)x = self.fc3(x)return x# ------------------------------
# 4. 模型訓練與驗證函數(用驗證集評估泛化能力)
# ------------------------------
def train_model(model, train_loader, val_loader, optimizer, criterion):"""訓練模型,每輪輸出訓練/驗證損失與準確率"""best_val_acc = 0.0  # 記錄最佳驗證準確率(用于保存最優模型)for epoch in range(EPOCHS):# ------------------------------# 訓練階段(更新模型參數)# ------------------------------model.train()  # 切換到訓練模式(啟用BatchNorm/ Dropout等訓練專用層)train_loss = 0.0train_correct = 0for inputs, labels in train_loader:inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)  # 數據移至GPU/CPUoptimizer.zero_grad()  # 清空梯度(避免梯度累積)outputs = model(inputs)  # 前向傳播:輸入→模型→輸出(預測值)loss = criterion(outputs, labels)  # 計算損失(預測值與真實值的差距)loss.backward()  # 反向傳播:計算梯度(從損失到各層參數)optimizer.step()  # 更新參數(用梯度調整參數,最小化損失)# 統計訓練損失與準確率train_loss += loss.item() * inputs.size(0)  # 累計損失(乘以批量大小,避免批量大小影響)_, preds = torch.max(outputs, 1)  # 取預測概率最大的類別(0-9)train_correct += (preds == labels).sum().item()  # 統計正確預測的樣本數# 計算訓練集平均損失與準確率train_loss = train_loss / len(train_loader.dataset)train_acc = train_correct / len(train_loader.dataset)# ------------------------------# 驗證階段(評估泛化能力,不更新參數)# ------------------------------model.eval()  # 切換到驗證模式(關閉BatchNorm/ Dropout等)val_loss = 0.0val_correct = 0with torch.no_grad():  # 關閉梯度計算(節省內存,加速驗證)for inputs, labels in val_loader:inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)outputs = model(inputs)loss = criterion(outputs, labels)# 統計驗證損失與準確率val_loss += loss.item() * inputs.size(0)_, preds = torch.max(outputs, 1)val_correct += (preds == labels).sum().item()# 計算驗證集平均損失與準確率val_loss = val_loss / len(val_loader.dataset)val_acc = val_correct / len(val_loader.dataset)# 打印本輪訓練/驗證結果print(f"Epoch {epoch+1}/{EPOCHS}")print(f"訓練集:損失={train_loss:.4f},準確率={train_acc:.4f}")print(f"驗證集:損失={val_loss:.4f},準確率={val_acc:.4f}")print("-" * 50)# 保存最佳模型(驗證準確率最高的模型,避免過擬合)if val_acc > best_val_acc:best_val_acc = val_acctorch.save(model.state_dict(), "best_model.pth")print(f"訓練結束,最佳驗證準確率={best_val_acc:.4f}(模型已保存至best_model.pth)")# ------------------------------
# 5. 模型測試與結果可視化(用獨立測試集最終評估)
# ------------------------------
def test_model(model, test_loader):"""用獨立測試集評估模型性能,輸出準確率并可視化預測結果"""model.eval()  # 切換到驗證模式test_correct = 0with torch.no_grad():  # 關閉梯度計算for inputs, labels in test_loader:inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)outputs = model(inputs)_, preds = torch.max(outputs, 1)test_correct += (preds == labels).sum().item()# 計算測試集準確率test_acc = test_correct / len(test_loader.dataset)print(f"\n測試集最終準確率={test_acc:.4f}")# 可視化10張測試圖像的預測結果(直觀展示模型效果)class_names = ["飛機", "汽車", "鳥", "貓", "鹿", "狗", "青蛙", "馬", "船", "卡車"]inputs, labels = next(iter(test_loader))  # 取一批測試數據(BATCH_SIZE=64)inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)outputs = model(inputs)_, preds = torch.max(outputs, 1)# 繪制圖像(2行5列,顯示10張)plt.figure(figsize=(12, 6))for i in range(10):plt.subplot(2, 5, i+1)# 反歸一化:將[-1,1]轉回[0,1](方便顯示圖像)img = inputs[i].cpu().numpy().transpose((1, 2, 0))  # 轉為HWC格式(高度×寬度×通道)img = img * 0.5 + 0.5  # 反歸一化(原歸一化公式:img = (img - mean) / std → 反推:img = img * std + mean)plt.imshow(img)# 設置標題:真實標簽 vs 預測標簽plt.title(f"真實:{class_names[labels[i]]}\n預測:{class_names[preds[i]]}", fontsize=10)plt.axis("off")  # 隱藏坐標軸plt.tight_layout()  # 調整子圖間距plt.show()# ------------------------------
# 6. 主程序(整合所有步驟,執行訓練與測試)
# ------------------------------
if __name__ == "__main__":# 1. 加載數據(劃分訓練/驗證/測試集)print("正在加載數據...")train_loader, val_loader, test_loader = load_data()print(f"數據加載完成:\n- 訓練集大小:{len(train_loader.dataset)} \n- 驗證集大小:{len(val_loader.dataset)} \n- 測試集大小:{len(test_loader.dataset)}")# 2. 初始化模型、損失函數、優化器print("\n正在初始化模型...")model = LeNet5().to(DEVICE)  # 將模型移至GPU/CPUcriterion = nn.CrossEntropyLoss()  # 交叉熵損失(適用于多分類任務)optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)  # Adam優化器(自適應學習率,收斂更穩定)# 3. 訓練模型(用驗證集評估)print("\n正在訓練模型...")train_model(model, train_loader, val_loader, optimizer, criterion)# 4. 加載最佳模型并測試(用獨立測試集)print("\n正在測試最佳模型...")model.load_state_dict(torch.load("best_model.pth"))  # 加載訓練過程中保存的最佳模型test_model(model, test_loader)

三、代碼使用說明

1.環境安裝

打開命令行,運行以下命令安裝依賴庫(建議使用虛擬環境):

pip install torch torchvision matplotlib numpy

2.運行代碼

將代碼保存為cnn_cifar10.py,在命令行中運行:

python?cnn_cifar10.py

3.結果解釋

訓練過程:每輪(Epoch)輸出訓練集(更新參數)和驗證集(評估泛化能力)的損失(Loss,越小說明預測越準)和準確率(Accuracy,越大說明模型越準)。

最佳模型:訓練結束后,保存驗證準確率最高的模型到best_model.pth(避免過擬合)。

測試結果:加載最佳模型后,用獨立測試集評估,輸出測試集準確率(一般在70%-85%之間,增加EPOCHS可提高),并顯示10張測試圖像的真實標簽預測標簽(直觀看到模型效果)。

四、數模小白調整建議

提高準確率:若訓練集準確率低(<80%),可增加EPOCHS(如改為20),讓模型多學習幾輪;或增大LEARNING_RATE(如改為0.002),加快收斂速度。

緩解過擬合:若驗證集準確率遠低于訓練集(如差 10% 以上),可添加更多數據增強(如transforms.RandomRotation(10)隨機旋轉 10 度、transforms.ColorJitter(brightness=0.2)調整亮度),或減小模型復雜度(如將conv1的out_channels=6改為3)。

加速訓練:若訓練太慢,可增大BATCH_SIZE(如改為128,需確保GPU內存足夠),或使用更高效的優化器(如optim.AdamW,帶權重衰減的Adam)。

五、常見問題解答

Q:為什么要劃分驗證集?A:驗證集用于在訓練過程中評估模型的泛化能力,避免模型“記住”訓練集細節(過擬合)。測試集是最終評估模型性能的“考題”,不能在訓練過程中使用。

Q:數據增強為什么有效?A:數據增強(如隨機裁剪、翻轉)通過生成“虛擬”訓練數據,擴大了訓練集的多樣性,讓模型學習到更通用的特征,從而提高泛化能力。

Q:為什么用Adam優化器而不是SGD?A:Adam優化器會為每個參數自適應調整學習率,比傳統SGD(隨機梯度下降)收斂更快、更穩定,適合新手使用。

通過運行這份代碼,你可以完整體驗CNN從數據預處理到模型部署的全流程,理解“卷積層提取特征、池化層簡化特征、全連接層做決策”的核心邏輯,為后續更復雜的深度學習模型(如ResNet、YOLO)打下基礎!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/92250.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/92250.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/92250.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Dev-C++——winAPI貪吃蛇小游戲

&#x1f680;歡迎互三&#x1f449;&#xff1a;霧狩 &#x1f48e;&#x1f48e; &#x1f680;關注博主&#xff0c;后期持續更新系列文章 &#x1f680;如果有錯誤感謝請大家批評指出&#xff0c;及時修改 &#x1f680;感謝大家點贊&#x1f44d;收藏?評論? 今天水一篇吧…

【openbmc6】entity-manager

文章目錄 2.1 事件監聽:dbus在linux上使用的底層通信方式多半是unix domain socket ,事件的到來可被抽象為:socket上有數據,可讀 2.2 事件處理:由于主線程肯定有邏輯得跑,因此新開一個線程甚至多個線程專門用來監聽和處理事件,但存在多線程就意味著可能存在競爭,存在競…

Java 實現 UDP 多發多收通信

在網絡通信領域&#xff0c;UDP&#xff08;用戶數據報協議&#xff09;以其無連接、高效率的特點&#xff0c;在實時通信場景中占據重要地位。本文將結合一段實現 UDP 多發多收的 Java 代碼&#xff0c;詳細解析其實現邏輯&#xff0c;幫助開發者深入理解 UDP 通信的底層邏輯與…

Java學習第六十二部分——Git

目錄 一、關鍵概述 二、核心概念 三、常用命令 四、優勢因素 五、應用方案 六、使用建議 一、關鍵概述 提問&#xff1a;Git 是什么&#xff1f; 回答&#xff1a;一句話&#xff0c;分布式版本控制系統&#xff08;DVCS&#xff09;&#xff0c;用來跟蹤文件&#…

CDN和DNS 在分布式系統中的作用

一、DNS&#xff1a;域名系統&#xff08;Domain Name System&#xff09; 1. 核心功能 DNS是互聯網的“地址簿”&#xff0c;負責將人類易記的域名&#xff08;如www.baidu.com&#xff09;解析為計算機可識別的IP地址&#xff08;如180.101.50.242&#xff09;。沒有DNS&…

uniapp用webview導入本地網頁,ios端打開頁面空白問題

目前還沒解決&#xff0c;DCloud官方也說不行 IOS下webview加載本地網頁時&#xff0c;無法加載資源 - DCloud問答

軟考 系統架構設計師系列知識點之面向服務架構設計理論與實踐(8)

接前一篇文章:軟考 系統架構設計師系列知識點之面向服務架構設計理論與實踐(7) 所屬章節: 第15章. 面向服務架構設計理論與實踐 第3節 SOA的參考架構 15.3 SOA的參考架構 IBM的Websphere業務集成參考架構(如圖15-2所示,以下簡稱參考架構)是典型的以服務為中心的企業集…

基于 Docker 及 Kubernetes 部署 vLLM:開啟機器學習模型服務的新篇章

在當今數字化浪潮中&#xff0c;機器學習模型的高效部署與管理成為眾多開發者和企業關注的焦點。vLLM 作為一款性能卓越的大型語言模型推理引擎&#xff0c;其在 Docker 及 Kubernetes 上的部署方式如何呢&#xff1f;本文將深入探討如何在 Docker 及 Kubernetes 集群中部署 vL…

工業互聯網六大安全挑戰的密碼“解法”

目錄 工業互聯網密碼技術應用Q&A Q1&#xff1a;設備身份認證與接入控制 Q2&#xff1a;通信數據加密與完整性保護 Q3&#xff1a;遠程安全訪問 Q4&#xff1a;平臺與數據安全 Q5&#xff1a;軟件與固件安全 Q6&#xff1a;日志審計與抗抵賴 首傳信安-解決方案 總…

基于springboot的在線問卷調查系統的設計與實現(源碼+論文)

一、開發環境 1 Java語言 Java語言是當今為止依然在編程語言行業具有生命力的常青樹之一。Java語言最原始的誕生&#xff0c;不僅僅是創造者感覺C語言在編程上面很麻煩&#xff0c;如果只是專注于業務邏輯的處理&#xff0c;會導致忽略了各種指針以及垃圾回收這些操作&#x…

民法學學習筆記(個人向) Part.1

民法學學習筆記(個人向) Part.1有關民法條文背后的事理、人心、經濟社會基礎&#xff1b;民法的結構民法學習的特色就是先學最難的民法總論&#xff0c;再學較難的物權法、合同法等&#xff0c;最后再學習最簡單的婚姻、繼承、侵權部分。這是一個由難到易的過程&#xff0c;尤為…

ElasticSearch Doc Values和Fielddata詳解

一、Doc Values介紹倒排索引在搜索包含指定 term 的文檔時效率極高&#xff0c;但在執行相反操作&#xff0c;比如查詢一個文檔中包含哪些 term&#xff0c;以及進行排序、聚合等與指定字段相關的操作時&#xff0c;表現就很差了&#xff0c;這時候就需要用到 Doc Values。倒排…

【C語言】解決VScode中文亂碼問題

文章目錄【C語言】解決VScode中文亂碼問題彈出無法寫入用戶設置的處理方法彈出無法在只讀編輯器編輯的問題處理方法【C語言】解決VScode中文亂碼問題 &#x1f4ac;歡迎交流&#xff1a;在學習過程中如果你有任何疑問或想法&#xff0c;歡迎在評論區留言&#xff0c;我們可以共…

MySQL筆記4

一、范式1.概念與意義范式&#xff08;Normal Form&#xff09;是數據庫設計需遵循的規范&#xff0c;解決“設計隨意導致后期重構困難”問題。主流有 三大范式&#xff08;1NF、2NF、3NF&#xff09;&#xff0c;還有進階的 BCNF、4NF、5NF 等&#xff0c;范式間是遞進依賴&am…

切比雪夫不等式的理解以及推導【超詳細筆記】

文章目錄參考教程一、意義1. 正態分布的 3σ 法則2. 不等式的含義3. 不等式的意義二、不等式的證明1. 馬爾科夫不等式馬爾可夫不等式證明(YYY 為非負隨機變量 &#xff09;2. 切比雪夫不等式推導參考教程 一個視頻&#xff0c;徹底理解切比雪夫不等式 一、意義 1. 正態分布的…

Spring Boot Jackson 序列化常用配置詳解

一、引言在當今的 Web 開發領域&#xff0c;JSON&#xff08;JavaScript Object Notation&#xff09;已然成為數據交換的中流砥柱。無論是前后端分離架構下前后端之間的數據交互&#xff0c;還是微服務架構里各個微服務之間的通信&#xff0c;JSON 都承擔著至關重要的角色 。它…

Jetpack ViewModel LiveData:現代Android架構組件的核心力量

引言在Android應用開發中&#xff0c;數據管理和界面更新一直是開發者面臨的重大挑戰。傳統的開發方式常常導致Activity和Fragment變得臃腫&#xff0c;難以維護&#xff0c;且無法優雅地處理配置變更&#xff08;如屏幕旋轉&#xff09;。Jetpack中的ViewModel和LiveData組件正…

Python數據分析案例79——基于征信數據開發信貸風控模型

背景 雖然模型基本都是表格數據那一套了&#xff0c;算法都沒什么新鮮點&#xff0c;但是本次數據還是很值得寫個案例的&#xff0c;有征信數據&#xff0c;各種&#xff0c;個人&#xff0c;機構&#xff0c;逾期匯總..... 這么多特征來做機器學習模型應該還不錯。本次帶來&…

板凳-------Mysql cookbook學習 (十二--------3_2)

3.3鏈接表 結構 P79頁 用一個類圖來表示EmployeeNode類的結構&#xff0c;展示其屬性和關系&#xff1a; plaintext ----------------------------------------- | EmployeeNode | ----------------------------------------- | - emp_no: int …

深度學習圖像預處理:統一輸入圖像尺寸方案

在實際訓練中&#xff0c;最常見也最簡單的做法&#xff0c;就是在送入網絡前把所有圖片「變形」到同一個分辨率&#xff08;比如 256256 或 224224&#xff09;&#xff0c;或者先裁剪&#xff0f;填充成同樣大小。具體而言&#xff0c;可以分成以下幾類方案&#xff1a;一、圖…