聲明:筆記是做項目時根據B站博主視頻學習時自己編寫,請勿隨意轉載!
一、站在巨人的肩膀上
SE模塊即Squeeze-and-Excitation 模塊,這是一種常用于卷積神經網絡中的注意力機制!!
借鑒代碼的代碼鏈接如下:
注意力機制-SEhttps://github.com/ZhugeKongan/Attention-mechanism-implementation
需要model里面的SE_block.py文件
# -*- coding: UTF-8 -*-
"""
SE structure"""import torch.nn as nn # 導入PyTorch的神經網絡模塊
import torch.nn.functional as F # 導入PyTorch的神經網絡功能函數模塊 class SE(nn.Module): # 定義一個名為SE的類,該類繼承自PyTorch的nn.Module,表示一個神經網絡模塊 def __init__(self, in_chnls, ratio): # 初始化函數,in_chnls表示輸入通道數,ratio表示壓縮比率 super(SE, self).__init__() # 調用父類nn.Module的初始化函數 # 使用AdaptiveAvgPool2d將輸入的空間維度壓縮為1x1,即全局平均池化 self.squeeze = nn.AdaptiveAvgPool2d((1, 1)) # 使用1x1卷積將通道數壓縮為原來的1/ratio,實現特征壓縮 self.compress = nn.Conv2d(in_chnls, in_chnls // ratio, 1, 1, 0) # 使用1x1卷積將通道數擴展回原來的in_chnls,實現特征激勵 self.excitation = nn.Conv2d(in_chnls // ratio, in_chnls, 1, 1, 0) def forward(self, x): # 定義前向傳播函數 out = self.squeeze(x) # 對輸入x進行全局平均池化 out = self.compress(out) # 對池化后的輸出進行特征壓縮 out = F.relu(out) # 對壓縮后的特征進行ReLU激活 out = self.excitation(out) # 對激活后的特征進行特征激勵 # 對激勵后的特征應用sigmoid函數,然后與原始輸入x進行逐元素相乘,實現特征重標定 return x*F.sigmoid(out)
代碼后面有附注的注釋(GPT解釋的,很好用),理解即可。對于使用者來說,重要關注點還是它的輸入通道、輸出通道、需要傳入的參數等!!這個函數整體傳入in_chnls, ratio兩個參數。
二、開始修改網絡結構
與上節的C2f修改基本流程一致,但稍有不同
- model/common.py加入新增的SE網絡結構,直接復制粘貼如下,這里加在了上節的C2f之前:
上面說到這個函數整體傳入in_chnls, ratio兩個參數!!
- model/yolo.py設定網絡結構的傳參細節
上期的C2f模塊之所以可以參照原本存在的C3模塊屬性,是因為兩者相似,但這里的SE模塊就不可簡單的在C3x后加SE,而是需要在下面加入一段elif代碼:
elif m is SE:c1 = ch[f]c2 = args[0]if c2 != no: # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, args[1]]
即當新引入的模塊中存在輸入輸出維度時,需要使用gw調整輸出維度!!
- model/yolov5s.yaml設定現有模型結構配置文件
老樣子,復制一份新的配置文件命名為yolov5s-se.yaml。首先需要在backbone的最后加上SE模塊(相當于多了一層為第10層);其次考慮到backbone里多了一層,且在head里的輸入層來源不止上一層(-1)一個,所以輸入層來源大于等于第10層的都需要改為往后遞推+1層。下圖左邊為原始的yaml配置文件,右側為修改后的:
即當yaml文件引入新的層后,需要修改模型結構的from參數(上期是將C3替換為C2f模塊,所以不涉及這一點)!!
- train.py訓練時指定模型結構配置文件
這次將parse_model函數里的第二個參數cfg改為yolov5s-se.yaml即可,運行train.py開始訓練!!
可見訓練時第10層已經引入了SE注意力機制模塊:
100次迭代后結果如下,結果保存在runs\train\exp12文件夾,文件夾里有很多指標曲線可對比分析:
?往期精彩
STM32專欄(9.9)http://t.csdnimg.cn/A3BJ2
OpenCV-Python專欄(9.9)http://t.csdnimg.cn/jFJWe
AI底層邏輯專欄(9.9)http://t.csdnimg.cn/6BVhM
機器學習專欄(免費)http://t.csdnimg.cn/ALlLlSimulink專欄(免費)
http://t.csdnimg.cn/csDO4電機控制專欄(免費)
http://t.csdnimg.cn/FNWM7