2-2 MATLAB鮣魚優化算法ROA優化CNN超參數回歸預測

本博客來源于CSDN機器魚,未同意任何人轉載。

更多內容,歡迎點擊本專欄目錄,查看更多內容。

目錄

0.引言

1.ROA優化CNN

2.主程序調用

3.結語


0.引言

在博客【ROA優化LSTM超參數回歸】中,我們采用ROA對LSTM的學習率、迭代次數、batchsize、兩個lstmlayer的節點數進行尋優,在優化過程中我們不必知道ROA的具體優化原理,只需要修改lb、ub、維度D、邊界判斷、適應度函數即可。今天這邊博客,我們依舊采用此前提到的步驟對CNN的超參數進行回歸,話不多說,首先我們定義一個超級簡單的CNN網絡進行回歸預測,代碼如下:

clc;clear;close all;rng(0)
%% 數據的提取
load data%數據是4輸入1輸出的簡單數據
train_x;%4*98
train_y;%1*98
test_x;%4*42
test_y;%1*98
%轉成CNN的輸入格式
feature=size(train_x,1);
num_train=size(train_x,2);
num_test=size(test_x,2);
trainD=reshape(train_x,[feature,1,1,num_train]);
testD=reshape(test_x,[feature,1,1,num_test]);
targetD = train_y';
targetD_test  = test_y';%% 網絡構建
layers = [imageInputLayer([size(trainD,1) size(trainD,2) size(trainD,3)]) % 輸入convolution2dLayer(3,4,'Stride',1,'Padding','same')%核3*1 數量4 步長1 填充為samereluLayer%relu激活convolution2dLayer(3,8,'Stride',1,'Padding','same')%核3*1 數量8 步長1 填充為samereluLayer%relu激活fullyConnectedLayer(20) % 全連接層1 20個神經元reluLayerfullyConnectedLayer(20) % 全連接層2 20個神經元reluLayerfullyConnectedLayer(size(targetD,2)) %輸出層regressionLayer];
%% 網絡訓練
options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'MaxEpochs',30, ...'MiniBatchSize',16, ...'InitialLearnRate',0.01, ...'GradientThreshold',1, ...'shuffle','every-epoch',...'Verbose',false);
train_again=1;% 為1就代碼重新訓練模型,為0就是調用訓練好的網絡
if train_again==1[net,traininfo] = trainNetwork(trainD,targetD,layers,options);save result/cnn_net net traininfo
elseload result/cnn_net
end
figure;
plot(traininfo.TrainingLoss,'b')
hold on;grid on
ylabel('損失')
xlabel('訓練次數')
title('CNN')
%% 結果評價
YPred = predict(net,testD);YPred=double(YPred);

觀察網絡構建與訓練,我們發現至少有9個參數需要優化,分別是:迭代次數MaxEpochs、MiniBatchSize、第一層卷積層的核大小和數量、第2層卷積層的核大小和數量,以及兩個全連接層的神經元數量,還有學習率InitialLearnRate(學習率放最后是因為其他的都是整數,只有這個是小數,要么放最前要么放最后,方便我們寫邊界判斷函數與初始化種群的程序)

1.ROA優化CNN

步驟1:知道要優化的參數與優化范圍。顯然就是上面提到的9個參數。代碼如下,首先改寫lb與ub,然后初始化的時候注意除了學習率,其他的都是整數。并將原來里面的邊界判斷,改成了Bounds函數,方便在計算適應度函數值的時候轉化成整數與小數。如果學習率的位置不在最后,而是在其他位置,就需要改隨機初始化位置和Bounds函數與fitness函數里對應的地方,具體怎么改就不說了,很簡單。

function [Rbest,Convergence_curve,process]= roa_cnn(X1,y1,Xt,yt)
D=9;%一共有9個參數需要優化,分別是迭代次數、batchsize、第一層卷積層的核大小、和數量、第2層卷積層的核大小、和數量,以及兩個全連接層的神經元數量,學習率
lb= [10 16  1 1 1 1 1 1 0.001];    % 下邊界
ub= [50 256 3 20 3 20 50 50 0.01];    % 上邊界
% 迭代次數的范圍是10-50 batchsize的范圍是16-256 核大小的范圍是1-3 核數量的范圍是1-20 全連接層的范圍是1-50% 學習率的范圍是0.001-0.01
sizepop=5;
maxgen=10;% maxgen 為最大迭代次數,
% sizepop 為種群規模
%記D為維度,lb、 ub分別為搜索上、下限
R=ones(sizepop,D);%預設種群
for i=1:sizepop%隨機初始化位置for j=1:Dif j==D%除了學習率 其他的都是整數R( i, j ) = (ub(j)-lb(j))*rand+lb(j);elseR( i, j ) = round((ub(j)-lb(j))*rand+lb(j));endend
endfor k= 1:sizepopFitness(k)=fitness(R(k,:),X1,y1,Xt,yt);%個體適應度
end
[Fbest,elite]= min(Fitness);%Fbest為最優適應度值
Rbest= R(elite,:);%最優個體位置
H=zeros(1,sizepop);%控制因子%主循環
for iter= 1:maxgenRpre= R;%記錄上一代的位置V=2*(1-iter/maxgen);B= 2*V*rand-V;a=-(1 + iter/maxgen);alpha=rand*(a-1)+ 1;for i= 1:sizepopif H(i)==0dis = abs(Rbest-R(i,:));R(i,:)= R(i,:)+ dis* exp(alpha)*cos(2*pi* alpha);elseRAND= ceil(rand*sizepop);%隨機選擇一個個體R(i,:)= Rbest -(rand*0.5*(Rbest + R(RAND,:))- R(RAND,:));endRatt= R(i,:)+ (R(i,:)- Rpre(i,:))*randn;%作出小幅度移動%邊界吸收R(i, : ) = Bounds( R(i, : ), lb, ub );%對超過邊界的變量進行去除Ratt = Bounds( Ratt, lb, ub );%對超過邊界的變量進行去除Fitness(i)=fitness(R(i,:),X1,y1,Xt,yt);Fitness_Ratt= fitness(Ratt,X1,y1,Xt,yt);if Fitness_Ratt < Fitness(i)%改變寄主if H(i)==1H(i)=0;elseH(i)=1;endelse %不改變寄主A= B*(R(i,:)-rand*0.3*Rbest);R(i,:)=R(i,:)+A;endR(i, : ) = Bounds( R(i, : ), lb, ub );%對超過邊界的變量進行去除end%更新適應度值、位置[fbest,elite] = min(Fitness);%更新最優個體if fbest< FbestFbest= fbest;Rbest= R(elite,:);endprocess(iter,:)=Rbest;Convergence_curve(iter)= Fbest;iter,Fbest,Rbest
endendfunction s = Bounds( s, Lb, Ub)
temp = s;
dim=length(Lb);
for i=1:length(s)if i==dim%除了學習率 其他的都是整數temp(:,i) =temp(:,i);elsetemp(:,i) =round(temp(:,i));end
end% 判斷參數是否超出設定的范圍for i=1:length(s)if temp(:,i)>Ub(i) | temp(:,i)<Lb(i) if i==dim%除了學習率 其他的都是整數temp(:,i) =rand*(Ub(i)-Lb(i))+Lb(i);elsetemp(:,i) =round(rand*(Ub(i)-Lb(i))+Lb(i));endend
end
s = temp;
end
function s = Bounds( s, Lb, Ub)
temp = s;
for i=1:length(s)if i==1%除了學習率 其他的都是整數temp(:,i) =temp(:,i);elsetemp(:,i) =round(temp(:,i));end
end% 判斷參數是否超出設定的范圍for i=1:length(s)if temp(:,i)>Ub(i) | temp(:,i)<Lb(i) if i==1%除了學習率 其他的都是整數temp(:,i) =rand*(Ub(i)-Lb(i))+Lb(i);elsetemp(:,i) =round(rand*(Ub(i)-Lb(i))+Lb(i));endend
end
s = temp;
end

步驟2:知道優化的目標。優化的目標是提高的網絡的準確率,而ROA代碼我們這個代碼是最小值優化的,所以我們的目標可以是最小化CNN的預測誤差。預測誤差具體是,測試集(或驗證集)的預測值與真實值之間的均方差。

步驟3:構建適應度函數。通過步驟2我們已經知道目標,即采用ROA去找到9個值,用這9個值構建的CNN網絡,誤差最小化。觀察下面的代碼,首先我們將ROA的值傳進來,然后轉成需要的9個值,然后構建網絡,訓練集訓練、測試集預測,計算預測值與真實值的mse,將mse作為結果傳出去作為適應度值。

function y=fitness(x,trainD,targetD,testD,targetD_test)
rng(0)
%% 將傳進來的值 轉換為需要的超參數
iter=x(1);
minibatch=x(2);
kernel1_size=x(3);
kernel1_num=x(4);
kernel2_size=x(5);
kernel2_num=x(6);
fc1_num=x(7);
fc2_num=x(8);
lr=x(9);feature=size(trainD,1);
%% 利用尋優得到參數重新訓練CNN與預測 
layers = [imageInputLayer([size(trainD,1) size(trainD,2) size(trainD,3)]) convolution2dLayer(kernel1_size,kernel1_num,'Stride',1,'Padding','same')reluLayerconvolution2dLayer(kernel2_size,kernel2_num,'Stride',1,'Padding','same')reluLayerfullyConnectedLayer(fc1_num) reluLayerfullyConnectedLayer(fc2_num) reluLayerfullyConnectedLayer(size(targetD,2))regressionLayer];
options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'MaxEpochs',iter, ...'MiniBatchSize',minibatch, ...'InitialLearnRate',lr, ...'GradientThreshold',1, ...'shuffle','every-epoch',...'Verbose',false);
net = trainNetwork(trainD,targetD,layers,options);
YPred = predict(net,testD);
%% 適應度值計算
YPred=double(YPred);
%以CNN的預測值與實際值的均方誤差最小化作為適應度函數,SSA的目的就是找到一組超參數
%用這組超參數訓練得到的CNN的誤差能夠最小化
[m,n]=size(YPred);
YPred=reshape(YPred,[1,m*n]);
targetD_test=reshape(targetD_test,[1,m*n]);
y=mse(YPred,targetD_test);rng(sum(100*clock))

2.主程序調用

clc;clear;close all;format compact;rng(0)%% 數據的提取
load data
load data%數據是4輸入1輸出的簡單數據
train_x;%4*98
train_y;%1*98
test_x;%4*42
test_y;%1*98
feature=size(train_x,1);
num_train=size(train_x,2);
num_test=size(test_x,2);
trainD=reshape(train_x,[feature,1,1,num_train]);
testD=reshape(test_x,[feature,1,1,num_test]);
targetD = train_y';
targetD_test  = test_y';%% ROA優化CNN的超參數
%一共有9個參數需要優化,分別是學習率、迭代次數、batchsize、第一層卷積層的核大小、和數量、第2層卷積層的核大小、和數量,以及兩個全連接層的神經元數量
optimaztion=1;  
if optimaztion==1[x,trace,process]=roa_cnn(trainD,targetD,testD,targetD_test);save result/roa_result x trace process
elseload result/roa_result
end
%%figure
plot(trace)
title('適應度曲線')
xlabel('優化次數')
ylabel('適應度值')disp('優化后的各超參數')iter=x(1)%迭代次數
minibatch=x(2)%batchsize 
kernel1_size=x(3)
kernel1_num=x(4)%第一層卷積層的核大小與核數量
kernel2_size=x(5)
kernel2_num=x(6)%第2層卷積層的核大小與核數量
fc1_num=x(7)
fc2_num=x(8)%兩個全連接層的神經元數量
lr=x(9)%學習率%% 利用尋優得到參數重新訓練CNN與預測
rng(0)
layers = [imageInputLayer([size(trainD,1) size(trainD,2) size(trainD,3)])convolution2dLayer(kernel1_size,kernel1_num,'Stride',1,'Padding','same')reluLayerconvolution2dLayer(kernel2_size,kernel2_num,'Stride',1,'Padding','same')reluLayerfullyConnectedLayer(fc1_num)reluLayerfullyConnectedLayer(fc2_num)reluLayerfullyConnectedLayer(size(targetD,2))regressionLayer];
options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'MaxEpochs',iter, ...'MiniBatchSize',minibatch, ...'InitialLearnRate',lr, ...'GradientThreshold',1, ...'Verbose',false);train_again=1;% 為1就重新訓練模型,為0就是調用訓練好的網絡 
if train_again==1[net,traininfo] = trainNetwork(trainD,targetD,layers,options);save result/roacnn_net net traininfo
elseload result/roacnn_net
endfigure;
plot(traininfo.TrainingLoss,'b')
hold on;grid on
ylabel('損失')
xlabel('訓練次數')
title('roa-CNN')%% 結果評價
YPred = predict(net,testD);YPred=double(YPred);

3.結語

優化網絡超參數的格式都是這樣的!只要會改一種,那么隨便拿一份能跑通的優化算法,在不管原理的情況下,都能用來優化網絡的超參數。更多內容【點擊專欄】目錄。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/75625.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/75625.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/75625.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

企業入駐成都國際數字影像產業園,可享150多項專業服務

企業入駐成都國際數字影像產業園&#xff0c;可享150多項專業服務 全方位賦能&#xff0c;助力影像企業騰飛 入駐成都國際數字影像產業園&#xff0c;企業將獲得一個涵蓋超過150項專業服務的全周期、一站式支持體系&#xff0c;旨在精準解決企業發展各階段的核心需求&#xf…

線路板元器件介紹及選型指南:提高電路設計效率

電路板&#xff08;PCB&#xff09;是現代電子設備的核心&#xff0c;其上安裝了各類電子元器件&#xff0c;這些元器件通過PCB的導電線路彼此連接&#xff0c;實現信號傳輸與功能執行。 元器件的選擇與安裝直接決定了電子產品的性能與穩定性。本文將為大家詳細介紹電路板上的…

探究 Arm Compiler for Embedded 6 的 Clang 版本

原創標題&#xff1a;Arm Compiler for Embedded 6 的 Clang 版本 原創作者&#xff1a;莊曉立&#xff08;LIIGO&#xff09; 原創日期&#xff1a;20250218&#xff08;首發日期20250326&#xff09; 原創連接&#xff1a;https://blog.csdn.net/liigo/article/details/14653…

RedHat7.6_x86_x64服務器(最小化安裝)搭建使用記錄(二)

PostgreSQL數據庫部署管理 1.rpm方式安裝 掛載系統安裝鏡像&#xff1a; [rootlocalhost ~]# mount /dev/cdrom /mnt 進入安裝包路徑&#xff1a; [rootlocalhost ~]# cd /mnt/Packages 依次安裝如下程序包&#xff1a; [rootlocalhost Packages]# rpm -ihv postgresql-libs-9…

瀏覽器存儲 IndexedDB

IndexedDB 1. 什么是 IndexedDB&#xff1f; IndexedDB 是一種 基于瀏覽器的 NoSQL 數據庫&#xff0c;用于存儲大量的結構化數據&#xff0c;包括文件和二進制數據。它比 localStorage 和 sessionStorage 更強大&#xff0c;支持索引查詢、事務等特性。 IndexedDB 主要特點…

panda3d 渲染

目錄 安裝 設置渲染寬高&#xff1a; 渲染3d 安裝 pip install Panda3D 設置渲染寬高&#xff1a; import panda3d.core as pdmargin 100 screen Tk().winfo_screenwidth() - margin, Tk().winfo_screenheight() - margin width, height (screen[0], int(screen[0] / 1…

Node.js 包管理工具 - NPM 與 PNPM 清理緩存

NPM 清理緩存 1、基本介紹 npm 緩存是 npm 用來存儲已下載包的地方&#xff0c;以加快后續安裝速度 但是&#xff0c;有時緩存可能會損壞或占用過多磁盤空間&#xff0c;這時可以清理 npm 緩存 2、清理操作 執行如下指令&#xff0c;清理 npm 緩存 npm cache clean --for…

STM32F103_LL庫+寄存器學習筆記05 - GPIO輸入模式,捕獲上升沿進入中斷回調

導言 GPIO設置輸入模式后&#xff0c;一般會用輪詢的方式去查看GPIO的電平狀態。比如&#xff0c;最常用的案例是用于檢測按鈕的當前狀態&#xff08;是按下還是沒按下&#xff09;。中斷的使用一般用于計算脈沖的頻率與計算脈沖的數量。 項目地址&#xff1a;https://github.…

【C++進階二】string的模擬實現

【C進階二】string的模擬實現 1.構造函數和C_strC_str: 2.operator[]3.拷貝構造3.1淺拷貝3.2深拷貝 4.賦值5.迭代器6.比較ascll碼值的大小7.reverse擴容8.push_back尾插和append尾插9.10.insert10.1在pos位置前插入字符ch10.2在pos位置前插入字符串str 11.resize12.erase12.1從…

wokwi arduino mega 2560 - 點亮LED案例

截圖&#xff1a; 點亮LED案例仿真截圖 代碼&#xff1a; unsigned long t[20]; // 定義一個數組t&#xff0c;用于存儲20個LED的上次狀態切換時間&#xff08;單位&#xff1a;毫秒&#xff09;void setup() {pinMode(13, OUTPUT); // 將引腳13設置為輸出模式&#xff08;此…

vue3項目使用 python +flask 打包成桌面應用

server.py import os import sys from flask import Flask, send_from_directory# 獲取靜態文件路徑 if getattr(sys, "frozen", False):# 如果是打包后的可執行文件base_dir sys._MEIPASS else:# 如果是開發環境base_dir os.path.dirname(os.path.abspath(__file…

后端學習day1-Spring(八股)--還剩9個沒看

一、Spring 1.請你說說Spring的核心是什么 參考答案 Spring框架包含眾多模塊&#xff0c;如Core、Testing、Data Access、Web Servlet等&#xff0c;其中Core是整個Spring框架的核心模塊。Core模塊提供了IoC容器、AOP功能、數據綁定、類型轉換等一系列的基礎功能&#xff0c;…

LeetCode 第34、35題

LeetCode 第34題&#xff1a;在排序數組中查找元素的第一個和最后一個位置 題目描述 給你一個按照非遞減順序排列的整數數組nums&#xff0c;和一個目標值target。請你找出給定目標值在數組中的開始位置和結束位置。如果數組中不存在目標值target&#xff0c;返回[-1,1]。你必須…

告別分庫分表,時序數據庫 TDengine 解鎖燃氣監控新可能

達成效果&#xff1a; 從 MySQL 遷移至 TDengine 后&#xff0c;設備數據自動分片&#xff0c;運維更簡單。 列式存儲可減少 50% 的存儲占用&#xff0c;單服務器即可支撐全量業務。 毫秒級漏氣報警響應時間控制在 500ms 以內&#xff0c;提升應急管理效率。 新架構支持未來…

第十四屆藍橋杯真題

一.LED 先配置LED的八個引腳為GPIO_OutPut,鎖存器PD2也是,然后都設置為起始高電平,生成代碼時還要去解決引腳沖突問題 二.按鍵 按鍵配置,由原理圖按鍵所對引腳要GPIO_Input 生成代碼,在文件夾中添加code文件夾,code中添加fun.c、fun.h、headfile.h文件,去資源包中把lc…

《基于機器學習發電數據電量預測》開題報告

個人主頁&#xff1a;大數據蟒行探索者 目錄 一、選題背景、研究意義及文獻綜述 &#xff08;一&#xff09;選題背景 &#xff08;二&#xff09;選題意義 &#xff08;三&#xff09;文獻綜述 1. 國內外研究現狀 2. 未來方向展望 二、研究的基本內容&#xff0c;擬解…

UWP程序用多頁面實現應用實例多開

Windows 10 IoT ARM64平臺下&#xff0c;UWP應用和MFC程序不一樣&#xff0c;同時只能打開一個應用實例。以串口程序為例&#xff0c;如果用戶希望同時打開多個應用實例&#xff0c;一個應用實例打開串口1&#xff0c;一個應用實例打開串口2&#xff0c;那么我們可以加載多個頁…

Springboot整合Netty簡單實現1對1聊天(vx小程序服務端)

本文功能實現較為簡陋&#xff0c;demo內容僅供參考&#xff0c;有不足之處還請指正。 背景 一個小項目&#xff0c;用于微信小程序的服務端&#xff0c;需要實現小程序端可以和他人1對1聊天 實現功能 Websocket、心跳檢測、消息持久化、離線消息存儲 Netty配置類 /*** au…

GitLab 中文版17.10正式發布,27項重點功能解讀【二】

GitLab 是一個全球知名的一體化 DevOps 平臺&#xff0c;很多人都通過私有化部署 GitLab 來進行源代碼托管。極狐GitLab 是 GitLab 在中國的發行版&#xff0c;專門為中國程序員服務。可以一鍵式部署極狐GitLab。 學習極狐GitLab 的相關資料&#xff1a; 極狐GitLab 官網極狐…

好消息!軟航文檔控件(NTKO WebOffice)在Chrome 133版本上提示擴展已停用的解決方案

軟航文檔控件現有版本依賴Manifest V2擴展技術支持才能正常運行&#xff0c;然而這個擴展技術到2025年6月在Chrome高版本上就徹底不支持了&#xff0c;現在Chrome 133開始的版本已經開始彈出警告&#xff0c;必須手工開啟擴展支持才能正常運行。那么如何解決這個技術難題呢&…