基于深度學習的圖像分類:使用ResNet實現高效分類

最近研學過程中發現了一個巨牛的人工智能學習網站,通俗易懂,風趣幽默,忍不住分享一下給大家。點擊鏈接跳轉到網站人工智能及編程語言學習教程。讀者們可以通過里面的文章詳細了解一下人工智能及其編程等教程和學習方法。下面開始對正文內容的介紹。

前言
圖像分類是計算機視覺領域中的一個基礎任務,其目標是將輸入的圖像分配到預定義的類別中。近年來,深度學習技術,尤其是卷積神經網絡(CNN),在圖像分類任務中取得了顯著的進展。ResNet(Residual Network)作為深度學習中的一個重要架構,通過引入殘差學習解決了深層網絡訓練中的梯度消失和梯度爆炸問題。本文將詳細介紹如何使用ResNet實現高效的圖像分類,從理論基礎到代碼實現,帶你一步步掌握圖像分類的完整流程。
一、圖像分類的基本概念
(一)圖像分類的定義
圖像分類是指將輸入的圖像分配到預定義的類別中的任務。圖像分類模型通常需要從大量的標注數據中學習,以便能夠準確地識別新圖像的類別。
(二)圖像分類的應用場景
1. ?醫學圖像分析:識別醫學圖像中的病變區域。
2. ?自動駕駛:識別道路標志、行人和車輛。
3. ?安防監控:識別監控視頻中的異常行為。
4. ?內容推薦:根據圖像內容推薦相關產品或服務。
二、ResNet架構的理論基礎
(一)殘差學習(Residual Learning)
ResNet的核心思想是引入殘差學習,通過跳躍連接(Skip Connections)將輸入直接傳遞到后面的層,從而解決了深層網絡訓練中的梯度消失和梯度爆炸問題。殘差學習的基本形式如下:
?H(x) = F(x) + x?
其中, H(x) ?是目標函數, F(x) ?是殘差函數, x ?是輸入。
(二)ResNet架構
ResNet通過堆疊多個殘差模塊(Residual Blocks)來構建深層網絡。每個殘差模塊包含兩個卷積層和一個跳躍連接,如下所示:
Input --> Conv2d --> ReLU --> Conv2d --> Add --> ReLU --> Output

跳躍連接將輸入直接加到卷積層的輸出上,從而保證了信息的傳遞。
(三)ResNet的優勢
1. ?緩解梯度消失問題:通過跳躍連接,梯度可以直接反向傳播到前面的層。
2. ?提高模型的訓練效率:深層網絡可以通過殘差學習更容易地訓練。
3. ?提高模型的性能:ResNet在多個圖像分類任務中取得了優異的性能。
三、代碼實現
(一)環境準備
在開始之前,確保你已經安裝了以下必要的庫:
? ?PyTorch
? ?torchvision
? ?matplotlib
? ?numpy
如果你還沒有安裝這些庫,可以通過以下命令安裝:

pip install torch torchvision matplotlib numpy

(二)加載數據集
我們將使用CIFAR-10數據集,這是一個經典的小型圖像分類數據集,包含10個類別。

import torch
import torchvision
import torchvision.transforms as transforms# 定義數據預處理
transform = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])# 加載訓練集和測試集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

(三)定義ResNet模型
以下是一個基于ResNet的圖像分類模型的實現:

import torch.nn as nn
import torch.nn.functional as Fclass ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)out = F.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=10):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(64)self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)self.fc = nn.Linear(512, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channelsreturn nn.Sequential(*layers)def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = F.avg_pool2d(out, 4)out = out.view(out.size(0), -1)out = self.fc(out)return outdef ResNet18():return ResNet(ResidualBlock, [2, 2, 2, 2])

(四)訓練模型
現在,我們使用訓練集數據來訓練ResNet模型。

import torch.optim as optim# 初始化模型和優化器
model = ResNet18()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 訓練模型
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch in train_loader:inputs, labels = batchoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

(五)評估模型
訓練完成后,我們在測試集上評估模型的性能。

def evaluate(model, loader, criterion):model.eval()total_loss = 0.0correct = 0total = 0with torch.no_grad():for batch in loader:inputs, labels = batchoutputs = model(inputs)loss = criterion(outputs, labels)total_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalreturn total_loss / len(loader), accuracytest_loss, test_acc = evaluate(model, test_loader, criterion)
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

四、總結
通過上述步驟,我們成功實現了一個基于ResNet的圖像分類模型,并在CIFAR-10數據集上進行了訓練和評估。ResNet通過殘差學習解決了深層網絡訓練中的梯度消失問題,顯著提高了模型的性能。你可以嘗試使用其他深度學習模型(如VGG、DenseNet等),或者在更大的數據集上應用圖像分類技術,探索更多有趣的應用場景。
如果你對圖像分類感興趣,或者有任何問題,歡迎在評論區留言!讓我們一起

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/92870.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/92870.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/92870.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

JVM:工具

JVMjpsjstatjmapjhatjstackjconsolejvisualvmjps jps( Java Virtual Machine Process Status Tool ),是 JDK 中的一個命令行工具,用于列出當前正在運行的 JVM 實例的信息。其對于監控和管理運行在多個 JVM 上的 Java 應用程序特別…

Elasticsearch Circuit Breaker 全面解析與最佳實踐

一、Circuit Breaker 簡介 Elasticsearch 是基于 JVM 的搜索引擎,其內存管理十分重要。為了避免單個操作或查詢耗費過多內存導致節點不可用,Elasticsearch 引入了 Circuit Breaker(熔斷器)機制。當內存使用達到熔斷器預設閾值時&a…

ARM-定時器-定時器函數封裝配置

以TIMER7為例&#xff0c;對定時器函數進行封裝注意事項&#xff1a;GD32中TIMER7是高級定時器&#xff0c;相關詳細請參考上一篇文章。main.c//main.c#include "gd32f4xx.h" #include "systick.h" #include <stdio.h> #include "main.h" …

【日志】unity俄羅斯方塊——邊界限制檢測

Bug修復記錄 項目場景 嘗試使用Unity獨自制作俄羅斯方塊&#xff08;也許很沒有必要&#xff0c;網上隨便一搜就有教程&#xff09; 問題描述 俄羅斯方塊的邊緣檢測出錯了&#xff0c;對方塊進行旋轉后&#xff0c;無法到達最左側或者最下側的位置&#xff0c;以及其他問題。演…

C++ string:準 STL Container

歷史STL 最初是一套獨立的泛型庫&#xff08;Alexander Stepanov 等人貢獻&#xff09;&#xff0c;后來被吸納進 C 標準庫&#xff1b;std::basic_string 則是早期 C 標準&#xff08;Cfront / ARM 時代&#xff09;就存在的“字符串類”&#xff0c;并非 STL 原生物。std::st…

Golang學習筆記--語言入門【Go-暑假學習筆記】

目錄 基礎語法部分相關概念 基礎語法部分概念詳解 可見性 導包 內部包 運算符 轉義字符 函數 風格 函數花括號換行 代碼縮進 代碼間隔 花括號省略 三元表達式 數據類型部分相關概念 數據類型部分概念詳解 布爾類型 整型 浮點型 復數類型 字符類型 派生類型…

linux中kill 命令使用詳解

在Linux系統里&#xff0c;kill命令的主要功能是向進程發送信號&#xff0c;以此來控制進程的運行狀態。下面為你詳細介紹它的使用方法&#xff1a; 基礎語法 kill [選項] [進程ID]進程ID也就是PID&#xff0c;可通過ps、pgrep或者top等命令來獲取。 常用信號及其含義 信號可以…

Nginx 安裝與 HTTPS 配置指南:使用 OpenSSL 搭建安全 Web 服務器

Nginx 安裝與 HTTPS 配置指南:使用 OpenSSL 搭建安全 Web 服務器 一、Nginx安裝 1. 安裝依賴項 sudo yum groupinstall "Development Tools" -y # 非必須 sudo yum install pcre pcre-devel zlib zlib-devel openssl openssl-devel -y2.下載Nginx wget http://n…

寫個 flask todo app,簡潔,實用

- 此項目雖然看起來簡單&#xff0c;實際上&#xff0c;修改成自己喜歡的樣子&#xff0c;也是費時間的。 - 別人都搞AI 相關的項目&#xff0c;而我還是搞這種基礎的東西。不要灰心。 - 積累。不論項目大小&#xff0c;不論難易&#xff0c;只看是否有用。項目地址&#xff1a…

4麥 360度定位

要在 ESP32 上用 4 個麥克風實現 360 聲源定位&#xff0c;通常思路是通過 時延估計&#xff08;TDOA&#xff09; 幾何計算&#xff0c;核心流程&#xff1a;陣列布置將 4 個麥克風等間距布置成正方形&#xff08;或圓形&#xff09;。記陣列中心為原點&#xff0c;麥克風編號…

使用yolov10模型檢測視頻中出現的行人,并保存為圖片

一、使用yolov10模型檢測視頻中出現的行人&#xff0c;并保存為圖片&#xff0c;detect_person.py代碼如下&#xff1a;from ultralytics import YOLOv10 import glob import os import cv2 import argparsedef detect_person(videoPath, savePath):if not os.path.exists(save…

現在希望用git將本地文件crawler目錄下的文件更新到遠程倉庫指定crawler目錄下,命名相同的文件本地文件將其覆蓋

git checkout main git pull origin main $source “D:\黑馬大數據學習\crawler” $dest Join-Path (Get-Location) “crawler” if (-not (Test-Path $dest)) { New-Item -ItemType Directory -Path $dest | Out-Null } Copy-Item -Path $source* -Destination $dest -Recur…

網絡調制技術對比表

&#x1f4ca; 網絡調制技術全維度對比表?調制技術??簡稱??頻譜效率??抗噪性??功率效率??復雜度??關鍵特性??典型應用場景??幅度鍵控?ASK低差高低/低電路簡單&#xff0c;易受干擾遙控器、光通信(OOK)?頻移鍵控?FSK低-中中中中/中抗噪較好&#xff0c;頻譜…

優化 Elasticsearch JVM 參數配置指南

一、概述 Elasticsearch 是基于 JVM 的搜索和分析引擎。JVM 參數的合理配置直接影響著 Elasticsearch 的性能和穩定性。盡管 Elasticsearch 已經提供了默認的 JVM 設置&#xff0c;但在某些特定場景下&#xff0c;我們可能需要進行適當的調整和優化。 本文將詳細講述如何安全、…

Python, Go 開發如何進入心流狀態APP

要開發一款基于Python和Go語言、幫助用戶進入“心流”狀態&#xff08;高度專注、高效愉悅的心理狀態&#xff09;的應用&#xff0c;需結合兩種語言的技術優勢&#xff08;Go的高并發與性能、Python的靈活性與AI生態&#xff09;及心流觸發機制&#xff08;清晰目標、即時反饋…

一文詳解手機WiFi模塊與連接

目錄 1 硬件模塊 1.1 Wifi射頻模 1.2 電源管理模塊 2 軟件與協議棧 2.1 系統服務層 2.2 認證與協議處理 3 連接流程 3.1 開啟WiFi與掃描 3.2 選擇網絡與認證 3.3 連接與IP分配 4 特殊連接方式 4.1 WPS快速連接 4.2 熱點模式&#xff08;AP模式&#xff09; 4.3 U…

Java 網絡編程詳解:從基礎到實戰,徹底掌握 TCP/UDP、Socket、HTTP 網絡通信

作為一名 Java 開發工程師&#xff0c;你一定在實際開發中遇到過需要與遠程服務器通信、實現客戶端/服務端架構、處理 HTTP 請求、構建分布式系統等場景。這時&#xff0c;Java 網絡編程&#xff08;Java Networking&#xff09; 就成為你必須掌握的核心技能之一。Java 提供了豐…

Java面試題(中等)

1. 計算機網絡傳輸層有哪些協議&#xff1f;分別適用于什么場景&#xff1f;TCP協議(傳輸控制協議)?&#xff1a;面向連接、可靠傳輸&#xff0c;流量控制、擁塞控制。適用于要求數據完整性的場景&#xff0c;如文件傳輸、網頁瀏覽、電子郵件等。UDP協議 (用戶數據報協議)?&a…

Apache 消息隊列分布式架構與原理

消息隊列 基本概念 定義 消息隊列&#xff08;Message Queue, MQ&#xff09;是一種分布式中間件&#xff0c;通過異步通信、消息暫存和解耦生產消費雙方的機制&#xff0c;提供消息的順序性保證、可靠投遞和流量控制能力&#xff0c;廣泛應用于微服務解耦、大數據流處理等場景…

ModernBERT如何突破BERT局限?情感分析全流程解析

自2018年推出以來&#xff0c;BERT 徹底改變了自然語言處理領域。它在情感分析、問答、語言推理等任務中表現優異。借助雙向訓練和基于Transformer的自注意力機制&#xff0c;BERT 開創了理解文本中單詞關系的新范式。然而&#xff0c;盡管成績斐然&#xff0c;BERT 仍存在局限…