《深度學習》卷積神經網絡:數據增強與保存最優模型解析及實現

目錄

一、數據增強

1. 核心概念

2. 核心目的

3. 常用方法

4. 實現示例(基于 PyTorch)

5. 自定義數據集加載

二、保存最優模型

1. 核心概念

2. 實現步驟

(1)定義 CNN 模型

(2)定義訓練與測試函數

(3)啟動訓練

3. 模型加載與使用

三、總結


在卷積神經網絡(CNN)的訓練過程中,數據增強和模型保存是提升性能與實用性的關鍵環節。以下結合理論與實例,詳細解析其原理及實現方式。

一、數據增強
1. 核心概念

數據增強是通過對原始訓練數據進行隨機變換(如旋轉、翻轉、調整亮度等),生成新的訓練樣本的技術。其本質是擴展數據多樣性,讓模型在訓練中接觸更多 “變體”,從而提升泛化能力(減少過擬合)。

2. 核心目的
  • 模擬真實場景中的變量(如光照變化、視角差異、遮擋等)。
  • 解決訓練數據不足的問題,通過 “人工擴充” 提升模型魯棒性。
3. 常用方法

4. 實現示例(基于 PyTorch)
import torch
from torchvision import transforms# 定義訓練集和驗證集的數據增強策略
data_transforms = {'train': transforms.Compose([transforms.Resize([300, 300]),  # 縮放圖像transforms.RandomRotation(45),  # 隨機旋轉(-45°~45°)transforms.CenterCrop(256),     # 中心裁剪至256x256transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻轉transforms.ColorJitter(brightness=0.2, contrast=0.1),  # 顏色調整transforms.ToTensor(),  # 轉換為Tensor(像素值歸一化到[0,1])transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 標準化]),'valid': transforms.Compose([transforms.Resize([256, 256]),  # 驗證集僅縮放,不做隨機增強transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
}
5. 自定義數據集加載

通過繼承Dataset類,將增強策略應用于實際數據:

from torch.utils.data import Dataset
from PIL import Image
import numpy as npclass FoodDataset(Dataset):def __init__(self, file_path, transform=None):self.transform = transformself.imgs = []self.labels = []# 從txt文件讀取圖像路徑和標簽(格式:圖像路徑 標簽)with open(file_path, 'r') as f:for line in f.readlines():img_path, label = line.strip().split(' ')self.imgs.append(img_path)self.labels.append(int(label))def __len__(self):return len(self.imgs)def __getitem__(self, idx):# 加載圖像并應用增強image = Image.open(self.imgs[idx]).convert('RGB')if self.transform:image = self.transform(image)# 標簽轉換為Tensorlabel = torch.tensor(self.labels[idx], dtype=torch.long)return image, label# 加載訓練集和驗證集
train_dataset = FoodDataset('./train.1txt', transform=data_transforms['train'])
valid_dataset = FoodDataset('./test.1txt', transform=data_transforms['valid'])# 數據加載器(批量處理)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, shuffle=False)

其中train.1txt文件內容為:

test.1txt文件內容:

??其中的每個文件地址都有其對應的圖片,數據量較大,訓練時間會較長,如需使用,可私信發送打包文件。

? ? ? ? 整篇文章所有代碼連接為一份完整代碼。

二、保存最優模型
1. 核心概念

訓練過程中,模型性能(如驗證集準確率)會隨迭代波動。保存最優模型指在訓練中跟蹤關鍵指標(如最高準確率),并保存對應狀態,以便后續直接使用最佳模型。

2. 實現步驟
(1)定義 CNN 模型
import torch.nn as nnclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷積層1:3通道輸入→16通道輸出,5x5卷積核self.conv1 = nn.Sequential(nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2)  # 池化后尺寸減半)# 卷積層2:16通道→32通道self.conv2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.ReLU(),nn.MaxPool2d(2))# 卷積層3:32通道→128通道(無池化)self.conv3 = nn.Sequential(nn.Conv2d(32, 128, kernel_size=5, stride=1, padding=2),nn.ReLU())# 全連接層:輸入為128×64×64(經3次卷積+池化后的尺寸),輸出20類self.fc = nn.Linear(128 * 64 * 64, 20)def forward(self, x):x = self.conv1(x)  # 輸出:16×128×128x = self.conv2(x)  # 輸出:32×64×64x = self.conv3(x)  # 輸出:128×64×64x = x.view(x.size(0), -1)  # 展平為向量x = self.fc(x)return x# 初始化模型并移動到設備(GPU/CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleCNN().to(device)

運行結果:

(2)定義訓練與測試函數
# 損失函數與優化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 訓練函數
def train(dataloader, model, loss_fn, optimizer):model.train()  # 開啟訓練模式(啟用 dropout/batchnorm)for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)pred = model(X)loss = loss_fn(pred, y)# 反向傳播optimizer.zero_grad()  # 清空梯度loss.backward()        # 計算梯度optimizer.step()       # 更新參數if batch % 100 == 0:print(f"Batch {batch}, Loss: {loss.item():.4f}")# 測試函數(含最優模型保存)
best_acc = 0.0  # 記錄最佳準確率def test(dataloader, model, loss_fn):global best_accmodel.eval()  # 開啟評估模式(固定 dropout/batchnorm)size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():  # 關閉梯度計算,節省內存for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test: Accuracy: {(100*correct):.1f}%, Avg loss: {test_loss:.4f}")# 保存最優模型(準確率提升時)if correct > best_acc:best_acc = correct# 保存完整模型(含結構和參數)torch.save(model, "best_model.pt")# 或僅保存參數(更輕量):torch.save(model.state_dict(), "best_model.pth")
(3)啟動訓練
epochs = 150  # 訓練輪數
for t in range(epochs):print(f"\nEpoch {t+1}/{epochs}")train(train_loader, model, loss_fn, optimizer)test(valid_loader, model, loss_fn)
print("訓練完成!最優模型已保存為 best_model.pt")
3. 模型加載與使用

訓練結束后,可直接加載最優模型進行預測:

# 加載保存的模型
loaded_model = torch.load("best_model.pt").to(device)
loaded_model.eval()  # 切換至評估模式# 示例:對單張圖像預測
def predict(image_path):image = Image.open(image_path).convert('RGB')# 應用驗證集的預處理transform = data_transforms['valid']image = transform(image).unsqueeze(0).to(device)  # 增加批次維度with torch.no_grad():pred = loaded_model(image)return pred.argmax(1).item()  # 返回預測類別# 測試預測
print("預測類別:", predict("test_image.jpg"))

?訓練結束得到當前訓練的最優模型,其為pt\pth\t7文件,此時該文件即為當前模型,可直接調用該文件使用。

三、總結
  • 數據增強通過模擬真實場景變化,提升模型泛化能力,需注意訓練集用隨機增強、驗證集僅做標準化。
  • 保存最優模型通過跟蹤驗證集指標(如準確率),保留性能最佳的模型狀態,避免訓練后期過擬合導致的性能下降。

以上方法可直接應用于圖像分類、目標檢測等 CNN 任務,實際使用時需根據數據集特點調整增強策略和模型結構。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/97859.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/97859.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/97859.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

tcpdump用法

tcpdump用法tcpdump一、什么是tcpdump二、命令格式與參數三、參數列表四、過濾規則組合邏輯運算符過濾器關鍵字理解 Flag 標識符五、常用例子tcpdump 一、什么是tcpdump 二、命令格式與參數 option 可選參數:將在后邊一一解釋。 proto 類過濾器:根據協…

平衡車 - 電機調速

🌈個人主頁:羽晨同學 💫個人格言:“成為自己未來的主人~” 在我們的這篇文章當中,我們主要想要實現的功能的是電機調速功能。在我們的這篇文章中,主要實現的是開環的功能,而非閉環,也就是不加…

從利潤率看價值:哪些公司值得長期持有?

💡 為什么盯緊利潤率? 投資者常常盯著營收增長,卻忽略了一個更關鍵的指標——利潤率。 收入可以靠規模“堆”出來,但利潤率卻是企業護城河的真實體現。心理學研究表明:當一個產品或服務被消費者認定為“不可替代”&a…

小迪web自用筆記25

傳統文件上傳:上傳至服務器本身硬盤。云存儲:借助云存儲oss對象存儲(只能被訪問,不可解析)Oss云存儲Access key與Access ID:有了這兩個東西之后就可以操作云存儲,可以向里面發數據了。這玩意兒泄…

分發餅干——很好的解釋模板

好的,孩子,我們來玩一個“喂餅干”的游戲。 0. 問題的本質是什么? 想象一下,你就是個超棒的家長,手里有幾塊大小不一的餅干,而面前有幾個餓著肚子的小朋友。每個小朋友都有一個最小的“胃口”值&#xff0c…

場景題:如果一個大型項目,某一個時間所有的CPU的已經被占用了,導致服務不可用,我們開發人員應該如何使服務器盡快恢復正常

問:如果一個大型項目,某一個時間所有的CPU的 已經被占用了,導致服務不可用,我們開發人員 應該如何使服務器盡快恢復正常答:應對CPU 100%導致服務不可用的緊急恢復流程面試官,如果遇到這種情況,我會立即按照…

Docker 安裝 RAGFlow保姆教程

前提條件 Ubuntu 服務器(20.04 或 22.04 LTS 推薦) 已安裝 Docker 和 Docker Compose 如果尚未安裝,請先運行以下命令:# 安裝 Docker curl -fsSL https://get.docker.com -o get-docker.sh sudo sh get-docker.sh # 將當前用戶加入 docker 組,避免每次都要 sudo sudo user…

為什么實際工程里 C++ 部署深度學習模型更常見?為什么大家更愛用 TensorRT?

很多人剛接觸深度學習模型部署的時候,都會習慣用 Python,因為訓練的時候就是 PyTorch、TensorFlow 啊,寫起來方便。但一到 實際工程,特別是工業設備、醫療影像、上位機系統這種場景,你會發現大多數人都轉向了 C 部署。…

深入理解 Java 集合框架:底層原理與實戰應用

在日常開發中,集合是 Java 中使用頻率最高的工具之一。從最常見的 ArrayList、HashMap 到更復雜的并發集合,幾乎每一個 Java 程序員都離不開集合框架。集合框架不僅提供了豐富的數據結構實現,還封裝了底層復雜的邏輯,讓開發者能夠…

爬取m3u8視頻完整教程

爬取步驟:1.先找到網頁源代碼2.從網頁源代碼中拿到m3u83.下載m3u84.讀取m3u8文件,下載視頻5.合并視頻首先我們來爬取一個星辰影院的電影:下面我以這個為例:我們需要在源代碼中找到m3u8這個url:緊接著我們利用下面的方法…

Python爬蟲實戰: 基于Scrapy的Amazon跨境電商選品數據爬蟲方案

概述與設計思路 利用Python的Scrapy框架進行大規模頁面抓取和結構化數據提取,配合aiohttp實現高并發請求,從而高效獲取Amazon平臺上的商品列表、詳情、評論等公開信息。通過對這些數據進行清洗與分析,可以識別出有潛力的商品,評估市場競爭程度,并跟蹤競爭對手的動態,為跨…

穩定版IM即時通訊 仿默往APP即時通訊im源碼聊天社交源碼支持二開原生開發獨立部署 含搭建教程

內容目錄一、詳細介紹二、效果展示1.部分代碼2.效果圖展示三、學習資料下載一、詳細介紹 技術開發語言: 后臺管理端:Java GO Mysql數據庫 安卓端:Java iOS端:ob PC端:c 功能簡單介紹: 單聊&#xff…

封裝一個redis獲取并解析數據的工具類

redis獲取并解析數據工具類實現代碼使用示例實現代碼 import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.StrUtil; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.TypeReference; import lom…

23種設計模式——策略模式 (Strategy Pattern)?詳解

?作者簡介:大家好,我是 Meteors., 向往著更加簡潔高效的代碼寫法與編程方式,持續分享Java技術內容。 🍎個人主頁:Meteors.的博客 💞當前專欄:設計模式 ?特色專欄:知識分享 &#x…

CI(持續集成)、CD(持續交付/部署)、CT(持續測試)、CICD、CICT

目錄 **CI、CD、CT 詳解與關系** **1. CI(Continuous Integration,持續集成)** **2. CD(Continuous Delivery/Deployment,持續交付/部署)** **持續交付(Continuous Delivery)** **持續部署(Continuous Deployment)** **3. CT(Continuous Testing,持續測試)** **4.…

【音視頻】WebRTC ICE 模塊深度剖析

原文鏈接: https://mp.weixin.qq.com/s?__bizMzIzMjY3MjYyOA&mid2247498075&idx2&sn6021a2f60b1e7c71ce4d7af6df0b9b89&chksme893e540dfe46c56323322e780d41aec1f851925cfce8b76b3f4d5cfddaa9c7cbb03a7ae4c25&scene178&cur_album_id314699…

linux0.12 head.s代碼解析

重新設置IDT和GDT,為256個中斷門設置默認的中斷處理函數檢查A20地址線是否啟用設置數學協處理器將main函數相關的參數壓棧設置分頁機制,將頁表映射到0~16MB的物理內存上返回main函數執行 源碼詳細注釋如下: /** linux/boot/head.s** (C) 1991 Linus T…

Maven動態控制版本號秘籍:高效發包部署,版本管理不再頭疼!

作者:唐叔在學習 專欄:唐叔的Java實踐 關鍵詞:Maven版本控制、versions插件、動態版本號、持續集成、自動化部署、Java項目管理 摘要:本文介紹如何使用Maven Versions插件動態控制項目版本號和依賴組件版本號,實現無需…

簡述:普瑞時空數據建庫軟件(國土變更建庫)之一(變更預檢查部分規則)

簡述:普瑞時空數據建庫軟件(國土變更建庫)之一(變更預檢查部分規則) 主要包括三種類型:常規檢查、行政區范圍檢查、20X異常滅失檢查 本blog地址:https://blog.csdn.net/hsg77

shell中命令小工具:cut、sort、uniq,tr的使用方式

提示:文章寫完后,目錄可以自動生成,如何生成可參考右邊的幫助文檔 文章目錄前言一、cut —— 按列或字符截取1. 常用選項2. 示例二、sort —— 排序(默認按行首字符升序)1. 常用選項常用 sort 命令選項三、uniq —— 去…