【深度學習】卷積神經網絡(CNN)的參數優化方法

卷積神經網絡(CNN)的參數優化方法

著名:?本文是從 Michael Nielsen的電子書Neural Network and Deep Learning的深度學習那一章的卷積神經網絡的參數優化方法的一些總結和摘錄,并不是我自己的結論和做實驗所得到的結果。我想Michael的實驗結果更有說服力一些。本書在github上有中文翻譯的版本,

前言

最近卷積神經網絡(CNN)很火熱,它在圖像分類領域的卓越表現引起了大家的廣泛關注。本文總結和摘錄了Michael Nielsen的那本Neural Network and Deep Learning一書中關于深度學習一章中關于提高泛化能力的一些概述和實驗結果。力爭用數據給大家一個關于正則化,增加卷積層/全連接數,棄權技術,拓展訓練集等參數優化方法的效果。

本文并不會介紹正則化,棄權(Dropout), 池化等方法的原理,只會介紹它們在實驗中的應用或者起到的效果,更多的關于這些方法的解釋請自行查詢。

mnist數據集介紹

本文的實驗是基于mnist數據集合的,mnist是一個從0到9的手寫數字集合,共有60,000張訓練圖片,10,000張測試圖片。每張圖片大小是28*28大小。我們的實驗就是構建一個神經網絡來高精度的分類圖片,也就是提高泛化能力。

卷積神經網絡(CNN)的參數優化方法
提高泛化能力的方法

一般來說,提高泛化能力的方法主要有以下幾個:

?

  • 正則化
  • 增加神經網絡層數
  • 使用正確的代價函數
  • 使用好的權重初始化技術
  • 人為拓展訓練集
  • 棄權技術

?

下面我們通過實驗結果給這些參數優化理論一個直觀的結果

1. 普通的全連接神經網絡的效果

我們使用一個隱藏層,包含100個隱藏神經元,輸入層是784,輸出層是one-hot編碼的形式,最后一層是Softmax層。訓練過程采用對數似然代價函數,60次迭代,學習速率η=0.1,隨機梯度下降的小批量數據大小為10,沒有正則化。在測試集上得到的結果是97.8%,代碼如下:

  
  1. >>> import network3
  2. >>> from network3 import Network
  3. >>> from network3 import ConvPoolLayer, FullyConnectedLayer, SoftmaxLayer
  4. >>> training_data, validation_data, test_data = network3.load_data_shared()
  5. >>> mini_batch_size = 10
  6. >>> net = Network([
  7. FullyConnectedLayer(n_in=784, n_out=100),
  8. SoftmaxLayer(n_in=100, n_out=10)], mini_batch_size)
  9. >>> net.SGD(training_data, 60, mini_batch_size, 0.1,
  10. validation_data, test_data)

2.使用卷積神經網絡 — 僅一個卷積層

輸入層是卷積層,5*5的局部感受野,也就是一個5*5的卷積核,一共20個特征映射。最大池化層選用2*2的大小。后面是100個隱藏神經元的全連接層。結構如圖所示

卷積神經網絡(CNN)的參數優化方法
在這個架構中,我們把卷積層和chihua層看做是學習輸入訓練圖像中的局部感受野,而后的全連接層則是一個更抽象層次的學習,從整個圖像整合全局信息。也是60次迭代,批量數據大小是10,學習率是0.1.代碼如下,

  
  1. >>> net = Network([
  2. ConvPoolLayer(image_shape=(mini_batch_size, 1, 28, 28),
  3. filter_shape=(20, 1, 5, 5),
  4. poolsize=(2, 2)),
  5. FullyConnectedLayer(n_in=20*12*12, n_out=100),
  6. SoftmaxLayer(n_in=100, n_out=10)], mini_batch_size)
  7. >>> net.SGD(training_data, 60, mini_batch_size, 0.1,
  8. validation_data, test_data)

經過三次運行取平均后,準確率是98.78%,這是相當大的改善。錯誤率降低了1/3,。卷積神經網絡開始顯現威力。

3.使用卷積神經網絡 — 兩個卷積層

我們接著插入第二個卷積-混合層,把它插入在之前的卷積-混合層和全連接層之間,同樣的5*5的局部感受野,2*2的池化層。

  
  1. >>> net = Network([
  2. ConvPoolLayer(image_shape=(mini_batch_size, 1, 28, 28),
  3. filter_shape=(20, 1, 5, 5),
  4. poolsize=(2, 2)),
  5. ConvPoolLayer(image_shape=(mini_batch_size, 20, 12, 12),
  6. filter_shape=(40, 20, 5, 5),
  7. poolsize=(2, 2)),
  8. FullyConnectedLayer(n_in=40*4*4, n_out=100),
  9. SoftmaxLayer(n_in=100, n_out=10)], mini_batch_size)
  10. >>> net.SGD(training_data, 60, mini_batch_size, 0.1,
  11. validation_data, test_data)

這一次,我們擁有了99.06%的準確率。

4.使用卷積神經網絡 — 兩個卷積層+線性修正單元(ReLU)+正則化

上面我們使用的Sigmod激活函數,現在我們換成線性修正激活函數ReLU
f(z)=max(0,z),我們選擇60個迭代期,學習速率η=0.03, ,使用L2正則化,正則化參數λ=0.1,代碼如下

  
  1. >>> from network3 import ReLU
  2. >>> net = Network([
  3. ConvPoolLayer(image_shape=(mini_batch_size, 1, 28, 28),
  4. filter_shape=(20, 1, 5, 5),
  5. poolsize=(2, 2),
  6. activation_fn=ReLU),
  7. ConvPoolLayer(image_shape=(mini_batch_size, 20, 12, 12),
  8. filter_shape=(40, 20, 5, 5),
  9. poolsize=(2, 2),
  10. activation_fn=ReLU),
  11. FullyConnectedLayer(n_in=40*4*4, n_out=100, activation_fn=ReLU),
  12. SoftmaxLayer(n_in=100, n_out=10)], mini_batch_size)
  13. >>> net.SGD(training_data, 60, mini_batch_size, 0.03,
  14. validation_data, test_data, lmbda=0.1)

這一次,我們獲得了99.23%的準確率,超過了S型激活函數的99.06%. ReLU的優勢是max(0,z)中z取最大極限時不會飽和,不像是S函數,這有助于持續學習。

5.使用卷積神經網絡 — 兩個卷基層+線性修正單元(ReLU)+正則化+拓展數據集

拓展訓練集數據的一個簡單方法是將每個訓練圖像由一個像素來代替,無論是上一個像素,下一個像素,或者左右的像素。其他的方法也有改變亮度,改變分辨率,圖片旋轉,扭曲,位移等。

我們把50,000幅圖像人為拓展到250,000幅圖像。使用第4節一樣的網絡,因為我們是在訓練5倍的數據,所以減少了過擬合的風險。

  
  1. >>> expanded_training_data, _, _ = network3.load_data_shared(
  2. "../data/mnist_expanded.pkl.gz")
  3. >>> net = Network([
  4. ConvPoolLayer(image_shape=(mini_batch_size, 1, 28, 28),
  5. filter_shape=(20, 1, 5, 5),
  6. poolsize=(2, 2),
  7. activation_fn=ReLU),
  8. ConvPoolLayer(image_shape=(mini_batch_size, 20, 12, 12),
  9. filter_shape=(40, 20, 5, 5),
  10. poolsize=(2, 2),
  11. activation_fn=ReLU),
  12. FullyConnectedLayer(n_in=40*4*4, n_out=100, activation_fn=ReLU),
  13. SoftmaxLayer(n_in=100, n_out=10)], mini_batch_size)
  14. >>> net.SGD(expanded_training_data, 60, mini_batch_size, 0.03,
  15. validation_data, test_data, lmbda=0.1)

這次的到了99.37的訓練正確率。

6.使用卷積神經網絡 — 兩個卷基層+線性修正單元(ReLU)+正則化+拓展數據集+繼續插入額外的全連接層

繼續上面的網絡,我們拓展全連接層的規模,300個隱藏神經元和1000個神經元的額精度分別是99.46%和99.43%.
我們插入一個額外的全連接層

  
  1. >>> net = Network([
  2. ConvPoolLayer(image_shape=(mini_batch_size, 1, 28, 28),
  3. filter_shape=(20, 1, 5, 5),
  4. poolsize=(2, 2),
  5. activation_fn=ReLU),
  6. ConvPoolLayer(image_shape=(mini_batch_size, 20, 12, 12),
  7. filter_shape=(40, 20, 5, 5),
  8. poolsize=(2, 2),
  9. activation_fn=ReLU),
  10. FullyConnectedLayer(n_in=40*4*4, n_out=100, activation_fn=ReLU),
  11. FullyConnectedLayer(n_in=100, n_out=100, activation_fn=ReLU),
  12. SoftmaxLayer(n_in=100, n_out=10)], mini_batch_size)
  13. >>> net.SGD(expanded_training_data, 60, mini_batch_size, 0.03,
  14. validation_data, test_data, lmbda=0.1)

這次取得了99.43%的精度。拓展后的網絡并沒有幫助太多。

7.使用卷積神經網絡 — 兩個卷基層+線性修正單元(ReLU)+拓展數據集+繼續插入額外的全連接層+棄權技術

棄權的基本思想就是在訓練網絡時隨機的移除單獨的激活值,使得模型對單獨的依據丟失更為強勁,因此不太依賴于訓練數據的特質。我們嘗試應用棄權技術到最終的全連接層(不是在卷基層)。這里,減少了迭代期的數量為40個,全連接層使用1000個隱藏神經元,因為棄權技術會丟棄一些神經元。Dropout是一種非常有效有提高泛化能力,降低過擬合的方法!

  
  1. >>> net = Network([
  2. ConvPoolLayer(image_shape=(mini_batch_size, 1, 28, 28),
  3. filter_shape=(20, 1, 5, 5),
  4. poolsize=(2, 2),
  5. activation_fn=ReLU),
  6. ConvPoolLayer(image_shape=(mini_batch_size, 20, 12, 12),
  7. filter_shape=(40, 20, 5, 5),
  8. poolsize=(2, 2),
  9. activation_fn=ReLU),
  10. FullyConnectedLayer(
  11. n_in=40*4*4, n_out=1000, activation_fn=ReLU, p_dropout=0.5),
  12. FullyConnectedLayer(
  13. n_in=1000, n_out=1000, activation_fn=ReLU, p_dropout=0.5),
  14. SoftmaxLayer(n_in=1000, n_out=10, p_dropout=0.5)],
  15. mini_batch_size)
  16. >>> net.SGD(expanded_training_data, 40, mini_batch_size, 0.03,
  17. validation_data, test_data)

使用棄權技術,的到了99.60%的準確率。

8.使用卷積神經網絡 — 兩個卷基層+線性修正單元(ReLU)+正則化+拓展數據集+繼續插入額外的全連接層+棄權技術+組合網絡

組合網絡類似于隨機森林或者adaboost的集成方法,創建幾個神經網絡,讓他們投票來決定最好的分類。我們訓練了5個不同的神經網絡,每個都大到了99.60%的準去率,用這5個網絡來進行投票表決一個圖像的分類。

采用這個方法,達到了99.67%的準確率。

總結

卷積神經網絡 的一些技巧總結如下:

1. 使用卷積層極大地減小了全連接層中的參數的數目,使學習的問題更容易

2. 使用更多強有力的規范化技術(尤其是棄權和卷積)來減小過度擬合,

3. 使用修正線性單元而不是S型神經元,來加速訓練-依據經驗,通常是3-5倍,

4. 使用GPU來計算

5. 利用充分大的數據集,避免過擬合

6. 使用正確的代價函數,避免學習減速

7. 使用好的權重初始化,避免因為神經元飽和引起的學習減速

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/167229.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/167229.shtml
英文地址,請注明出處:http://en.pswp.cn/news/167229.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【不同請求方式在springboot中對應的注解】

GET 請求方法&#xff1a;用于獲取資源。使用 GetMapping 注解來處理 GET 請求。 示例代碼&#xff1a; RestController public class MyController {GetMapping("/resource")public ResponseEntity<String> getResource() {// 處理 GET 請求邏輯} }POST 請求方…

喜訊!云起無垠成為國家信息安全漏洞庫(CNNVD)技術支撐單位

近日&#xff0c;云起無垠憑借其在漏洞挖掘、漏洞檢測以及漏洞修復等領域的卓越表現&#xff0c;榮獲“國家信息安全漏洞庫&#xff08;CNNVD&#xff09;技術支撐單位等級證書&#xff08;三級&#xff09;”&#xff0c;正式成為CNNVD技術支撐單位。 中國國家信息安全漏洞庫&…

MTK聯發科MT6762/MT6763/MT6765安卓核心板參數規格比較

MT6762安卓核心板 MTK6762安卓核心板是一款工業級高性能、可運行 android9.0 操作系統的 4G智能模塊。 CPU&#xff1a;4xCortex-A53 up to 2.0Ghz/4xCortex-A53 up to 1.5GhzGraphics&#xff1a;IMG GE8320 Up to 650MhzProcess&#xff1a;12nmMemory&#xff1a;1xLP3 9…

【正點原子STM32連載】 第六十章 串口IAP實驗(Julia分形)實驗 摘自【正點原子】APM32F407最小系統板使用指南

1&#xff09;實驗平臺&#xff1a;正點原子APM32F407最小系統板 2&#xff09;平臺購買地址&#xff1a;https://detail.tmall.com/item.htm?id609294757420 3&#xff09;全套實驗源碼手冊視頻下載地址&#xff1a; http://www.openedv.com/thread-340252-1-1.html## 第六十…

CMake使用file(GLOB ...)需要注意的問題

文章目錄 基本語法使用例子潛在的問題大型項目中推薦的用法 file(GLOB ...) 命令用于獲取匹配指定模式的文件列表。在 CMake 中&#xff0c;file(GLOB ...) 命令的一種常見用法是用于收集源文件列表&#xff0c;例如 C 源文件&#xff08;.cpp&#xff09;和 C 源文件&#xff…

html頁面加載json數據,在html中顯示JSON數據的方法

html頁面加載json數據,在html中顯示JSON數據的方法 export const mixin {methods: {syntaxHighlight(json) {if (typeof json ! string) {json JSON.stringify(json, undefined, 2);}json json.replace(/&/g, &).replace(/</g, <).replace(/>/g, >);re…

實例分割12篇頂會論文及代碼合集,含2023最新

同學們&#xff0c;你們覺得視覺經典四個任務中哪個最難&#xff1f;我個人覺得是實例分割。 因為它既具備語義分割的特點&#xff0c;需要做到像素層面上的分類&#xff0c;也具備目標檢測的一部分特點&#xff0c;即需要定位出不同實例&#xff0c;即使它們是同一種類。 但…

LangChain的函數,工具和代理(一):OpenAI的函數調用

一、什么是函數調用功能 幾個月前OpenAI官方發布了其API的函數調用功能(Function calling), 在 API 調用中&#xff0c;您可以描述函數&#xff0c;并讓模型智能地選擇輸出包含調用一個或多個函數的參數的 JSON 對象。API函數“ChatCompletion” 雖然不會實際調用該函數&#…

C語言變量和常量

變量和常量 標識符 在計算機高級語言中&#xff0c;用來對變量、符號常量、函數、數組、類型等命名的有效字符序列統稱為標識符&#xff08;identifier&#xff09;。 C語言規定標識符&#xff1a; 只能由字母&#xff0c;數字和下劃線組成。不能以數字開頭。字母區分大小寫…

一站式企業快遞管理平臺使用教程

因公寄件在企業中重要性的提升&#xff0c;催生出了企業快遞管理平臺。為什么這么說呢&#xff1f; 隨著經濟和快遞行業的發展&#xff0c;因公寄件在企業中成了一件“常事”&#xff0c;寄文件合同、發票、節假日慰問品、樣品等等&#xff0c;這種情況之下&#xff0c;因公寄件…

Vue3 設置點擊后滾動條移動到固定的位置

需求&#xff1a; 點擊不通過按鈕&#xff0c;顯示紅框中表單&#xff0c;且滾動條滾動到底部 &#xff08;顯示紅框中表單默認不顯示&#xff09; <el-button click"onApprovalPass">不通過</el-button> <div class"item" v-if"app…

vue打包優化

vue.config.js文件中 module.exports defineConfig({ productionSourceMap: false,//去掉mapjs文件 });

pwn:[SWPUCTF 2021 新生賽]nc簽到

題目 linux環境下顯示為 配合題目的下載附件&#xff0c;發現過濾了一些&#xff0c;一旦輸入這些會自動關閉程序 ls被過濾了&#xff0c;可以使用l\s cat和空格都被過濾了&#xff0c;cat可以換成c\at ,空格可以換成$IFS$9

<HarmonyOS第一課>1·運行Hello World【課后考核】

【習題】運行Hello World工程 判斷題 1.DevEco Studio是開發HarmonyOS應用的一站式集成開發環境。 正確(True) 2.main_pages.json存放頁面page路徑配置信息。 正確(True) 單選題 1.在stage模型中&#xff0c;下列配置文件屬于AppScope文件夾的是&#xff1f;&#xff08;…

Youtube0播放?運營教你需要的技巧、策略與工具!

對于有跨境意向的內容創作者或者品牌企業來說&#xff0c;YouTube是因其巨大的潛在受眾群和商業價值成為最值得投入變現與營銷計劃的平臺。 據統計&#xff0c;98% 的美國人每月訪問 YouTube&#xff0c;近三分之二的人每天訪問。但是&#xff0c;YouTube還遠未達到過度飽和的…

酵母雙雜交服務專題(一)

酵母雙雜交系統是一種在酵母這種真核生物模型中執行的實驗方法&#xff0c;用于探索活細胞內部蛋白質間的相互作用。這種技術能夠敏感地捕捉蛋白質間的細微和短暫相互作用&#xff0c;通過檢測報告基因的表達產物來實現。作為一種高度靈敏的技術&#xff0c;酵母雙雜交系統被廣…

Spring Cloud LoadBalancer 簡單介紹與實戰

前言 本文為SpringCloud的學習筆記&#xff0c;如有錯誤&#xff0c;希望各位高手能指出&#xff0c;主要介紹SpringCloudLoadBalancer的基本概念和實戰 文章目錄 前言什么是LoadBalancer負載均衡分類服務端負載均衡客戶端負載均衡服務端負載均衡和客戶端負載均衡的優缺點 常見…

評測|PolarDB MySQL 版 Serverless

評測&#xff5c;PolarDB MySQL 版 Serverless 目錄 一、測試背景 1.1、云原生數據庫 PolarDB Serverless新架構概念 1.2、Serverless資源彈性擴縮觸發條件 二、PolarDB的Serverless能力與同類型產品進行對比 三、動態彈性升降資源的能力測試 3.1、測試資源 3.2、測試一…

ubuntu22.04在線安裝redis,可選擇版本

安裝腳本7.0.5版本 在線安裝腳本&#xff0c;默認版本號是7.0.5&#xff0c;可以根據需要選擇需要的版本進行下載編譯安裝 sudo apt-get install gcc -y sudo apt-get install pkg-config -y sudo apt-get install build-essential -y#安裝redis rm -rf ./tmp.log systemctl …