實驗描述
將GobalM模塊加入到changerEx的stage2中。
下面展示一些內聯片段:
model = dict(backbone=dict(interaction_cfg=(None,dict(type='GlobalM', embed_dim=128,num_heads=32,axial_strategy='row'),dict(type='ChannelExchange', p=1/2),dict(type='ChannelExchange', p=1/2))),decode_head=dict(num_classes=2,sampler=dict(type='mmseg.OHEMPixelSampler', thresh=0.7, min_kept=100000)),# test_cfg=dict(mode='slide', crop_size=crop_size, stride=(crop_size[0]//2, crop_size[1]//2)),)
GlobalM的定義如下:
@ITERACTION_LAYERS.register_module()
class GlobalM(nn.Module):"""全局空間多頭交互自注意力模塊(Global-M)功能:挖掘高維空間光譜特征中的全局空間相關性參數說明:embed_dim: 嵌入維度(特征通道數)num_heads: 注意力頭數axial_strategy: 軸向分割策略('row'行分割/'column'列分割)"""def __init__(self, embed_dim, num_heads, axial_strategy):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsself.axial_strategy = axial_strategy# 1. QKV投影層(使用1x1卷積實現)# 輸入:B×C×H×W → 輸出:B×3C×H×W(分別對應Q,K,V)self.qkv_proj = nn.Conv2d(embed_dim, embed_dim * 3, kernel_size=1)# 2. 多頭交互卷積(論文中的ω^{3×3}操作)# 用于融合不同注意力頭的特征self.mh_interaction = nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1)# 3. 前饋網絡(FFN)# 結構:1x1卷積擴展→GELU激活→1x1卷積壓縮self.ffn = nn.Sequential(nn.Conv2d(embed_dim, embed_dim * 4, kernel_size=1),nn.GELU(),nn.Conv2d(embed_dim * 4, embed_dim, kernel_size=1))# 4. 層歸一化self.norm1 = nn.LayerNorm(embed_dim)self.norm2 = nn.LayerNorm(embed_dim)# 注意力縮放因子self.scale = self.head_dim ** -0.5def forward(self, x1, x2):"""前向傳播過程輸入:x (B, C, H, W)輸出:增強后的特征圖 (B, C, H, W)修改為兼容 IA_ResNet 的雙輸入結構,并返回兩個輸出參數:x1: 主輸入特征 (B, C, H, W)x2: 次輸入 (本模塊未使用,僅為保持接口兼容)返回:(out, x2): 返回處理后的特征和原始 x2(保持雙輸出結構)"""B, C, H, W = x1.shaperesidual = x1 # 殘差連接# === 第一階段:多頭自注意力 ===# 1. 層歸一化x_norm = self.norm1(x1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)# 2. 生成Q,K,Vqkv = self.qkv_proj(x_norm) # B×3C×H×Wq, k, v = qkv.chunk(3, dim=1) # 各為B×C×H×W# 3. 全局軸向分割(GAS策略)if self.axial_strategy == 'row':# 行分割:將特征圖按行分成H個token,每個token尺寸為W×Cq = q.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2) # B×H×Nh×W×Dhk = k.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2)v = v.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 3, 1, 4, 2)else: # column# 列分割:將特征圖按列分成W個token,每個token尺寸為H×Cq = q.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2) # B×W×Nh×H×Dhk = k.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2)v = v.reshape(B, self.num_heads, self.head_dim, H, W).permute(0, 4, 1, 3, 2)# 4. 計算注意力權重(QK^T/sqrt(d_k))attn = (q @ k.transpose(-2, -1)) * self.scale # B×L×Nh×L×L (L=H/W)attn = attn.softmax(dim=-1)# 5. 注意力加權求和out = attn @ v # B×L×Nh×L×Dh# 6. 恢復原始形狀if self.axial_strategy == 'row':out = out.permute(0, 2, 4, 1, 3) # B×Nh×Dh×H×Welse:out = out.permute(0, 2, 4, 3, 1) # B×Nh×Dh×H×Wout = out.reshape(B, C, H, W)# 7. 多頭交互(3x3卷積融合多頭特征)out = self.mh_interaction(out)out += residual # 殘差連接# === 第二階段:前饋網絡 ===residual = outout = self.norm2(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)out = self.ffn(out)out += residual# 返回 (out, x2),即使 x2 未被修改return out, x2 if x2 is not None else x1 # 如果 x2 是 None,返回 x1 作為占位
實驗結果
在s2looking上的訓練完成后,在驗證集(每8k次進行一次驗證)上的結果為:
2025/04/16 01:10:45 - mmengine - INFO - Iter(val) [ 50/500] eta: 0:00:17 time: 0.0331 data_time: 0.0058 memory: 578
2025/04/16 01:10:47 - mmengine - INFO - Iter(val) [100/500] eta: 0:00:14 time: 0.0402 data_time: 0.0130 memory: 578
2025/04/16 01:10:49 - mmengine - INFO - Iter(val) [150/500] eta: 0:00:12 time: 0.0317 data_time: 0.0045 memory: 578
2025/04/16 01:10:50 - mmengine - INFO - Iter(val) [200/500] eta: 0:00:10 time: 0.0318 data_time: 0.0046 memory: 578
2025/04/16 01:10:52 - mmengine - INFO - Iter(val) [250/500] eta: 0:00:08 time: 0.0319 data_time: 0.0046 memory: 578
2025/04/16 01:10:54 - mmengine - INFO - Iter(val) [300/500] eta: 0:00:06 time: 0.0321 data_time: 0.0047 memory: 578
2025/04/16 01:10:55 - mmengine - INFO - Iter(val) [350/500] eta: 0:00:05 time: 0.0318 data_time: 0.0045 memory: 578
2025/04/16 01:10:57 - mmengine - INFO - Iter(val) [400/500] eta: 0:00:03 time: 0.0361 data_time: 0.0087 memory: 578
2025/04/16 01:10:59 - mmengine - INFO - Iter(val) [450/500] eta: 0:00:01 time: 0.0320 data_time: 0.0047 memory: 578
2025/04/16 01:11:00 - mmengine - INFO - Iter(val) [500/500] eta: 0:00:00 time: 0.0310 data_time: 0.0041 memory: 578
2025/04/16 01:11:00 - mmengine - INFO - per class results:
2025/04/16 01:11:00 - mmengine - INFO -
+-----------+--------+-----------+--------+-------+-------+
| Class | Fscore | Precision | Recall | IoU | Acc |
+-----------+--------+-----------+--------+-------+-------+
| unchanged | 98.96 | 99.07 | 98.85 | 97.94 | 98.85 |
| changed | 30.29 | 28.2 | 32.71 | 17.85 | 32.71 |
+-----------+--------+-----------+--------+-------+-------+
2025/04/16 01:11:00 - mmengine - INFO - Iter(val) [500/500] aAcc: 97.9500 mFscore: 64.6200 mPrecision: 63.6300 mRecall: 65.7800 mIoU: 57.8900 mAcc: 65.7800 data_time: 0.0062 time: 0.0340
在測試集上的結果為:
2025/04/22 19:40:01 - mmengine - WARNING - The prefix is not set in metric class IoUMetric.
2025/04/22 19:40:01 - mmengine - INFO - Load checkpoint from changer_r18_globalM_stage2/iter_80000.pth
2025/04/22 19:40:17 - mmengine - INFO - Iter(test) [ 50/1000] eta: 0:04:56 time: 0.1928 data_time: 0.1650 memory: 17217
2025/04/22 19:40:26 - mmengine - INFO - Iter(test) [ 100/1000] eta: 0:03:40 time: 0.1764 data_time: 0.1476 memory: 479
2025/04/22 19:40:35 - mmengine - INFO - Iter(test) [ 150/1000] eta: 0:03:10 time: 0.1709 data_time: 0.1438 memory: 479
2025/04/22 19:40:43 - mmengine - INFO - Iter(test) [ 200/1000] eta: 0:02:46 time: 0.1655 data_time: 0.1373 memory: 479
2025/04/22 19:40:51 - mmengine - INFO - Iter(test) [ 250/1000] eta: 0:02:27 time: 0.1496 data_time: 0.1225 memory: 479
2025/04/22 19:40:59 - mmengine - INFO - Iter(test) [ 300/1000] eta: 0:02:14 time: 0.1564 data_time: 0.1296 memory: 479
2025/04/22 19:41:07 - mmengine - INFO - Iter(test) [ 350/1000] eta: 0:02:02 time: 0.1574 data_time: 0.1296 memory: 479
2025/04/22 19:41:16 - mmengine - INFO - Iter(test) [ 400/1000] eta: 0:01:51 time: 0.1477 data_time: 0.1202 memory: 479
2025/04/22 19:41:24 - mmengine - INFO - Iter(test) [ 450/1000] eta: 0:01:41 time: 0.1874 data_time: 0.1595 memory: 479
2025/04/22 19:41:33 - mmengine - INFO - Iter(test) [ 500/1000] eta: 0:01:32 time: 0.1641 data_time: 0.1348 memory: 479
2025/04/22 19:41:41 - mmengine - INFO - Iter(test) [ 550/1000] eta: 0:01:21 time: 0.1415 data_time: 0.1143 memory: 479
2025/04/22 19:41:49 - mmengine - INFO - Iter(test) [ 600/1000] eta: 0:01:12 time: 0.1890 data_time: 0.1607 memory: 479
2025/04/22 19:41:58 - mmengine - INFO - Iter(test) [ 650/1000] eta: 0:01:02 time: 0.1698 data_time: 0.1416 memory: 479
2025/04/22 19:42:06 - mmengine - INFO - Iter(test) [ 700/1000] eta: 0:00:53 time: 0.1320 data_time: 0.1043 memory: 479
2025/04/22 19:42:13 - mmengine - INFO - Iter(test) [ 750/1000] eta: 0:00:43 time: 0.1528 data_time: 0.1256 memory: 479
2025/04/22 19:42:22 - mmengine - INFO - Iter(test) [ 800/1000] eta: 0:00:35 time: 0.1697 data_time: 0.1419 memory: 479
2025/04/22 19:42:31 - mmengine - INFO - Iter(test) [ 850/1000] eta: 0:00:26 time: 0.1735 data_time: 0.1456 memory: 479
2025/04/22 19:42:39 - mmengine - INFO - Iter(test) [ 900/1000] eta: 0:00:17 time: 0.1672 data_time: 0.1386 memory: 479
2025/04/22 19:42:49 - mmengine - INFO - Iter(test) [ 950/1000] eta: 0:00:08 time: 0.1929 data_time: 0.1649 memory: 479
2025/04/22 19:42:58 - mmengine - INFO - Iter(test) [1000/1000] eta: 0:00:00 time: 0.1733 data_time: 0.1454 memory: 479
2025/04/22 19:42:58 - mmengine - INFO - per class results:
2025/04/22 19:42:58 - mmengine - INFO -
+-----------+--------+-----------+--------+-------+-------+
| Class | Fscore | Precision | Recall | IoU | Acc |
+-----------+--------+-----------+--------+-------+-------+
| unchanged | 99.39 | 98.79 | 100.0 | 98.79 | 100.0 |
| changed | 0.08 | 90.77 | 0.04 | 0.04 | 0.04 |
+-----------+--------+-----------+--------+-------+-------+
2025/04/22 19:42:58 - mmengine - INFO - Iter(test) [1000/1000] aAcc: 98.7900 mFscore: 49.7400 mPrecision: 94.7800 mRecall: 50.0200 mIoU: 49.4100 mAcc: 50.0200 data_time: 0.1437 time: 0.1761
實驗結果分析
關鍵觀察:
1.類別間性能差異顯著:
驗證集:unchanged類Fscore=98.96 vs changed類Fscore=30.29
測試集:unchanged類Fscore=99.39 vs changed類Fscore=0.08
2.測試集性能崩塌:
changed類的Recall從驗證集32.71驟降到0.04,說明模型完全無法檢測變化區域
3.訓練-測試泛化差距:
驗證集mIoU=57.89 → 測試集mIoU=49.41,顯示過擬合風險
可能原因分析:
-
- 類別極端不平衡問題
從驗證集結果推測數據分布中unchanged樣本占比極高(可能超過99%)
模型學習到"always predict unchanged"的簡單策略即可獲得高整體準確率
測試集changed類樣本可能更少或分布差異更大
- 類別極端不平衡問題
-
- GlobalM模塊適配性問題
軸向注意力(row/column分割)可能破壞局部空間關系,對變化檢測需要的精細定位不利
32個注意力頭過多(通常建議head_dim≥32),可能導致注意力過于分散
多頭交互的3x3卷積可能引入不必要的位置偏置
- GlobalM模塊適配性問題
-
- 訓練策略缺陷
OHEM設置不當(thresh=0.7過高),難例挖掘未能有效捕捉變化樣本
沒有使用類別平衡損失函數(如Focal Loss)
可能缺乏有效的數據增強(特別是對變化區域的增強)
- 訓練策略缺陷
-
- 特征交互設計問題
Stage2中GlobalM與ChannelExchange交替使用可能導致特征混淆
雙流結構的信息融合方式不夠合理(x2特征未被有效利用)
- 特征交互設計問題
改進建議
1.數據層面:
-
重平衡數據集:
使用加權采樣(oversampling changed類)
引入copy-paste augmentation人工增加變化區域 -
增強策略:
針對變化檢測設計空間變換增強(如非對稱形變)
使用MixUp增強策略平衡類別
2. 模型層面:
- 注意力機制改進:
# 修改GlobalM參數配置
dict(type='GlobalM', embed_dim=128,num_heads=8, # 減少注意力頭數axial_strategy='hybrid', # 增加行列交替策略use_local_attn=True) # 新增局部注意力分支
- 特征交互優化:
在GlobalM后增加空間注意力門控機制
將ChannelExchange替換為可學習的ChannelAttention - 解碼頭改進:
decode_head=dict(num_classes=2,loss_decode=dict(type='FocalLoss', loss_weight=[1.0, 5.0]), # 類別加權sampler=dict(type='mmseg.OHEMPixelSampler', thresh=0.5, min_kept=150000) # 調整閾值
)
3.訓練策略:
- 漸進式訓練:
第一階段凍結GlobalM,先訓練基礎特征提取器
第二階段分層解凍注意力模塊 - 損失函數改進:
Dice Loss + Focal Loss組合
引入邊界感知損失(Boundary-aware Loss) - 后處理優化:
test_cfg=dict(mode='slide',crop_size=(512, 512),stride=(256, 256),post_process=dict(type='CRFPostProcess', # 增加條件隨機場后處理win_size=5,iter_max=10)
)
驗證方向:
- 可視化分析:
使用Grad-CAM可視化GlobalM的注意力區域
對比有無GlobalM的特征響應圖 - 消融實驗:
逐步移除各組件驗證有效性
測試不同注意力頭數的影響(8/16/32) - 數據診斷:
統計測試集changed樣本數量
檢查標注質量(可能存在標注噪聲)