深度學習入門代碼詳細注釋-ResNet18分類螞蟻蜜蜂

? ? ? ?本項目將基于PyTorch平臺遷移ResNet18模型。該模型原采用ImageNet數據集(含1000個圖像類別)進行訓練。我們將嘗試運用該模型對螞蟻和蜜蜂進行分類(這兩個類別未包含在原訓練數據集中)。

? ? ? ?本文的原始代碼參考于博客深度學習入門項目——附代碼(持續更新) - 知乎,但是這位博主只給出了代碼,而沒有對代碼進行一些必要的注釋,這對剛入門的菜鳥新手來說不太友好,所以在這里,我對該代碼做了一些詳細的注釋,希望能夠幫助到和我一樣新入門的菜鳥。也同時感謝原作者對代碼整理所付出的勞動!

#加載所需要的庫
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLRimport torchvision
from torchvision import datasets
from torchvision import models
from torchvision import transformsimport numpy as npfrom io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFileimport matplotlib.pyplot as plt#調用urllib.request.urlopen從下載網址https://pytorch.tips/bee-zip中下載數據,并生成對象#zipresp;
#用IO流來進行操作,無需將zip文件下載到本地磁盤上
zipurl='https://pytorch.tips/bee-zip'
with urlopen(zipurl) as zipresp:with ZipFile(BytesIO(zipresp.read())) as zfile:zfile.extractall('./data')#定義訓練集的變換
train_transforms = transforms.Compose([       #transforms.Compose 是一個工具,用于將多個預
#處理操作組合成一個完整的流程。它接受一個列表,列表中的每個元素是一個預處理操作。# 隨機裁剪并調整大小到 224x224,比如原圖像大小為512x512,是一個隨機操作,每次裁剪的區域可
#能不同,裁剪出224x224transforms.RandomResizedCrop(224),         # 用于訓練集,增加數據的多樣性,幫助模型學習
#到更多的特征。# 隨機水平翻轉,概率為 0.5transforms.RandomHorizontalFlip(),     #它通過0.5的概率隨機水平翻轉圖像,增加了數據的多樣性,有助于提高模型的泛化能力和減少過擬合。一張貓的圖像無論向左看還是向右看,都是#貓通過隨機水平翻轉,模型可以學習到更多樣的圖像特征。# 將圖片轉換為張量transforms.ToTensor(),# 標準化處理,使用預定義的均值和標準差transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet  RGB三個通道的均值std=[0.229, 0.224, 0.225]        # ImageNet  RGB三個通道的標準差)])
#定義驗證集的變換
val_transforms = transforms.Compose([# 將圖片調整為 256x256transforms.Resize(256),               #是一個確定性操作,每次調整大小的結果是相同的。這
#確保了驗證集的圖像在每次運行時都保持一致,從而保證驗證結果的穩定性和可重復性。# 從中心裁剪出 224x224 的區域transforms.CenterCrop(224),    #這種組合操作確保了驗證集的圖像在每次運行時都保持一致,同
#時也能保證輸入到模型的圖像大小一致。# 將圖片轉換為張量transforms.ToTensor(),    #將圖像從 PIL 圖像或 NumPy 數組轉換為 PyTorch 張量。轉換后的張
#量形狀為 (C, H, W),其中 C 是通道數,H 是高度,W 是寬度。像素值會被歸一化到 [0, 1] 范圍。# 標準化處理,使用預定義的均值和標準差transforms.Normalize(mean=[0.485, 0.456, 0.406],    # ImageNet RGB三個通道的均值std=[0.229, 0.224, 0.225]         # ImageNet RGB三個通道的標準差)])
#加載數據
#數據集
train_dataset = datasets.ImageFolder(root = './data/hymenoptera_data/train',transform = train_transforms)val_dataset = datasets.ImageFolder(root = './data/hymenoptera_data/val',transform = val_transforms)#數據加載器
train_loader = DataLoader(train_dataset,batch_size=4,     #每個批次加載 4 個樣本。這意味著每次迭代會返回 4 ##個樣本及其對應的標簽shuffle=True,    #每個 epoch 開始時隨機打亂數據。雖然驗證集通常不需要打亂,但在某些情況下#打亂數據可以避免驗證結果的偏差。num_workers=4                    #如果只用一個進程加載數據,GPU每次只能處理一張圖像,處
#理完一張后再處理下一張。而使用 4 個子進程時,可以同時處理 4 張圖像,這樣可以顯著減少數據加#載的#總時間。
)
val_loader = DataLoader(val_dataset,batch_size=4,shuffle=True,num_workers=4
)#生成ResNet18模型
#模型
model = models.resnet18(pretrained = True)#這是 PyTorch 提供的預定義 ResNet18 模型。pretrained=True:表示加載預訓練的權重。這些權重是在 ImageNet 數據集上訓練得到的,通常 #用于遷#移學習。
print(model.fc)                                                     #fc 是 ResNet18 模型中#的最后一個全連接層。默認情況下,ResNet18 的全連接層有 1000 個輸出節點,對應于 ImageNet 數據
#集的 1000 個類別。
model.fc = nn.Linear(model.fc.in_features, 2) #model.fc.in_features:獲取原全連接層的輸入特#征數量。ResNet18 的全連接層輸入特征數量為 512。nn.Linear(model.fc.in_features, 2):創建一
#個新的全連接層,輸入特征數量保持不變,輸出特征數量改為 2。這通常用于二分類任務。
print(model.fc)#定義超參數
model = model.to("cuda")#將模型的所有參數和緩沖區移動到 GPU 上。這使得模型可以在 GPU 上進行訓#練,從而顯著提高訓練速度。Loss = nn.CrossEntropyLoss()#這是 PyTorch 提供的交叉熵損失函數,通常用于多分類任務。它結合了 LogSoftmax 和 NLLLoss,適用于分類任務。optim = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# optim.SGD:這是隨機梯度下降(SGD)優化器。model.parameters():獲取模型的所有可訓練參#數。model.parameters():lr=0.001:設置學習率為 0.001。momentum=0.9:設置動量為 0.9,動量可#以幫助優化器更快地收斂,并減少振蕩。exp_lr_scheduler = StepLR(optim, step_size=7, gamma=0.1)#StepLR:這是 PyTorch 提供的學習率#調度器,用于在訓練過程中調整學習率。step_size=7:每 7 個 epoch,學習率會調整一次。
#gamma=0.1:每次調整學習率時,學習率會乘以 0.1。例如,如果初始學習率為 0.001,那么在第 7 個 #epoch 時,學習率會變為 0.0001。#可微調
#訓練
num_epochs = 25 #定義了訓練的總輪數,即模型將完整地遍歷訓練數據集的次數。在這里,設置為 25輪。
for epoch in range(num_epochs):#訓練模型model.train()                                   #將模型設置為訓練模式。這會影響某些層的
#行為,如 Dropout 和 BatchNorm,確保它們在訓練時的行為與評估時不同。running_loss = 0.0                         #初始化running_loss:用于累計每個 epoch 的總#損失。running_corrects = 0                    #初始化running_corrects:用于累計每個 epoch 中
#正確預測的數量。for inputs, labels in train_loader: #train_loader:數據加載器,每次迭代返回一個批次的數
#據和標簽。inputs = inputs.to("cuda")           #將輸入數據移動到 GPU 上。labels = labels.to("cuda")            #將標簽數據移動到 GPU 上。outputs = model(inputs) #將輸入數據通過模型進行前向傳播,得到模型的輸出。_, preds = torch.max(outputs, 1) #獲取模型輸出的最大值及其索引。preds 是預測的類別索#引。loss = Loss(outputs, labels)          #計算模型輸出和真實標簽之間的損失值loss.backward()                            #反向傳播,計算損失值的梯度,并將其傳播回#模型的每個參數。optim.step()                                #根據計算得到的梯度更新模型的參數。optim.zero_grad()                       #清空之前的梯度,避免梯度累積# 累計損失和正確預測數running_loss += loss.item()/inputs.size(0)    #獲取損失值的標量值。running_corrects += torch.sum(preds == labels.data)/inputs.size(0) #torch.sum(preds == labels.data):計算預測正確的數量。inputs.size(0):獲取當前批次的樣本數量。exp_lr_scheduler.step() #根據學習率調度器的設置更新學習率。train_epoch_loss = running_loss/len(train_loader) #計算每個 epoch 的平均損失。train_epoch_acc = running_corrects/len(train_loader)#計算每個 epoch 的平均準確率。#測試模型model.eval()running_loss = 0.0running_corrects = 0for inputs, labels in val_loader:inputs = inputs.to("cuda")labels = labels.to("cuda")outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = Loss(outputs, labels)running_loss += loss.item()/inputs.size(0)running_corrects += torch.sum(preds == labels.data)/inputs.size(0)epoch_loss = running_loss/len(val_loader)epoch_acc = running_corrects/len(val_loader)if((epoch+1)%5==0): #每隔 5 個 epoch,輸出當前的訓練和驗證的損失和準確率。print("epoch:{},  ""Train_loss:{:.4f} Train_acc:{:.4f},  ""Loss:{:.4f}, Acc:{:.4f}".format(epoch+1, train_epoch_loss, train_epoch_acc, epoch_loss, epoch_acc))#測試可視化 這段代碼定義了一個函數 imshow,用于將經過預處理的圖像數據可視化,并顯示其預測的類
#別。它還展示了如何使用這個函數來可視化驗證集中的圖像及其預測結果
def imshow(inp, title=None):#從 C-H-W 切換回 H-W-C 圖像格式inp = inp.numpy().transpose((1, 2, 0)) #將輸入的 PyTorch 張量轉換為 NumPy 數組   從通道#優先格式(C-H-W)轉換為高度-寬度-通道格式(H-W-C),以便使用 matplotlib 進行可視化。#撤銷歸一化mean = np.array([0.485, 0.456, 0.406]) #均值數組,用于撤銷歸一化。std = np.array([0.229, 0.224, 0.225])  #標準差數組,用于撤銷歸一化。inp = std * inp + mean  #撤銷歸一化操作,將圖像數據恢復到原始范圍。inp = np.clip(inp, 0, 1) #將圖像數據裁剪到 [0, 1] 范圍內,確保數據有效。plt.imshow(inp) #使用 matplotlib 的 imshow 函數顯示圖像。if title is not None:plt.title(title)inputs , classes = next(iter(val_loader))  #從驗證集加載器中獲取一個批次的數據和標簽。
out = torchvision.utils.make_grid(inputs)    #將一個批次的圖像拼接成一個網格,便于可視化
class_names = val_dataset.classes  #val_dataset.classes:獲取驗證集中的類別名稱列表
outputs = model(inputs.to("cuda"))           #inputs.to("cuda"):將輸入圖像移動到 GPU。model(inputs.to("cuda")):將輸入圖像通過模型進行前向傳播,得到模型的輸出。_, preds = torch.max(outputs, 1)   #獲取模型輸出的最大值及其索引,preds 是預測的類別索引。imshow(out, title=[class_names[x] for x in preds])  #根據預測的類別索引獲取對應的類別名稱

????????????????

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/89056.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/89056.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/89056.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

北京飲馬河科技公司 Java 實習面經

北京飲馬河科技公司 Java 實習面經 本文作者:程序員小白條 本站地址:https://xbt.xiaobaitiao.top 1) 面試官:我看你這塊是有一個開源的項目,這個項目主要是做什么的? 我:主要兩點是亮點&…

java基礎(day07)

目錄 OOP編程 方法 方法的調用: 在main入口函數中調用: 動態參數: 方法重載 OOP編程 方法 概念:指為獲得某種東西或達到某種目的而采取的手段與行為方式。有時候被稱作“方法”,有時候被稱作“函數”。例如UUID.…

使用EasyExcel動態合并單元格(模板方法)

1、導入EasyExcel依賴<dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>4.0.3</version> </dependency>2、編寫實體類Data publci class Student{ ExcelProperty("姓名")pri…

jenkins 流水線比較簡單直觀的

//全篇沒用自定義變量pipeline {agent any// 使用工具自動配置Node.js環境tools {nodejs nodejs22 // 需在Jenkins全局工具中預配置該名稱的Node.js安裝}//下面拉取代碼通過的是流水線片段生成的stages {stage(Checkout Code) {steps {git branch: release-v1.2.6,credentials…

CV目標檢測中的LetterBox操作

LetterBox類比理解&#xff1a;想象你要把一張任意形狀的照片放進一個正方形的相框里&#xff0c;照片不能變形拉伸&#xff0c;所以你先等比例縮小照片&#xff0c;然后在空余的地方填上灰色背景。第1章 數學原理當我們有一個原始圖像的尺寸為 19201080&#xff08;寬高&#…

Leetcode 3614. Process String with Special Operations II

Leetcode 3614. Process String with Special Operations II 1. 解題思路2. 代碼實現 題目鏈接&#xff1a;3614. Process String with Special Operations II 1. 解題思路 這一題思路上是一個逆推的思路。 首先&#xff0c;我們順序走一輪不難得到最終我們能夠獲得的字符串…

.NET ExpandoObject 技術原理解析

&#x1f31f; .NET ExpandoObject 技術原理解析 引用&#xff1a; .NET 剖析4.0上ExpandoObject動態擴展對象原理風瀟瀟人渺渺快意刀山中草 #mermaid-svg-RtpHctpdchPPN1Xo {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mer…

放蘋果(信息學奧賽一本通-T1192)

【題目描述】把M個同樣的蘋果放在N個同樣的盤子里&#xff0c;允許有的盤子空著不放&#xff0c;問共有多少種不同的分法&#xff1f;&#xff08;用K表示&#xff09;5&#xff0c;1&#xff0c;1和1&#xff0c;5&#xff0c;1 是同一種分法。【輸入】第一行是測試數據的數目…

(懶人救星版)CNN_Kriging_NSGA2_Topsis(多模型融合典范)深度學習+SCI熱點模型+多目標+熵權法 全網首例,完全原創,早用早發SCI

全網首例&#xff0c;完全原創&#xff0c;早用早發SCI&#xff08;多模型融合典范&#xff09;機器學習SCI熱點模型多目標熵權法(懶人救星版)BP_Kriging_NSGA2_Topsis 改進克里金工作量大&#xff1a;多模型融合創新性&#xff1a;首次結合BP神經網絡和克里金多目標利用 BP神…

LeetCode熱題100【第一天】

第一題 兩數之和 給定一個整數數組 nums 和一個整數目標值 target&#xff0c;請你在該數組中找出 和為目標值 target 的那 兩個 整數&#xff0c;并返回它們的數組下標。 你可以假設每種輸入只會對應一個答案&#xff0c;并且你不能使用兩次相同的元素。 你可以按任意順序返回…

AI Linux 運維筆記

運維基本概念 IT運維是指通過專業技術手段&#xff0c;確保企業的IT系統和網絡持續、安全、穩定運行&#xff0c;保障業務的連續性。運維涵蓋計算機網絡、應用系統、硬件環境和服務流程的綜合管理。主要分為: 系統運維、數據庫運維、自動化運維、容器運維、云計算運維、信創運維…

Redis性能基準測試

基準環境 機器&#xff1a;AWS EC2 c4.8xlarge&#xff08;同機部署 Redis Server 與 ReJSONBenchmark 工具&#xff0c;通過網絡棧連接&#xff09;測試工具&#xff1a;ReJSONBenchmark&#xff08;Go 實現、可配置并發&#xff09;模式&#xff1a;非管線&#xff08;non-pi…

XML外部實體注入與修復方案

XML外部實體注入&#xff08;XXE&#xff09;是一種嚴重的安全漏洞&#xff0c;攻擊者利用XML解析器處理外部實體的功能來讀取服務器內部文件、執行遠程請求&#xff08;SSRF&#xff09;、掃描內網端口或發起拒絕服務攻擊。以下是詳細解釋和修復方案&#xff1a;XXE 攻擊原理外…

解決高并發場景中的連接延遲:TCP 優化與隊頭阻塞問題剖析

你是否在高并發場景下遇到過這種情況&#xff1a;系統性能本來不錯&#xff0c;但在請求量大增的時刻&#xff0c;連接延遲暴漲&#xff0c;響應時間直線飆升&#xff0c;甚至整個服務都變得不可用&#xff1f;當你打開監控時&#xff0c;CPU、內存、帶寬都在正常范圍內&#x…

Web學習筆記4

CSS概述1、CSS簡介CSS&#xff0c;層疊樣式表&#xff0c;是一種樣式表語言&#xff0c;用以描述HTML的呈現內容的方式&#xff08;美化網頁&#xff09;。CSS書寫規則是&#xff1a;選擇器{屬性名&#xff1a;屬性值}的鍵值對CSS有三種引入方式&#xff0c;分別為&#xff1a;…

Spring AI 初學者指南:從入門到實踐與常用大模型介紹

作為 Java 開發者&#xff0c;當 AI 浪潮席卷而來時&#xff0c;如何在熟悉的 Spring 生態中快速擁抱大模型開發&#xff1f;Spring AI 的出現給出了答案。本文將從初學者視角出發&#xff0c;帶你了解 Spring AI 的核心概念、使用方法&#xff0c;并介紹與之搭配的常用大模型&…

C#自定義控件

1。C#中控件和組件的區別&#xff1a; 一般組件派生于&#xff1a;Component類&#xff0c;所以從此類派生出的稱之為組件。 一般用戶控件派生于:Control類或UserControl類&#xff0c;所以從該類派生出的稱之為用戶控件。 他們之間的關系主要是&#xff1a;UserControl繼承Con…

網絡資產測繪工具全景解析:七大平臺深度洞察

?一、資產測繪工具的核心價值?網絡資產測繪&#xff08;Cyber Asset Intelligence&#xff09;技術通過主動掃描與被動分析&#xff1a;實時發現全球暴露的網絡設備&#xff08;服務器、攝像頭、IoT設備&#xff09;自動化構建資產指紋庫&#xff08;操作系統/服務/框架版本&…

編程語言設計目的與側重點全解析(主流語言深度總結)

編程語言的設計本質上是對計算邏輯的形式化表達與工程約束的平衡&#xff0c;不同語言因目標場景、時代需求和技術哲學的差異&#xff0c;形成了獨特的設計范式。以下從系統級、應用級、腳本/動態、函數式、并發/安全等維度&#xff0c;選取10種最具代表性的編程語言&#xff0…

重學前端003 --- 響應式網頁設計 CSS 顏色

文章目錄文檔聲明head顏色模型div根據在這里 Freecodecamp 實踐&#xff0c;記錄筆記總結。 文檔聲明 在文檔頂部添加 DOCTYPE html 聲明 <!DOCTYPE html>head title 元素為搜索引擎提供了有關頁面的額外信息。 它還通過以下兩種方式顯示 title 元素的內容&#xff1a…