😎😎😎物體檢測-系列教程 總目錄
有任何問題歡迎在下面留言
本篇文章的代碼運行界面均在Pycharm中進行
本篇文章配套的代碼資源已經上傳
點我下載源碼
14、Model類
14.2 前向傳播
def forward(self, x, augment=False, profile=False):if augment:img_size = x.shape[-2:] # height, widths = [1, 0.83, 0.67] # scalesf = [None, 3, None] # flips (2-ud, 3-lr)y = [] # outputsfor si, fi in zip(s, f):xi = scale_img(x.flip(fi) if fi else x, si)yi = self.forward_once(xi)[0] # forwardyi[..., :4] /= si # de-scaleif fi == 2:yi[..., 1] = img_size[0] - yi[..., 1] # de-flip udelif fi == 3:yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lry.append(yi)return torch.cat(y, 1), None # augmented inference, trainelse:return self.forward_once(x, profile) # single-scale inference, train
這段代碼是forward
方法的實現,它定義了模型的前向傳播過程,支持正常和增強兩種推理模式:
- 前向傳播函數,輸入
x
,是否進行數據增強augment
,是否分析性能profile
- 是否使用數據增強
- img_size ,獲取輸入圖像的長寬
- s,定義縮放尺度
- f,定義翻轉模式,這里
None
表示不翻轉,3
表示左右翻轉 - y,初始化輸出列表
- 使用zip函數將尺度因子列表s和翻轉指示列表f組合起來,然后遍歷每一對尺度因子和翻轉指示
- xi,如果fi不為None,先根據fi的值對圖像進行翻轉,然后調用scale_img函數根據si的值縮放處理圖像;否則直接調用scale_img函數根據si的值縮放處理圖像
- yi,將xi進行一次前向傳播,取第一個輸出
- 對輸出yi的前四個維度進行縮放調整,以恢復到原始的尺度。這通常是對邊界框坐標的調整
- 如果使用了上下翻轉
- 則調整y的坐標
- 如果使用了左右翻轉
- 則調整x坐標
- 將處理后的輸出添加到列表
- 將list y的所有輸出按照第一個維度進行拼接
- 如果在當前循環中沒有使用數據增強
- 直接進行一次正常的前向傳播
前向傳播方法,包括了一個可選的圖像增強步驟。在增強模式下,通過對輸入圖像應用不同的尺度和翻轉,生成多個變體,對每個變體單獨進行前向傳播,并對輸出進行調整以適應原始圖像的尺寸和方向,最后將所有變體的輸出合并。這種方法可以增加模型的泛化能力,因為它讓模型在訓練時見到更多的數據變化。如果不進行圖像增強,它將執行一次標準的前向傳播。通過這種設計,模型可以更靈活地應對不同的輸入和訓練需求
14.3 forward_once函數
def forward_once(self, x, profile=False):y, dt = [], [] # outputsfor m in self.model:if m.f != -1: # if not from previous layerx = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]if profile:try:import thopo = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # FLOPSexcept:o = 0t = time_synchronized()for _ in range(10):_ = m(x)dt.append((time_synchronized() - t) * 100)print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))x = m(x) # runy.append(x if m.i in self.save else None) # save outputif profile:print('%.1fms total' % sum(dt))return x
- forward_once函數,輸入和forward函一樣
- y, dt ,初始化兩個空列表,y用于存儲每一層的輸出,dt用于在性能分析模式下存儲每一層的執行時間
- 遍歷模型的每一層
- 如果當前層的輸入不是來自上一層的輸出
- 如果m.f是整數,則直接從y中獲取對應的層輸出作為輸入。如果m.f是一個列表,則根據列表中的索引從y中選擇輸入,如果索引為-1,則使用原始輸入x
- 是否開啟性能分析模式
- try
- 導入thop庫,用于計算浮點運算數(FLOPS)
- o,使用thop.profile計算當前層m的FLOPS,結果除以1E9轉換為GigaFLOPS,并乘以2。這里假設thop.profile返回的是一個元組,其第一個元素是所需的FLOPS
- 如果嘗試執行失敗
- 則將o(FLOPS)設置為0
- t,調用time_synchronized函數,獲取當前精確的時間
- 循環10次
- 為了穩定測量時間,通過多次執行減少偶然誤差
- 調用time_synchronized函數計算執行當前層操作的總時間,并將其添加到dt列表中
- 打印當前層的FLOPS、參數數量、執行時間和層類型。為性能分析提供詳細信息
- 執行當前層的前向傳播,并更新x為該層的輸出
- 如果當前層的索引m.i在保存列表self.save中,則將輸出x保存到y列表中;否則,保存
None
. 這樣做可以減少內存占用,只保存那些后續步驟中需要的層的輸出 - 再次檢查是否開啟了性能分析模式。這個檢查是為了在性能分析完成后打印總的執行時間
- 如果開啟了性能分析,計算所有層執行時間的總和并打印。這提供了整個前向傳播過程的總執行時間,幫助了解模型的性能瓶頸
- 返回最后一層的輸出
14.4 _initialize_biases函數
def _initialize_biases(self, cf=None):m = self.model[-1] # Detect() modulefor mi, s in zip(m.m, m.stride): # fromb = mi.bias.data.view(m.na, -1).clone()obj_add = math.log(8 / (640 / s) ** 2) # 計算obj層需要增加的值cls_add = math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum())b[:, 4] = b[:, 4] + obj_addb[:, 5:] = b[:, 5:] + cls_addmi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
- 初始化偏執的函數,接受一個可選的參數,這個參數用于根據數據集中各類別出現的頻率來調整分類(cls)層的偏置
- m,獲取模型中的最后一個模塊,檢測層(Detect模塊),用于目標檢測
- 遍歷檢測層中的每個子模塊mi及其對應的步長stride,這里的步長是指輸入圖像被縮減的尺度,對目標尺寸預測非常關鍵
- b,獲取子模塊mi的偏置項,并將其重塑(reshape)成(m.na, -1)的形狀,其中m.na是每個特征圖位置預測的錨框數量。.clone()確保在修改b時不會影響原始的偏置值
- obj_add ,計算對象(obj)層偏置需要增加的值。這個公式基于假設每640像素的圖像中有8個對象,并根據特征圖的尺度(通過步長s計算)來調整。目的是調整檢測層對于不同尺寸特征圖上對象數量預測的偏置
- cls_add ,計算分類(cls)層偏置需要增加的值。如果沒有提供類頻率(cf為None),則使用一個基于類數量m.nc的固定公式。如果提供了類頻率,那么使用類頻率來計算每個類的偏置調整值,以此反映數據集中類別的分布
- 將計算出的對象層偏置調整值加到b的第4列上,這是因為在目標檢測中,偏置項通常包括4個坐標偏置和一個對象存在的偏置,后者位于第5個位置(索引為4)
- 將計算出的分類層偏置調整值加到b的第5列及之后的所有列上,對應于每個類別的偏置
- 將調整后的偏置b重塑回原始形狀并設置為mi的偏置,確保這些偏置在訓練過程中可以被進一步調整(requires_grad=True)
14.5 其他輔助函數
def _print_biases(self):m = self.model[-1] # Detect() modulefor mi in m.m: # fromb = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)print(('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
- 獲取模型的最后一個模塊,這里假設是一個目標檢測模塊(Detect模塊)
- 遍歷檢測模塊中的每個子模塊mi
- 取得當前子模塊mi的偏置,通過.detach()確保不會影響梯度計算,.view(m.na, -1)調整形狀以匹配錨點數量m.na和偏置的其它維度,最后進行轉置以便于處理
- 打印當前子模塊卷積層的輸入通道數和偏置的統計信息,包括前五個偏置的平均值和之后所有偏置的平均值
fuse函數,用于融合模型中的卷積層(Conv2d)和批歸一化層(BatchNorm2d)
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layersprint('Fusing layers... ')for m in self.model.modules():if type(m) is Conv:m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatabilitym.conv = fuse_conv_and_bn(m.conv, m.bn) # update convm.bn = None # remove batchnormm.forward = m.fuseforward # update forwardself.info()return self
- 遍歷模型中的所有模塊
- 檢查當前模塊是否為卷積層
- 為了兼容PyTorch 1.6.0,清空非持久性緩沖區集合
- 使用fuse_conv_and_bn函數來融合當前卷積層和其后的批歸一化層
- 將批歸一化層設為None,表示移除批歸一化層
- 更新模塊的前向傳播函數為融合后的版本
- 在完成融合后,調用info方法打印模型信息
- 返回更新后的模型實例
def info(self): # print model informationmodel_info(self)
調用一個model_info函數,傳入當前模型實例,用于收集和打印模型的詳細信息,如參數數量、層的類型等