代碼來自:GitHub - ChuHan89/WSSS-Tissue
借助了一些人工智能
2_generate_PM.py
功能總結
該代碼用于?生成弱監督語義分割(WSSS)所需的偽掩碼(Pseudo-Masks),是 Stage2 訓練的前置步驟。其核心流程為:
-
加載 Stage1 訓練好的分類模型(支持 CAM 生成)。
-
為不同層次的特征圖生成偽掩碼(如?
b4_5
,?b5_2
,?bn7
?對應的不同網絡層)。 -
保存偽掩碼圖像,使用調色板將類別標簽映射為彩色圖像。
代碼解析
1. 導入依賴庫
import os import torch import argparse import importlib from torch.backends import cudnn cudnn.enabled = True # 啟用CUDA加速 from tool.infer_fun import create_pseudo_mask # 自定義函數:生成偽掩碼
-
關鍵依賴:
-
cudnn.enabled = True
:啟用 cuDNN 加速,優化 GPU 計算性能。 -
create_pseudo_mask
:核心函數(用戶需參考其實現),負責生成并保存偽掩碼。
-
2. 主函數與參數解析
if __name__ == '__main__':# 定義命令行參數parser = argparse.ArgumentParser()parser.add_argument("--weights", default='checkpoints/stage1_checkpoint_trained_on_bcss.pth', type=str)parser.add_argument("--network", default="network.resnet38_cls", type=str)parser.add_argument("--dataroot", default="datasets/BCSS-WSSS/", type=str)parser.add_argument("--dataset", default="bcss", type=str)parser.add_argument("--num_workers", default=8, type=int)parser.add_argument("--n_class", default=4, type=int)args = parser.parse_args()print(args) # 打印參數列表
-
參數說明:
-
--weights
:Stage1 訓練好的模型權重文件路徑(默認指向 BCSS 數據集)。 -
--network
:網絡結構定義文件(如?network.resnet38_cls
)。 -
--dataroot
:數據集根目錄(包含訓練/測試數據)。 -
--dataset
:數據集標識(bcss
?或?luad
)。 -
--n_class
:類別數量(BCSS 為 4 類,LUAD 可能不同)。
-
3. 定義調色板(顏色映射)
if args.dataset == 'luad':palette = [0]*15 # 初始化長度為15的列表(每類3個RGB通道)palette[0:3] = [205,51,51] # 類別1:紅色palette[3:6] = [0,255,0] # 類別2:綠色palette[6:9] = [65,105,225] # 類別3:藍色palette[9:12] = [255,165,0] # 類別4:橙色palette[12:15] = [255, 255, 255] # 背景或未標注區域:白色elif args.dataset == 'bcss':palette = [0]*15palette[0:3] = [255, 0, 0] # 類別1:紅色palette[3:6] = [0,255,0] # 類別2:綠色palette[6:9] = [0,0,255] # 類別3:藍色palette[9:12] = [153, 0, 255] # 類別4:紫色palette[12:15] = [255, 255, 255] # 背景:白色
-
作用:將類別標簽映射為 RGB 顏色,用于偽掩碼的可視化。
-
細節:
-
每個類別占 3 個連續位置(RGB 通道)。
-
palette[12:15]
?可能表示背景或未標注區域。 -
不同數據集使用不同的顏色方案(如 BCSS 用紫色表示第4類)。
-
4. 創建偽掩碼保存路徑
PMpath = os.path.join(args.dataroot, 'train_PM') # 路徑示例:datasets/BCSS-WSSS/train_PMif not os.path.exists(PMpath):os.mkdir(PMpath) # 若目錄不存在則創建
-
目的:在數據集根目錄下創建?
train_PM
?文件夾,用于保存生成的偽掩碼。
5. 加載模型
model = getattr(importlib.import_module("network.resnet38_cls"), 'Net_CAM')(n_class=args.n_class)model.load_state_dict(torch.load(args.weights), strict=False)model.eval() # 設置為評估模式(禁用Dropout等隨機操作)model.cuda() # 將模型移至GPU
-
關鍵步驟:
-
動態加載模型:從?
network.resnet38_cls
?模塊加載?Net_CAM
?類(支持 CAM 生成的變體)。 -
加載權重:使用 Stage1 訓練好的模型參數(
strict=False
?允許部分參數不匹配)。 -
評估模式:關閉 BatchNorm 和 Dropout 的隨機性,確保結果一致性。
-
6. 生成多級偽掩碼
##fm = 'b4_5' # 特征模塊名稱(可能對應網絡中的某個中間層)savepath = os.path.join(PMpath, 'PM_' + fm) # 保存路徑:train_PM/PM_b4_5if not os.path.exists(savepath):os.mkdir(savepath)create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset)## 重復相同流程生成其他層級的偽掩碼fm = 'b5_2'savepath = os.path.join(PMpath, 'PM_' + fm)if not os.path.exists(savepath):os.mkdir(savepath)create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset)##fm = 'bn7'savepath = os.path.join(PMpath, 'PM_' + fm)if not os.path.exists(savepath):os.mkdir(savepath)create_pseudo_mask(model, args.dataroot, fm, savepath, args.n_class, palette, args.dataset)
-
功能:針對不同特征模塊(
fm
)生成偽掩碼,保存到對應子目錄。 -
關鍵參數:
-
fm
:特征模塊標識,可能對應網絡中的不同層(如 ResNet 的?block4
、block5
?或?bottleneck
)。 -
create_pseudo_mask
:核心函數,推測其功能為:-
加載訓練集圖像。
-
使用模型提取指定層的特征圖。
-
生成類別激活圖(CAM)。
-
根據閾值將 CAM 轉換為二值偽掩碼。
-
應用調色板將掩碼保存為彩色 PNG 圖像。
-
-
代碼執行示例
python generate_pseudo_masks.py \--dataset bcss \--dataroot datasets/BCSS-WSSS/ \--weights checkpoints/stage1_checkpoint_trained_on_bcss.pth
-
輸出:在?
datasets/BCSS-WSSS/train_PM/
?下生成三個子目錄:-
PM_b4_5
:基于?b4_5
?層特征的偽掩碼。 -
PM_b5_2
:基于?b5_2
?層特征的偽掩碼。 -
PM_bn7
:基于?bn7
?層特征的偽掩碼。
-
總結
該代碼是弱監督語義分割流程中?生成多級偽掩碼的關鍵步驟,利用 Stage1 訓練的分類模型提取不同層級的特征,生成偽標簽供 Stage2 的分割模型訓練。通過多級偽掩碼的融合,可以提升最終分割結果的精度和魯棒性。
3_train_stage2.py
功能總結
該代碼是弱監督語義分割(WSSS)的?Stage2 訓練與測試腳本,核心功能為:
-
訓練分割模型:基于 DeepLab v3+ 架構,使用 Stage1 生成的偽掩碼(Pseudo-Masks)進行監督訓練。
-
驗證與測試:評估模型在驗證集和測試集上的性能(如 mIoU、像素準確率等)。
-
門控機制(Gate Mechanism):在測試階段結合 Stage1 的分類結果過濾分割預測,提升精度。
-
多任務損失:融合不同層次偽掩碼的損失(主偽掩碼 + 兩種增強版本)。
代碼結構
# 1. 依賴庫導入 import argparse, os, numpy as np from tqdm import tqdm import torch from tool.GenDataset import make_data_loader from network.sync_batchnorm.replicate import patch_replication_callback from network.deeplab import * from tool.loss import SegmentationLosses from tool.lr_scheduler import LR_Scheduler from tool.saver import Saver from tool.summaries import TensorboardSummary from tool.metrics import Evaluator# 2. 定義訓練器類 class Trainer(object):def __init__(self, args): ... # 初始化模型、數據、優化器等def training(self, epoch): ... # 訓練一個epochdef validation(self, epoch): ... # 驗證集評估def test(self, epoch, Is_GM): ... # 測試集評估(支持門控機制)def load_the_best_checkpoint(self): ... # 加載最佳模型# 3. 主函數 def main(): ... # 解析參數、啟動訓練if __name__ == "__main__":main()
關鍵代碼解析
1.?Trainer
?類初始化
class Trainer(object):def __init__(self, args):self.args = args# 初始化日志記錄與模型保存工具self.saver = Saver(args) # 保存模型檢查點self.summary = TensorboardSummary('logs') # TensorBoard日志self.writer = self.summary.create_summary()# 數據加載kwargs = {'num_workers': args.workers, 'pin_memory': False}self.train_loader, self.val_loader, self.test_loader = make_data_loader(args, **kwargs)# 模型定義(DeepLab v3+)self.nclass = args.n_classmodel = DeepLab(num_classes=self.nclass,backbone=args.backbone, # 骨干網絡(如ResNet)output_stride=args.out_stride, # 輸出步長(控制特征圖分辨率)sync_bn=args.sync_bn, # 多GPU同步BatchNormfreeze_bn=args.freeze_bn # 凍結BN層參數)# 優化器配置(分層學習率)train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}, # 骨干網絡低學習率{'params': model.get_10x_lr_params(), 'lr': args.lr * 10} # 分類頭高學習率]optimizer = torch.optim.SGD(train_params, momentum=args.momentum,weight_decay=args.weight_decay, nesterov=args.nesterov)# 損失函數(交叉熵或Focal Loss)self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)self.model, self.optimizer = model, optimizer# 評估工具(計算mIoU等指標)self.evaluator = Evaluator(self.nclass)# 學習率調度(Poly策略)self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader))# 加載Stage1的分類模型(用于門控機制)model_stage1 = getattr(importlib.import_module('network.resnet38_cls'), 'Net_CAM')(n_class=4)resume_stage1 = 'checkpoints/stage1_checkpoint_trained_on_'+str(args.dataset)+'.pth'weights_dict = torch.load(resume_stage1)model_stage1.load_state_dict(weights_dict)self.model_stage1 = model_stage1.cuda()self.model_stage1.eval() # 固定Stage1模型參數# GPU并行化if args.cuda:self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)patch_replication_callback(self.model) # 修復多GPU BatchNorm同步問題self.model = self.model.cuda()# 加載預訓練權重(如DeepLab預訓練模型)if args.resume is not None:checkpoint = torch.load(args.resume)# 處理分類頭權重(微調時保留,否則刪除)if args.ft:self.model.load_state_dict(checkpoint['state_dict'])self.optimizer.load_state_dict(checkpoint['optimizer'])else:del checkpoint['state_dict']['decoder.last_conv.8.weight']del checkpoint['state_dict']['decoder.last_conv.8.bias']self.model.load_state_dict(checkpoint['state_dict'], strict=False)# 初始化最佳mIoUself.best_pred = 0.0
2. 訓練階段?training
def training(self, epoch):train_loss = 0.0self.model.train()tbar = tqdm(self.train_loader) # 進度條num_img_tr = len(self.train_loader)for i, sample in enumerate(tbar):# 加載數據(圖像 + 三個偽掩碼)image, target, target_a, target_b = sample['image'], sample['label'], sample['label_a'], sample['label_b']if self.args.cuda:image, target, target_a, target_b = image.cuda(), target.cuda(), target_a.cuda(), target_b.cuda()# 調整學習率self.scheduler(self.optimizer, i, epoch, self.best_pred)self.optimizer.zero_grad()# 前向傳播output = self.model(image)# 添加額外通道處理類別4(背景或忽略類)one = torch.ones((output.shape[0],1,224,224)).cuda()output = torch.cat([output, (100 * one * (target==4).unsqueeze(dim=1)], dim=1)# 計算多任務損失(主偽掩碼 + 兩種增強版本)loss_o = self.criterion(output, target)loss_a = self.criterion(output, target_a)loss_b = self.criterion(output, target_b)loss = 0.6*loss_o + 0.2*loss_a + 0.2*loss_b# 反向傳播loss.backward()self.optimizer.step()# 統計損失train_loss += loss.item()tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))# 記錄TensorBoard日志self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)# 輸出epoch總結self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))print('Loss: %.3f' % train_loss)
3. 驗證階段?validation
def validation(self, epoch):self.model.eval()self.evaluator.reset()tbar = tqdm(self.val_loader, desc='\r')test_loss = 0.0for i, sample in enumerate(tbar):image, target = sample[0]['image'], sample[0]['label']if self.args.cuda:image, target = image.cuda(), target.cuda()with torch.no_grad():output = self.model(image)# 轉換為CPU numpy數組pred = output.data.cpu().numpy()target = target.cpu().numpy()pred = np.argmax(pred, axis=1)# 處理類別4(設為忽略類)pred[target==4] = 4# 更新評估指標self.evaluator.add_batch(target, pred)# 計算并記錄指標Acc = self.evaluator.Pixel_Accuracy()Acc_class = self.evaluator.Pixel_Accuracy_Class()mIoU = self.evaluator.Mean_Intersection_over_Union()ious = self.evaluator.Intersection_over_Union()FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()# 輸出結果print('Validation:')print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))# 保存最佳模型if mIoU > self.best_pred:self.best_pred = mIoUself.saver.save_checkpoint({'state_dict': self.model.module.state_dict(),'optimizer': self.optimizer.state_dict()}, 'stage2_checkpoint_trained_on_'+self.args.dataset+'.pth')
4. 測試階段?test
(含門控機制)
def test(self, epoch, Is_GM):self.load_the_best_checkpoint() # 加載最佳模型self.model.eval()self.evaluator.reset()tbar = tqdm(self.test_loader, desc='\r')for i, sample in enumerate(tbar):image, target = sample[0]['image'], sample[0]['label']if self.args.cuda:image, target = image.cuda(), target.cuda()with torch.no_grad():output = self.model(image)# 門控機制:利用Stage1的分類結果過濾分割預測if Is_GM:_, y_cls = self.model_stage1.forward_cam(image) # Stage1的分類輸出y_cls = y_cls.cpu().datapred_cls = (y_cls > 0.1) # 類別存在性判斷(閾值0.1)# 應用門控機制pred = output.data.cpu().numpy()if Is_GM:pred = pred * pred_cls.unsqueeze(dim=2).unsqueeze(dim=3).numpy()# 處理類別4pred = np.argmax(pred, axis=1)pred[target==4] = 4self.evaluator.add_batch(target, pred)# 計算并輸出指標Acc = self.evaluator.Pixel_Accuracy()Acc_class = self.evaluator.Pixel_Accuracy_Class()mIoU = self.evaluator.Mean_Intersection_over_Union()print('Test:')print("Acc:{}, Acc_class:{}, mIoU:{}".format(Acc, Acc_class, mIoU))
5. 主函數?main
def main():# 解析命令行參數parser = argparse.ArgumentParser(description="WSSS Stage2")# 模型結構參數parser.add_argument('--backbone', default='resnet', choices=['resnet', 'xception', 'drn', 'mobilenet'])parser.add_argument('--out-stride', type=int, default=16) # 輸出步長(控制特征圖下采樣率)parser.add_argument('--Is_GM', type=bool, default=True) # 是否啟用門控機制# 數據集參數parser.add_argument('--dataroot', default='datasets/BCSS-WSSS/')parser.add_argument('--dataset', default='bcss')parser.add_argument('--n_class', type=int, default=4)# 訓練超參數parser.add_argument('--epochs', type=int, default=30)parser.add_argument('--batch-size', type=int, default=20)parser.add_argument('--lr', type=float, default=0.01)parser.add_argument('--lr-scheduler', default='poly', choices=['poly', 'step', 'cos'])# 其他配置parser.add_argument('--gpu-ids', default='0') # 指定使用的GPUparser.add_argument('--resume', default='init_weights/deeplab-resnet.pth.tar') # 預訓練權重args = parser.parse_args()# 配置CUDAargs.cuda = not args.no_cuda and torch.cuda.is_available()if args.cuda:args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]# 自動設置SyncBNif args.sync_bn is None:args.sync_bn = True if args.cuda and len(args.gpu_ids) > 1 else False# 初始化訓練器并啟動訓練trainer = Trainer(args)for epoch in range(trainer.args.epochs):trainer.training(epoch)if epoch % args.eval_interval == 0:trainer.validation(epoch)# 最終測試trainer.test(epoch, args.Is_GM)trainer.writer.close()
關鍵設計解析
-
多任務損失:
-
目標:同時優化主偽掩碼(
target
)及其兩種增強版本(target_a
,?target_b
),提升模型對不同噪聲偽標簽的魯棒性。 -
權重分配:主損失占60%,增強損失各占20%(
0.6*loss_o + 0.2*loss_a + 0.2*loss_b
)。
-
-
門控機制(Gate Mechanism):
-
作用:在測試階段,利用 Stage1 的分類結果過濾分割預測,僅保留分類模型認為存在的類別。
-
實現:若 Stage1 對某類別的預測概率 > 0.1,則保留該類的分割結果,否則置零。
-
-
類別4處理:
-
背景或忽略類:在標簽中,類別4可能表示背景或未標注區域,預測時直接繼承真實標簽的值(
pred[target==4] = 4
),避免錯誤優化。
-
-
模型初始化:
-
預訓練權重:加載 DeepLab 在 ImageNet 上的預訓練權重(
init_weights/deeplab-resnet.pth.tar
),加速收斂。 -
分層學習率:骨干網絡使用較低學習率(
args.lr
),分類頭使用更高學習率(args.lr * 10
)。
-
運行示例
python train_stage2.py \--dataset bcss \--dataroot datasets/BCSS-WSSS/ \--backbone resnet \--Is_GM True \--batch-size 20 \--epochs 30
總結
該代碼實現了弱監督語義分割的第二階段訓練,通過多任務損失融合多級偽標簽,結合門控機制提升測試精度,最終生成高精度分割模型。訓練過程支持多GPU加速、Poly學習率調度及多種評估指標監控,適用于醫學圖像(如BCSS)或自然場景圖像的分割任務。