用matlab搭建一個簡單的圖像分類網絡

文章目錄

  • 1、數據集準備
  • 2、網絡搭建
  • 3、訓練網絡
  • 4、測試神經網絡
  • 5、進行預測
  • 6、完整代碼

1、數據集準備

首先準備一個包含十個數字文件夾的DigitsData,每個數字文件夾里包含1000張對應這個數字的圖片,圖片的尺寸都是 28×28×1 像素的,如下圖所示

在這里插入圖片描述

matlab 中imageDatastore 函數會根據文件夾名稱自動為圖像進行分類標注。該數據集包含 10 個類別。

% 創建一個圖像數據存儲對象 `imds`,用于從名為 "DigitsData" 的文件夾中加載圖像數據
imds = imageDatastore("DigitsData", ...IncludeSubfolders=true, ...  % 指定在加載數據時包含子文件夾中的圖像LabelSource="foldernames");  % 使用子文件夾的名稱作為圖像的標簽(自動分類)% 獲取數據集中所有的類別名稱(即文件夾名),并將其存儲在變量 classNames 中
classNames = categories(imds.Labels);  % 將 imds.Labels

將數據劃分為訓練集、驗證集和測試集。使用 70% 的圖像作為訓練數據,15% 作為驗證數據,15% 作為測試數據。指定使用 “randomized”(隨機化),以便從每個類別中按指定比例隨機分配圖像到新的數據集中。
splitEachLabel 函數用于將圖像數據存儲對象劃分成三個新的數據存儲對象。

% 使用 splitEachLabel 函數將原始圖像數據集 imds 隨機劃分為訓練集、驗證集和測試集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");
  • splitEachLabel:MATLAB 中的函數,用于根據每個標簽(類別)分別劃分圖像數據集。這樣可以確保每個類別在訓練集、驗證集和測試集中都有代表性。
  • imds:原始的圖像數據存儲對象,包含所有圖像和對應的標簽。
  • 0.7:表示將每個類別中 70% 的圖像用于訓練集
  • 0.15:表示每個類別中 15% 的圖像用于驗證集
  • 0.15:表示每個類別中 15% 的圖像用于測試集
  • "randomized":表示在劃分數據集時使用隨機抽樣,避免按文件順序導致劃分不均衡。
  • [imdsTrain, imdsValidation, imdsTest]:返回三個新的 imageDatastore 對象,分別代表:
    • imdsTrain:訓練數據集
    • imdsValidation:驗證數據集
    • imdsTest:測試數據集

2、網絡搭建

這里,我們需要借用到matlab工具欄里APPS里的Deep Network Designer,如下圖所示

在這里插入圖片描述

在Deep Network Designer, 我們創建一個空白Designer畫布

在這里插入圖片描述

然后我們可以拖動相應的層到Designer里,并連接各個層,如下圖所示

在這里插入圖片描述

這里,我們只需要改一下輸入層的InputSize就行,如下圖

在這里插入圖片描述

然后,我們可以檢查這個網絡可行不可行,通過Analyze按鈕,就會得到這個網絡的分析結果,如下圖

在這里插入圖片描述

沒有錯誤,就可以通過Export按鈕輸出這個網絡到Matlab工作區,這個網絡被自動被命名為net_1。
在這里插入圖片描述

3、訓練網絡

指定訓練選項。不同選項的選擇需要依賴實驗分析(即通過反復試驗和比較來確定最優配置)。

% 設置用于網絡訓練的選項,這里使用的是隨機梯度下降動量法(SGDM)
% 最大訓練輪數(epoch):訓練過程中將整個訓練集完整迭代 4 次
% 指定驗證數據集,用于在訓練過程中評估模型的泛化能力
% 每訓練 30 個 mini-batch 執行一次驗證評估
% 在訓練過程中顯示實時圖形界面,包括損失值和準確率的變化曲線
% 指定訓練期間關注的評估指標為準確率(accuracy)
% 禁止在命令行窗口輸出詳細訓練信息(安靜模式)
options = trainingOptions("sgdm", ...  MaxEpochs = 4, ...  ValidationData = imdsValidation, ... ValidationFrequency = 30, ...  Plots = "training-progress", ...  Metrics = "accuracy", ...  Verbose = false); 

trainingOptions 是 MATLAB 中用于設置神經網絡訓練參數的函數。

"sgdm" 是一種常用優化算法,適用于多數分類問題。

MaxEpochs=4 設置為 4 是為了快速試驗,實際訓練中可以設置更大,比如 10、20 甚至更多。

ValidationFrequency=30 表示每 30 次 mini-batch 后在驗證集上評估一次性能,值越小越頻繁,但也會增加驗證的耗時。

Plots="training-progress" 是非常有用的調試和可視化工具,能幫助你觀察訓練是否收斂。

Verbose=false 適合在圖形界面中查看結果時使用;如果希望看到文字日志,可以設置為 true

使用 trainnet 函數訓練神經網絡。由于目標是分類任務,因此使用交叉熵損失函數(cross-entropy loss)

% 使用 trainnet 函數對神經網絡進行訓練
net = trainnet(imdsTrain, net_1, "crossentropy", options);
  • imdsTrain:訓練數據集,是一個圖像數據存儲對象(imageDatastore),包含用于訓練的圖像和對應標簽。
  • net_1:要訓練的神經網絡結構(可由 layerGraphdlnetwork 等方式定義的網絡)。
  • "crossentropy":指定損失函數為交叉熵損失函數(cross-entropy loss),這是分類任務中最常用的損失函數,特別適用于多類分類問題。
  • options:訓練選項,由前面設置的 trainingOptions 定義,包含訓練輪數、驗證數據、優化器、可視化等信息。

返回值:

  • net:訓練完成后的神經網絡,包含了優化后的權重和結構,可用于后續的預測或評估。

在這里插入圖片描述

4、測試神經網絡

使用 testnet 函數對神經網絡進行測試。對于單標簽分類任務,評估指標為準確率(accuracy),即預測正確的百分比。默認情況下,testnet 函數會在可用時自動使用 GPU。如果希望手動選擇執行環境,可以使用 testnet 函數的 ExecutionEnvironment 參數進行設置。

% 使用 testnet 函數對訓練好的神經網絡進行驗證,并評估其準確率
accuracy = testnet(net, imdsTest, "accuracy");
  • net:已訓練好的神經網絡模型,是前面通過 trainnet 得到的結果。
  • imdsTest:測試數據集,是一個圖像數據存儲對象(imageDatastore),用于測試模型的性能。
  • "accuracy":評估指標,這里指定為準確率,即預測正確的樣本數量占總樣本數量的百分比。

返回值:

  • accuracy:一個介于 0 和 1 之間的小數,表示模型在測試集上的準確率。例如,accuracy = 0.93 表示模型在測試集中有 93% 的預測是正確的。

testnet 函數自動根據你的硬件情況選擇在 CPU 還是 GPU 上運行。如果你想手動指定環境,比如使用 CPU,可以這樣寫:

accuracy = testnet(net, imdsTest, "accuracy", ExecutionEnvironment="cpu");

5、進行預測

使用 minibatchpredict 函數進行預測,并通過 scores2label 函數將預測得分轉換為類別標簽。默認情況下,如果有可用的 GPU,minibatchpredict 會自動使用 GPU 進行計算。

% 對測試集進行批量預測,輸出每個圖像對應的類別得分(概率)
scores = minibatchpredict(net, imdsValidation);% 將得分(scores)轉換為類別標簽,使用 classNames 映射到原始類名
YValidation = scores2label(scores, classNames);

可視化部分預測結果:

% 獲取測試集圖像的總數量
numTestObservations = numel(imdsTest.Files);% 從測試集中隨機選取 9 個樣本用于可視化
idx = randi(numTestObservations, 9, 1);% 創建一個新的圖形窗口
figure
tiledlayout("flow")  % 使用自動流式布局排列子圖(tiled layout)% 遍歷 9 張圖像,顯示圖像并在標題中標注預測類別
for i = 1:9nexttile  % 在下一個網格位置準備繪圖img = readimage(imdsTest, idx(i));  % 讀取第 idx(i) 張圖像imshow(img)  % 顯示圖像title("Predicted Class: " + string(YTest(idx(i))))  % 設置標題,顯示預測類別
end

在這里插入圖片描述

6、完整代碼

% 創建一個圖像數據存儲對象 `imds`,用于從名為 "DigitsData" 的文件夾中加載圖像數據
imds = imageDatastore("DigitsData", ...IncludeSubfolders=true, ...  % 指定在加載數據時包含子文件夾中的圖像LabelSource="foldernames");  % 使用子文件夾的名稱作為圖像的標簽(自動分類)% 獲取數據集中所有的類別名稱(即文件夾名),并將其存儲在變量 classNames 中
classNames = categories(imds.Labels);  % 將 imds.Labels%%
% 使用 splitEachLabel 函數將原始圖像數據集 imds 隨機劃分為訓練集、驗證集和測試集
[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.7, 0.15, 0.15, "randomized");% 設置用于網絡訓練的選項,這里使用的是隨機梯度下降動量法(SGDM)
% 最大訓練輪數(epoch):訓練過程中將整個訓練集完整迭代 4 次
% 指定驗證數據集,用于在訓練過程中評估模型的泛化能力
% 每訓練 30 個 mini-batch 執行一次驗證評估
% 在訓練過程中顯示實時圖形界面,包括損失值和準確率的變化曲線
% 指定訓練期間關注的評估指標為準確率(accuracy)
% 禁止在命令行窗口輸出詳細訓練信息(安靜模式)
options = trainingOptions("sgdm", ...  MaxEpochs = 4, ...  ValidationData = imdsValidation, ... ValidationFrequency = 30, ...  Plots = "training-progress", ...  Metrics = "accuracy", ...  Verbose = false); % 使用 trainnet 函數對神經網絡進行訓練
net = trainnet(imdsTrain, net_1, "crossentropy", options);%%
% 使用 testnet 函數對訓練好的神經網絡進行驗證,并評估其準確率
accuracy = testnet(net, imdsTest, "accuracy");%%
% 對測試集進行批量預測,輸出每個圖像對應的類別得分(概率)
scores = minibatchpredict(net, imdsTest);% 將得分(scores)轉換為類別標簽,使用 classNames 映射到原始類名
YTest = scores2label(scores, classNames);% 獲取測試集圖像的總數量
numTestObservations = numel(imdsTest.Files);% 從測試集中隨機選取 9 個樣本用于可視化
idx = randi(numTestObservations, 9, 1);% 創建一個新的圖形窗口
figure
tiledlayout("flow")  % 使用自動流式布局排列子圖(tiled layout)% 遍歷 9 張圖像,顯示圖像并在標題中標注預測類別
for i = 1:9nexttile  % 在下一個網格位置準備繪圖img = readimage(imdsTest, idx(i));  % 讀取第 idx(i) 張圖像imshow(img)  % 顯示圖像title("Predicted Class: " + string(YTest(idx(i))))  % 設置標題,顯示預測類別
end

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/74972.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/74972.shtml
英文地址,請注明出處:http://en.pswp.cn/web/74972.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Go 語言語法精講:從 Java 開發者的視角全面掌握

《Go 語言語法精講:從 Java 開發者的視角全面掌握》 一、引言1.1 為什么選擇 Go?1.2 適合 Java 開發者的原因1.3 本文目標 二、Go 語言環境搭建2.1 安裝 Go2.2 推薦 IDE2.3 第一個 Go 程序 三、Go 語言基礎語法3.1 變量與常量3.1.1 聲明變量3.1.2 常量定…

如何選擇優質的安全工具柜:材質、結構與功能的考量

在工業生產和實驗室環境中,安全工具柜是必不可少的設備。它不僅承擔著工具的存儲任務,還直接影響工作環境的安全和效率。那么,如何選擇一個優質的安全工具柜呢?關鍵在于對材質、結構和功能的考量。 01材質:耐用與防腐 …

系統與網絡安全------Windows系統安全(11)

資料整理于網絡資料、書本資料、AI,僅供個人學習參考。 制作U啟動盤 U啟動程序 下載制作U啟程序 Ventoy是一個制作可啟動U盤的開源工具,只需要把ISO等類型的文件拷貝到U盤里面就可以啟動了 同時支持x86LegacyBIOS、x86_64UEFI模式。 支持Windows、L…

【5】搭建k8s集群系列(二進制部署)之安裝master節點組件(kube-controller-manager)

注&#xff1a;承接專欄上一篇文章 一、創建配置文件 cat > /opt/kubernetes/cfg/kube-controller-manager.conf << EOF KUBE_CONTROLLER_MANAGER_OPTS"--logtostderrfalse \\ --v2 \\ --log-dir/opt/kubernetes/logs \\ --leader-electtrue \\ --kubeconfig/op…

C#里第一個WPF程序

WPF程序對界面進行優化,但是比WINFORMS的程序要復雜很多, 并且界面UI基本上不適合拖放,所以需要比較多的時間來布局界面, 產且需要開發人員編寫更多的代碼。 即使如此,在面對誘人的界面表現, 隨著客戶對界面的需求提高,還是需要采用這樣的方式來實現。 界面的樣式采…

createContext+useContext+useReducer組合管理React復雜狀態

createContext、useContext 和 useReducer 的組合是 React 中管理全局狀態的一種常見模式。這種模式非常適合在不引入第三方狀態管理庫&#xff08;如 Redux&#xff09;的情況下&#xff0c;管理復雜的全局狀態。 以下是一個經典的例子&#xff0c;展示如何使用 createContex…

記一次常規的網絡安全滲透測試

目錄&#xff1a; 前言 互聯網突破 第一層內網 第二層內網 總結 前言 上個月根據領導安排&#xff0c;需要到本市一家電視臺進行網絡安全評估測試。通過對內外網進行滲透測試&#xff0c;網絡和安全設備的使用和部署情況&#xff0c;以及網絡安全規章流程出具安全評估報告。本…

el-table,新增、復制數據后,之前的勾選狀態丟失

需要考慮是否為 更新數據的方式不對 如果新增數據的方式是直接替換原數據數組&#xff0c;而不是通過正確的響應式數據更新方式&#xff08;如使用 Vue 的 this.$set 等方法 &#xff09;&#xff0c;也可能導致勾選狀態丟失。 因為 Vue 依賴數據的響應式變化來準確更新視圖和…

第15屆藍橋杯java-c組省賽真題

目錄 一.拼正方形 1.題目 2.思路 3.代碼 二.勁舞團 1.題目 2.思路 3.代碼 三.數組詩意 1.題目 2.思路 3.代碼 四.封閉圖形個數 1.題目 2.思路 3.代碼 五.吊墜 1.題目 六.商品庫存管理 1.題目 2.思路 3.代碼 七.挖礦 1.題目 2.思路 3.代碼 八.回文字…

玄機-應急響應-入侵排查

靶機排查目標&#xff1a; 1.web目錄存在木馬&#xff0c;請找到木馬的密碼提交 查看/var/www/html。 使用find命令查找 find ./ -type f -name "*.php | xargs grep "eval("查看到1.php里面存在無條件一句話木馬。 2.服務器疑似存在不死馬&#xff0c;請找…

usbip學習記錄

USB/IP: USB device sharing over IP make menuconfig配置&#xff1a; Device Drivers -> Staging drivers -> USB/IP support Device Drivers -> Staging drivers -> USB/IP support -> Host driver 如果還有作為客戶端的需要&#xff0c;繼續做以下配置&a…

愛普生高精度車規晶振助力激光雷達自動駕駛

在自動駕駛技術快速落地的今天&#xff0c;激光雷達作為車輛的“智慧之眼”&#xff0c;其測距精度與可靠性直接決定了自動駕駛系統的安全上限。而在這雙“眼睛”的核心&#xff0c;愛普生&#xff08;EPSON&#xff09;的高精度車規晶振以卓越性能成為激光雷達實現毫米級感知的…

28--當路由器開始“宮斗“:設備控制面安全配置全解

當路由器開始"宮斗"&#xff1a;設備控制面安全配置全解 引言&#xff1a;路由器的"大腦保衛戰" 如果把網絡世界比作一座繁忙的城市&#xff0c;那么路由器就是路口執勤的交通警察。而控制面&#xff08;Control Plane&#xff09;就是警察的大腦&#xf…

58.基于springboot老人心理健康管理系統

目錄 1.系統的受眾說明 2.相關技術 2.1 B/S結構 2.2 MySQL數據庫 3.系統分析 3.1可行性分析 3.1.1時間可行性 3.1.2 經濟可行性 3.1.3 操作可行性 3.1.4 技術可行性 3.1.5 法律可行性 3.2系統流程分析 3.3系統功能需求分析 3.4 系統非功能需求分析 4.系統設計 …

去中心化固定利率協議

核心機制與分類 協議類型&#xff1a; 借貸協議&#xff08;如Yield、Notional&#xff09;&#xff1a;通過零息債券模型&#xff08;如fyDai、fCash&#xff09;鎖定固定利率。 收益聚合器&#xff08;如Saffron、BarnBridge&#xff09;&#xff1a;通過風險分級或博弈論…

反射率均值與RCS均值的計算方法差異

1. 反射率均值&#xff08;Mean Reflectance&#xff09; 定義&#xff1a; 反射率是物體表面反射的電磁波能量與入射能量的“比例”&#xff0c;通常以百分比或小數表示。 反射率均值是對多個測量點反射率的算術平均&#xff0c;反映目標區域整體的平均反射特性。 特點&a…

[MySQL初階]MySQL(8)索引機制:下

標題&#xff1a;[MySQL初階]MySQL&#xff08;8&#xff09;索引機制&#xff1a;下 水墨不寫bug 文章目錄 四、從問題到底層&#xff0c;從現象到本質1.為什么插入的數據默認排好序2.MySQL的Page&#xff08;1&#xff09;為什么選擇用Page&#xff1f;&#xff08;2&#x…

Access:在移動互聯網與AI時代煥發新生

Microsoft Access&#xff1a;在移動互聯網與AI時代煥發新生 在移動互聯網和人工智能&#xff08;AI&#xff09;技術快速發展的今天&#xff0c;許多傳統工具被認為已經過時。然而&#xff0c;Microsoft Access&#xff0c;這款曾經風靡一時的數據庫&#xff0c;真的已經被淘…

【無人機】無人機PX4飛控系統高級軟件架構

目錄 1、概述&#xff08;圖解&#xff09; 一、數據存儲層&#xff08;Storage&#xff09; 二、外部通信層&#xff08;External Connectivity&#xff09; 三、核心通信樞紐&#xff08;Message Bus&#xff09; 四、硬件驅動層&#xff08;Drivers&#xff09; 五、飛…

【項目日記】高并發服務器項目總結

生活總是讓我們遍體鱗傷&#xff0c; 但到后來&#xff0c; 那些受傷的地方一定會變成我們最強壯的地方。 -- 《老人與海》-- 高并發服務器項目總結 模塊關系圖項目工具模塊緩沖區模塊通用類型模塊套接字socket模塊信道Channel模塊多路轉接Poller模塊 Reactor模塊時間輪Tim…