為了提升推理速度并降低部署成本,模型剪枝已成為關鍵技術。本文將結合實踐操作,講解YOLOv8模型剪枝的方法原理、實施步驟及注意事項。
雖然YOLOv8n版本本身參數量少、推理速度快,能滿足大多數工業檢測需求,但谷歌研究表明:通過對大模型進行裁剪得到的小模型往往性能更優。
本文基于其他博客的剪枝方法的代碼實現,專門針對YOLOv8模型進行剪枝優化,能夠理解模型剪枝的底層操作。其核心創新點在于利用BN層(Batch Normalization)的特性,實現高效的通道級剪枝操作。
一、剪枝的理論基礎
- BN參數的重要性:BN層中的縮放參數(γ)代表了卷積核的重要程度,通過裁剪γ值較小的卷積核,可以實現剪枝。
- 剪枝流程總體架構:
- 訓練稀疏模型(引入BN正則化)
- 計算剪枝閾值
- 剪除冗余卷積核
- 微調模型,恢復性能
二、YOLOv8剪枝的具體步驟
1. 預備工作
- 模型訓練:?先進行完整訓練,獲得基準性能指標。
將LL_pruning.py
、LL_train.py這兩個文件放在根目錄下
LL_train.py代碼如下所示:from ultralytics import YOLO # 導入YOLO模型庫 import os # 導入os模塊,用于處理文件路徑 root = os.getcwd() # 獲取當前工作目錄 ## 配置文件路徑 name_yaml = os.path.join(root, "ultralytics/datasets/VOC.yaml") # 數據集配置文件路徑 name_pretrain = os.path.join(root, r"D:\practice_demo\ultralytics\runs\detect\jueyuanzi_yolov8m\best.pt") # 預訓練模型路徑 ## 原始訓練路徑 path_train = os.path.join(root, "runs/detect/VOC") # 原始訓練結果保存路徑 name_train = os.path.join(path_train, "weights/last.pt") # 原始訓練模型文件路徑 ## 約束訓練路徑、剪枝模型文件 path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint") # 約束訓練結果保存路徑 name_prune_before = os.path.join(path_constraint_train, "weights/last.pt") # 剪枝前模型文件路徑 name_prune_after = os.path.join(path_constraint_train, "weights/last_prune.pt") # 剪枝后模型文件路徑 ## 微調路徑 path_fineturn = os.path.join(root, "runs/detect/VOC_finetune") # 微調結果保存路徑 def step1_train(): model = YOLO(name_pretrain) # 加載預訓練模型 model.train(data=name_yaml, imgsz=640, epochs=300, batch=32, name=path_train) # 訓練模型 ## 一定要添加【amp=False】 def step2_Constraint_train(): model = YOLO(name_train) # 加載原始訓練模型 model.train(data=name_yaml, imgsz=640, epochs=50, batch=32, amp=False, save_period=1, name=path_constraint_train) # 訓練模型 def step3_pruning(): from LL_pruning import do_pruning # 導入剪枝函數 do_pruning(name_prune_before, name_prune_after) # 執行剪枝操作 def step4_finetune(): model = YOLO(name_prune_after) # 加載剪枝后的模型 model.train(data=name_yaml, imgsz=640, epochs=100, batch=32, save_period=1, name=path_fineturn) # 微調模型 # 執行訓練、約束訓練、剪枝和微調步驟 step1_train() # 訓練模型 # step2_Constraint_train() # 進行稀疏訓練 # step3_pruning() # 執行剪枝 # step4_finetune() # 微調模型
LL_pruning.py代碼如下所示:
?
from ultralytics import YOLO # 導入YOLO模型
import torch # 導入PyTorch庫
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect # 導入YOLO模型中的模塊
import os # 導入os模塊,用于處理文件路徑# os.environ["CUDA_VISIBLE_DEVICES"] = "2" # 可選:指定使用的GPU設備class PRUNE():def __init__(self) -> None:self.threshold = None # 初始化閾值def get_threshold(self, model, factor=0.8):"""計算剪枝閾值:param model: YOLO模型:param factor: 剪枝比例,默認0.8"""ws = [] # 存儲權重bs = [] # 存儲偏置for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d): # 僅處理BatchNorm2d層w = m.weight.abs().detach() # 獲取權重的絕對值b = m.bias.abs().detach() # 獲取偏置的絕對值ws.append(w) # 添加權重bs.append(b) # 添加偏置print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item()) # 打印權重和偏置的最大最小值# 合并所有權重ws = torch.cat(ws)# 計算剪枝閾值self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]def prune_conv(self, conv1: Conv, conv2: Conv):"""對卷積層的“相鄰”卷積做通道級剪枝。參數----:param conv1: 第一個卷積層: Conv(Ultralytics封裝的Conv模塊,內部含 nn.Conv2d + BN + 激活)*上游* 被剪枝的卷積。刪除它的某些 輸出 通道。:param conv2: 第二個卷積層: Conv 或 Conv列表 / 純 nn.Conv2d / None*下游* 接收 conv1 輸出的卷積(可能有多支分支)。需要把 輸入 通道同步刪除。剪枝規則--------1. 用 conv1 中 BatchNorm 的縮放系數 γ 的絕對值做“重要性”指標。2. 選出 |γ| >= 全局閾值 的通道索引 keep_idxs(若太少則降低閾值,至少保留8個,防止結構非法)。3. 在 conv1 中:刪掉其它通道 → 需要同時修改 BN 的各種統計量與 nn.Conv2d 的權重/偏置/out_channels。4. 在 conv2 中:這些被刪的只是“輸入特征圖”,因此只更新 in_channels。"""# a. 根據BN中的參數,獲取需要保留的indexgamma = conv1.bn.weight.data.detach() # 獲取BN層的權重beta = conv1.bn.bias.data.detach() # 獲取BN層的偏置keep_idxs = [] # 存儲需要保留的索引local_threshold = self.threshold # 使用全局閾值while len(keep_idxs) < 8: # 確保至少保留8個卷積核keep_idxs = torch.where(gamma.abs() >= local_threshold)[0] # 獲取滿足條件的索引local_threshold = local_threshold * 0.5 # 如果不足8個,降低閾值n = len(keep_idxs) # 保留的卷積核數量print(n / len(gamma)) # 打印保留的比例# b. 利用index對BN進行剪枝conv1.bn.weight.data = gamma[keep_idxs] # 更新BN權重conv1.bn.bias.data = beta[keep_idxs] # 更新BN偏置conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs] # 更新BN的方差conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs] # 更新BN的均值conv1.bn.num_features = n # 更新BN的特征數量conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs] # 更新卷積層的權重conv1.conv.out_channels = n # 更新卷積層的輸出通道數# c. 利用index對conv1進行剪枝if conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs] # 更新卷積層的偏置# d. 利用index對conv2進行剪枝if not isinstance(conv2, list):conv2 = [conv2] # 確保conv2是列表for item in conv2:if item is None: continue # 跳過Noneif isinstance(item, Conv):conv = item.conv # 獲取卷積層else:conv = itemconv.in_channels = n # 更新輸入通道數conv.weight.data = conv.weight.data[:, keep_idxs] # 更新卷積層的權重def prune(self, m1, m2):"""對模塊進行剪枝:param m1: 第一個模塊:param m2: 第二個模塊"""if isinstance(m1, C2f): # 如果m1是C2f模塊,獲取其cv2m1 = m1.cv2if not isinstance(m2, list): # 確保m2是列表m2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1 # 獲取C2f或SPPF的cv1self.prune_conv(m1, m2) # 對卷積層進行剪枝def do_pruning(modelpath, savepath):"""執行剪枝操作:param modelpath: 原始模型路徑:param savepath: 剪枝后模型保存路徑"""pruning = PRUNE() # 創建PRUNE實例### 0. 加載模型yolo = YOLO(modelpath) # 從指定路徑加載YOLO模型pruning.get_threshold(yolo.model, 0.8) # 獲取剪枝閾值,0.8為剪枝率### 1. 剪枝c2f中的Bottleneckfor name, m in yolo.model.named_modules():if isinstance(m, Bottleneck): # 僅處理Bottleneck模塊pruning.prune_conv(m.cv1, m.cv2) # 對Bottleneck中的卷積層進行剪枝### 2. 指定剪枝不同模塊之間的卷積核seq = yolo.model.model # 獲取模型的序列for i in [3, 5, 7, 8]: # 指定需要剪枝的模塊pruning.prune(seq[i], seq[i + 1]) # 對相鄰模塊進行剪枝### 3. 對檢測頭進行剪枝detect: Detect = seq[-1] # 獲取檢測頭last_inputs = [seq[15], seq[18], seq[21]] # 獲取最后輸入的模塊colasts = [seq[16], seq[19], None] # 獲取與最后輸入相連的模塊for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):pruning.prune(last_input, [colast, cv2[0], cv3[0]]) # 對輸入模塊和檢測頭進行剪枝pruning.prune(cv2[0], cv2[1]) # 對檢測頭的卷積層進行剪枝pruning.prune(cv2[1], cv2[2]) # 對檢測頭的卷積層進行剪枝pruning.prune(cv3[0], cv3[1]) # 對檢測頭的卷積層進行剪枝pruning.prune(cv3[1], cv3[2]) # 對檢測頭的卷積層進行剪枝### 4. 模型梯度設置與保存for name, p in yolo.model.named_parameters():p.requires_grad = True # 設置所有參數的梯度為可計算# yolo.val() # 驗證模型性能torch.save(yolo.ckpt, savepath) # 保存剪枝后的模型yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath)) # 更新模型路徑yolo.export(format="onnx") # 導出為ONNX格式## 重新加載模型,修改保存命名,用以比較剪枝前后的onnx的大小yolo = YOLO(modelpath) # 從指定路徑加載YOLO模型yolo.export(format="onnx") # 導出為ONNX格式if __name__ == "__main__":modelpath = "runs/detect1/14_Constraint/weights/last.pt" # 原始模型路徑savepath = "runs/detect1/14_Constraint/weights/last_prune.pt" # 剪枝后模型保存路徑do_pruning(modelpath, savepath) # 執行剪枝操作?
2. 稀疏正則訓練
- 使用帶有 BN正則的訓練方式,促進BN參數稀疏化。
首先加載一個正常訓練的yolov8模型權重(.pt文件),在ultralytics/engine/trainer.py
中添加如下代碼,使得bn參數在訓練時變得稀疏。
代碼中對所有 BatchNorm 層加了 L1 正則,以便自動把不重要的通道“壓”成零,后面再統一按閾值剪枝。關鍵代碼如下:
...## add start=============================## add l1 regulation for step2_Constraint_trainl1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)for k, m in self.model.named_modules():if isinstance(m, nn.BatchNorm2d):m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))## add end ==============================...
-
為什么只對 BN 做正則?
BatchNorm 的 γ(scale)系數直接影響通道輸出強度:γ ≈ 0 時,該通道幾乎不參與后續計算,用它來衡量“重要性”最直觀。 -
L1 正則如何“稀疏”?
在反向傳播時,為每個 γ/β 的梯度額外加上 ±λ,這會讓本就小的 γ 更快被拉向 0,從而在訓練中自然分化出大 γ(保留通道)和小 γ(待剪通道)。 -
λ 為何隨 epoch 遞減?
訓練初期靠強正則快速分離;后期減弱正則,避免過度壓榨保留通道,給微調留下空間。 -
bias 也正則嗎?
雖然偏置對通道篩選作用不如 γ 強,但適度收斂 β 能進一步去除邊緣特征,提高稀疏度。
之后在LL_pruning.py中運行方框中的代碼
注意事項:
稀疏訓練需要關閉混合精度(amp=False)
剪枝依賴于 BatchNorm 的 γ 值作為排序閾值,γ 越小越容易被剪除。若使用 FP16(混合精度),許多接近 0 的 γ 會被量化到同一值甚至下溢為 0,導致排序失真,同時 L1 正則梯度也容易消失,后續剪枝的閾值選擇會變得不穩定。而使用 FP32(amp=False)能精確表示這些微小差異,確保稀疏模式可控。
稀疏訓練的 batch size 不宜過大
由于關閉了混合精度,模型采用全精度計算,顯存占用顯著增加。若 batch size 設置過大,可能導致顯存溢出(OOM),進而引發訓練失敗。
稀疏訓練階段要將 patience 設為 0 或較大值
稀疏訓練的目標并非短期提升 mAP,而是讓 BN 的 γ 在多個 epoch 內逐步被 L1 正則“壓縮”。在此期間,驗證集指標可能停滯甚至下降。若啟用常規早停機制(默認 patience 為幾十),訓練可能在 γ 尚未充分分化前被提前終止,導致剪枝時閾值模糊、可剪通道不足。
3. 剪枝
執行以下代碼;
剪枝中的注意點:
在 YOLOv8 中,當進行 split 和 concat 操作時,若剪枝后的通道數不匹配會報錯。LL_pruning.py 的剪枝代碼怎么避免這一問題,暫時還沒研究透,有大佬知道請不吝指教。
關于 do_pruning 方法啟用 yolo.val() 后保存的剪枝模型缺失 BN 層的原因:
Ultralytics 的驗證?/?導出流程會將 Conv + BatchNorm 靜態融合到卷積權重和偏置中,從而提升推理速度和輕量化。這一過程會直接移除 BN 層,因此保存的 yolo.ckpt 是已融合的模型。
對比剪枝前后的模型文件(last.pt/last_prune.pt)及其 ONNX 轉換結果:
剪枝后的 .pt 文件增大,而 ONNX 文件從 43MB 縮減至 36MB。這是因為 .pt 文件包含完整的 checkpoint 元數據,而 ONNX 僅保存精簡的推理圖結構,因此只需關注 ONNX 文件大小的優化即可。
4. 微調
在第二步稀疏正則訓練中將BN約束注釋
需要注意的是明明加載的是剪枝后的模型,但訓練啟動時打印的日志卻顯示為標準版模型的參數。并且經過驗證,微調后的模型參數就是標準的yolo模型。所以需要進行一些修改,詳細的講解可以看YOLOv8 剪枝模型加載踩坑記:解決 YAML 覆蓋剪枝結構的問題-CSDN博客
修改ultralytics/engine/model.py文件內容:
self.trainer.model包含從YAML文件加載的原始模型配置信息,以及從PT文件加載的剪枝后權重。只需將該變量的網絡結構更新為剪枝后的網絡結構就行,否則訓練后的模型參數不會改變。
運行下面的代碼
yolov8模型的剪枝到這就結束了。