文章目錄
- 以YoloV4-tiny為例
- 要加入的注意力機制代碼
- 模型中插入注意力機制
以YoloV4-tiny為例
解釋一下各個部分:
- 最左邊這部分為主干提取網絡,功能為特征提取
- 中間這邊部分為FPN,功能是加強特征提取
- 最后一部分為yolo head,功能為獲得我們具體的一個預測結果
需要明白幾個點:
- 注意力機制模塊是一個即插即用的模塊,理論上是可以添加到任何一個特征圖后面
- 但是,不建議添加到主干部分(即最左邊的那部分),主干部分所用的特征是我們后面處理所用的基礎,故不建議添加到主干部分
- 如果添加到主干部分,由于注意力機制模塊 它的權值模塊是隨機初始化的,那主干部分的權值就被破壞了,最開始提取出來的特征就不好用了。
- 故建議把注意力機制模塊添加到主干以外的部分
本節把注意力機制添加到加強網絡里面,即上圖的中間部分。
添加注意力機制可以添加到上圖標注的部分。
要加入的注意力機制代碼
這一部分為要加入的注意力機制模塊,文件名為attention.py
import torch
from torch import nn
# 通道注意力機制
class channel_attention(nn.Module):def __init__(self,channel,ration=16): #因為要進行全連接,故需要傳入通道數量,及縮放比例super(channel_attention,self).__init__() #初始化#定義最大池化層self.max_pool = nn.AdaptiveMaxPool2d(1) #輸出層的高和寬是1#定義平均池化self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(#定義第一次全連接nn.Linear(channel,channel // ration ,False),nn.ReLU(),# 定義第二次全連接nn.Linear(channel//ration,channel,False))#由于圖中的通道注意力機制是連個全連接層相加之后再取sigmoidself.sigmoid=nn.Sigmoid()#前傳部分def forward(self,x):b,c,h,w=x.size()#首先對輸入進來的x先進行一個全局最大池化 在進行一個全局平均池化max_pool_out=self.max_pool(x).view([b,c])avg_pool_out=self.avg_pool(x).view([b,c])#然后對兩次池化后的結果用共享的全連接層fc進行處理max_fc_out=self.fc(max_pool_out)avg_fc_out=self.fc(avg_pool_out)#最后將上面的兩個結果進行相加out=max_fc_out + avg_fc_outout=self.sigmoid(out).view([b,c,1,1])#print(out)return out * x
# 空間注意力機制
class spacial_attention(nn.Module):def __init__(self,kernel_size=7): #空間注意力沒有通道數,故不用傳入channel和ration#但是空間注意力會進行一次卷積,故我們需要關注卷積核大小,一般為3或7super(spacial_attention,self).__init__() #初始化padding=7//2 #卷積核大小整除輸入通道數self.conv=nn.Conv2d(2,1,kernel_size,1,padding,bias=False)#由圖可知輸入通道數是2,輸出通道數為1,卷積核大小默認設置為7,步長為1,因為不需要壓縮特征層阿高和寬#由于圖中的通道注意力機制是連個全連接層相加之后再取sigmoidself.sigmoid=nn.Sigmoid()#空間注意力機制前傳部分def forward(self,x):b,c,h,w=x.size()max_pool_out,_= torch.max(x,dim=1,keepdim=True)#需要把通道這一維度保留下來,故設置keepdim為True#對于pytorch來講,它的通道是在第一維度,也就是batchsize后面的那個維度故定義dim為1mean_pool_out = torch.mean(x,dim = 1,keepdim=True)#對最大值和平均值進行一個堆疊pool_out = torch.cat([max_pool_out, mean_pool_out],dim=1)#對堆疊后的結果取一個卷積out=self.conv(pool_out)out=self.sigmoid(out)print(out)return out * x#把空間注意力機制和通道注意力機制進行一個融合
class Cbam(nn.Module):def __init__(self,channel,ratio=16,kernel_size=7):super(Cbam,self).__init__()#調用已經定義好的2個注意力機制self.channel_attention=channel_attention(channel,ratio)self.spacial_attention = spacial_attention(kernel_size)#融合后機制的前傳部分def forward(self,x):x=self.channel_attention(x)x=self.spacial_attention(x)return x
在模型文件(yolo.py)中,首行添加如下部分
from .attention import se_block,cbam_block,eca_block
attention_blocks=[se_block,cbam_block,eca_block]
為何要設置成上面的形式?
為了方便調用,到時候可以直接編寫下面的代碼調用具體的注意力機制模塊
attention_blocks[0]
之后,需要找到yolo.py里面的模型主體部分,大概形式如下代碼
class YoloBody(nn.Module):def __init__(self,anchors_mask,num_classes,phi=0)#在原來的代碼上只是添加了phi,代表我們選用的注意力機制模塊,默認情況下為0super(YoloBody, self).__init__()self.backbone = darknet53_tiny(None)self.conv_for_P5 = BasicConv(512,256,1)self.yolo_headP5 = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)self.upsample = Upsample(256,128)self.yolo_headP4 = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)#下面這部分為自己填寫self.phi = phi #這個是自己添加的if 1 <= self.phi and self.phi <= 3:self.feat1_att = attention_block[self.phi - 1](256) #通道數為256self.feat2_att = attention_block[self.phi - 1](512)#通道數為512self.upsample_att = attention_block[self.phi - 1](128)#通道數為128#通道數到底是多少看這個模型的前傳部分的通道數為多少def forward(self, x):#---------------------------------------------------## 生成CSPdarknet53_tiny的主干模型# feat1的shape為26,26,256# feat2的shape為13,13,512#---------------------------------------------------#feat1, feat2 = self.backbone(x)#下面代碼為自己填寫if 1 <= self.phi and self.phi <= 3:#如果滿足條件就添加具體的注意力機制feat1 = self.feat1_att(feat1)feat2 = self.feat2_att(feat2)#下面代碼模型自帶# 13,13,512 -> 13,13,256P5 = self.conv_for_P5(feat2)# 13,13,256 -> 13,13,512 -> 13,13,255out0 = self.yolo_headP5(P5) # 13,13,256 -> 13,13,128 -> 26,26,128P5_Upsample = self.upsample(P5)# 26,26,256 + 26,26,128 -> 26,26,384#上面代碼模型自帶,下面代碼自己編寫if 1 <= self.phi and self.phi <= 3:P5_Upsample = self.upsample_att(P5_Upsample)#下面代碼模型自帶P4 = torch.cat([P5_Upsample,feat1],axis=1)# 26,26,384 -> 26,26,256 -> 26,26,255out1 = self.yolo_headP4(P4)return out0, out1