目錄
思考
一、代碼功能分析
1. 構建 shortcut 分支(殘差連接的旁路)
2. 主路徑的第一層卷積(1×1)
4. 主路徑的第三層卷積(1×1)
5. 殘差連接 + 激活函數
二、問題分析總結:殘差結構中通道數不一致的風險
1.深度學習框架中的張量加法規則
2.為何代碼可能未報錯的原因分析
3.代碼中的潛在風險與不足
改進建議
顯式檢查通道維度
自動調整 shortcut 的通道數
統一殘差結構接口,增強模塊的通用性和健壯性
?
?
- 🍨 本文為🔗365天深度學習訓練營?中的學習記錄博客
- 🍖 原作者:K同學啊
📌你需要解決的疑問:這個代碼是否有錯?對錯與否都請給出你的思考
📌打卡要求:請查找相關資料、逐步推理模型、詳細寫下你的思考過程
# 定義殘差單元
def block(x, filters, strides=1, groups=32, conv_shortcut=True): if conv_shortcut: shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x) # epsilon為BN公式中防止分母為零的值 shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut) else: # identity_shortcut shortcut = x # 三層卷積層 x = Conv2D(filters=filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x) x = BatchNormalization(epsilon=1.001e-5)(x) x = ReLU()(x) # 計算每組的通道數 g_channels = int(filters / groups) # 進行分組卷積 x = grouped_convolution_block(x, strides, groups, g_channels) x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x) x = BatchNormalization(epsilon=1.001e-5)(x) x = Add()([x, shortcut]) x = ReLU()(x) return x
思考
一、代碼功能分析
1. 構建 shortcut 分支(殘差連接的旁路)
-
目的:shortcut 分支處理路徑
-
如果 conv_shortcut = True,表示輸入的維度(通道或空間)與主路徑輸出不一致,因此需要用 1x1 卷積調整它。
-
若為 False,說明維度匹配,直接將輸入作為 shortcut。
-
-
1×1卷積:調整維度、實現下采樣(若 strides=2);
-
BatchNorm:標準化,有助于加速訓練與收斂。
? ? 作用:為殘差連接做準備。
2. 主路徑的第一層卷積(1×1)
-
目的:降維,減少通道數,從輸入通道數 C_in → filters
-
保持空間尺寸不變,設置 stride=1
? ?作用:減少參數數量,提高計算效率,為分組卷積做準備。
3. 分組卷積(3×3)
-
groups:將輸入通道劃分為多個小組,每組獨立卷積。
-
g_channels:每組處理的通道數。
-
grouped_convolution_block 負責執行分組卷積。
?作用:提升網絡的表達力,降低計算量,是 ResNeXt 的關鍵創新。
4. 主路徑的第三層卷積(1×1)
-
將通道數從 filters → filters * 2,以與 shortcut 的輸出通道對齊。
-
保持空間尺寸不變。
作用:恢復原通道數,為后續殘差連接準備。
5. 殘差連接 + 激活函數
-
Add():將主路徑輸出與 shortcut 相加,實現殘差學習。
-
ReLU():激活輸出,非線性映射。
作用:引入殘差連接,緩解深層網絡退化,提高訓練效率與性能。
二、問題分析總結:殘差結構中通道數不一致的風險
在所提供的殘差單元 block 函數中,存在一個潛在的問題點,即當 conv_shortcut=False 時,shortcut 分支直接使用輸入張量 x,而沒有經過任何通道數調整操作。與此同時,主路徑經過卷積操作之后,其輸出通道數被顯式設定為 filters * 2。這樣,在執行 Add() 操作時,如果輸入張量 x 的通道數并不等于 filters * 2,就會出現形狀不匹配的錯誤。
1.深度學習框架中的張量加法規則
以 TensorFlow/Keras 框架為例,在執行張量加法時,要求輸入張量的 shape 必須完全一致,包括 batch、height、width 和 channels 四個維度。盡管在某些操作中支持 broadcasting(廣播機制),但通道維度是不能自動廣播的。因此,如果主路徑輸出和 shortcut 的通道維度不一致,Add() 操作會直接報錯,通常為 InvalidArgumentError。
2.為何代碼可能未報錯的原因分析
盡管代碼存在上述邏輯風險,但在某些特定條件下,運行時并不會出錯,可能原因包括:
-
測試階段未觸發問題分支
代碼運行過程中,conv_shortcut 參數始終為 True,因此始終使用了 1×1 卷積來調整 shortcut 的通道數,沒有觸發錯誤分支。
-
輸入張量的通道數剛好等于目標通道數
如果傳入的張量 x 的通道數恰好等于 filters * 2,即使沒有卷積調整,兩個分支的通道數也能對齊,因此不會出錯。
-
grouped_convolution_block 可能改變了主路徑輸出通道數
主路徑中的分組卷積函數 grouped_convolution_block 未給出定義。如果它在內部改變了特征圖的通道數,可能導致主路徑輸出與 shortcut 在某些情況下恰好匹配,也可能在某些輸入下失配。
3.代碼中的潛在風險與不足
當前的 block 函數存在以下幾點風險:
-
對通道數匹配的依賴是隱式的,未做任何斷言或自動調整;
-
conv_shortcut=False 的分支缺乏魯棒性,容易在模型設計中因輸入不符而觸發錯誤;
-
沒有對 filters 與 groups 之間的關系進行合法性檢查(如 filters 不可被 groups 整除時分組卷積將出錯)。
if not conv_shortcut:shortcut = Conv2D(filters * 2, kernel_size=1, strides=strides, padding='same', use_bias=False)(x)shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)
改進建議
-
顯式檢查通道維度
在 conv_shortcut=False 的分支中添加通道數一致性斷言:
if not conv_shortcut:assert x.shape[-1] == filters * 2, "Shortcut and main path channel mismatch."
-
自動調整 shortcut 的通道數
即使 conv_shortcut=False,也建議通過 1×1 卷積層使 shortcut 與主路徑對齊:
if not conv_shortcut:shortcut = Conv2D(filters * 2, kernel_size=1, strides=strides, padding='same', use_bias=False)(x)shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)
-
統一殘差結構接口,增強模塊的通用性和健壯性
建議將 shortcut 的處理邏輯統一封裝,避免依賴外部輸入的特殊情況。