深度學習訓練配置參數詳解
1. 啟動初始化參數 說明 CUDA_VISIBLE_DEVICES
指定使用的GPU設備編號("0"表示單卡) seed
隨機種子(1777777),保證實驗可復現性 cuda
是否啟用GPU加速(True) benchmark
是否啟用cudnn基準測試(False),輸入尺寸固定時可設為True加速 deterministic
是否強制確定性算法(True),保證可復現性但可能降低性能
2. 數據預處理參數 說明 resample_spacing
體數據重采樣間距([0.5,0.5,0.5]毫米) clip_lower_bound
灰度值截斷下限(-1412) clip_upper_bound
灰度值截斷上限(17943) samples_train
每張訓練圖像的采樣點數(2048) crop_size
訓練裁剪尺寸(160×160×96體素) crop_threshold
有效裁剪的最小前景占比(0.5)
3. 數據增強參數 說明 augmentation_probability
數據增強應用概率(30%) augmentation_method
增強策略("Choice"表示隨機選擇一種) open_elastic_transform
是否啟用彈性形變(True) elastic_transform_sigma
彈性形變強度(20) elastic_transform_alpha
彈性形變縮放系數(1) open_gaussian_noise
是否添加高斯噪聲(True) gaussian_noise_mean
噪聲均值(0) gaussian_noise_std
噪聲標準差(0.01) open_random_flip
是否啟用隨機翻轉(True) open_random_rescale
是否啟用隨機縮放(True) random_rescale_min_percentage
最小縮放比例(0.5倍) random_rescale_max_percentage
最大縮放比例(1.5倍) open_random_rotate
是否啟用隨機旋轉(True) random_rotate_min_angle
最小旋轉角度(-50°) random_rotate_max_angle
最大旋轉角度(50°) normalize_mean
數據標準化均值(0.050) normalize_std
數據標準化標準差(0.028)
4. 數據加載參數 說明 dataset_name
數據集名稱(“3D-CBCT-Tooth”) dataset_path
數據集存儲路徑 create_data
是否重新生成預處理數據(False) batch_size
批大小(1) num_workers
數據加載線程數(4)
5. 模型配置參數 說明 model_name
模型名稱(“KanNet”) in_channels
輸入通道數(1表示灰度圖像) classes
分類數量(2類:背景/前景) index_to_class_dict
類別索引映射字典 resume
斷點續訓模型路徑(None表示不啟用) pretrain
預訓練權重路徑(None表示不啟用) high_frequency
高頻成分權重(0.9) low_frequency
低頻成分權重(0.1)
6. 優化器參數 說明 optimizer_name
優化器類型(“AdamW”) learning_rate
初始學習率(0.0005) weight_decay
L2正則化系數(0.00005) momentum
動量參數(0.8)
7. 學習率調度參數 說明 lr_scheduler_name
學習率調度器類型(“ReduceLROnPlateau”) mode
監控指標方向("max"表示越大越好) factor
學習率衰減系數(0.5) patience
等待epoch數(1輪不提升后衰減) milestones
多步學習率調整時機([1,3,5,7,8,9]epoch)
8. 損失函數與評估參數 說明 metric_names
評估指標列表([“DSC”]) loss_function_name
損失函數(“DiceLoss”) class_weight
類別權重(背景0.005,前景0.995) dice_loss_mode
Dice損失變體(“extension”) sigmoid_normalization
是否使用Sigmoid歸一化(False)
9. 訓練設置參數 說明 optimize_params
是否優化超參數(False) use_amp
是否使用混合精度(False) run_dir
實驗日志保存目錄 start_epoch
起始epoch(0) end_epoch
終止epoch(20) best_dice
初始最佳Dice分數(0.60) save_epoch_freq
模型保存頻率(每4個epoch) crop_stride
預測時的滑動窗口步長([32,32,32])
關鍵說明:
GPU相關參數:需根據實際硬件調整CUDA_VISIBLE_DEVICES
數據增強:所有open_*參數控制是否啟用對應增強方法
類別不平衡:通過class_weight參數顯著提高前景權重(牙科結構)
訓練控制:deterministic=True保證可復現性,但會禁用benchmark優化
注:實際使用時需根據數據集特性和硬件條件調整參數值。對于醫學圖像分割任務,建議優先保證deterministic和精細的數據預處理。