模型剪枝知識點整理
剪枝是深度學習模型優化的兩種常見技術,用于減少模型復雜度和提升推理速度,適用于資源受限的環境。
剪枝(Pruning)
剪枝是一種通過移除模型中不重要或冗余的參數來減少模型大小和計算量的方法。剪枝通常分為以下幾種類型:
1. 權重剪枝(Weight Pruning)
權重剪枝是通過移除權重矩陣中接近零的元素來減少模型的參數數量。常見的方法有:
- 非結構化剪枝(Unstructured Pruning):逐個移除權重矩陣中的小權重。
- 結構化剪枝(Structured Pruning):按特定結構(如整行或整列)移除權重。
示例:
import torch# 假設有一個全連接層
fc = torch.nn.Linear(100, 100)# 獲取權重矩陣
weights = fc.weight.data.abs()# 設定剪枝閾值
threshold = 0.01# 應用剪枝
mask = weights > threshold
fc.weight.data *= mask
2. 通道剪枝(Channel Pruning)
通道剪枝主要用于卷積神經網絡,通過移除卷積層中不重要的通道來減少計算量。常見的方法有:
- 基于重要性評分:計算每個通道的重要性分數,移除分數較低的通道。
- 基于稀疏性:通過增加稀疏正則項,訓練過程中自然使某些通道稀疏,再進行剪枝。
import torch
import torch.nn as nnclass ConvNet(nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)def forward(self, x):x = self.conv1(x)x = self.conv2(x)return xmodel = ConvNet()# 獲取卷積層的權重
weights = model.conv1.weight.data.abs()# 計算每個通道的L1范數
channel_importance = torch.sum(weights, dim=[1, 2, 3])# 設定剪枝閾值
threshold = torch.topk(channel_importance, k=32, largest=True).values[-1]# 應用剪枝
mask = channel_importance > threshold
model.conv1.weight.data *= mask.view(-1, 1, 1, 1)
3. 層剪枝(Layer Pruning)
層剪枝是移除整個網絡層,以減少模型的計算深度。這種方法較為激進,通常結合模型架構搜索(NAS)使用。
import torch.nn as nnclass LayerPrunedNet(nn.Module):def __init__(self, use_layer=True):super(LayerPrunedNet, self).__init__()self.use_layer = use_layerself.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)def forward(self, x):x = self.conv1(x)if self.use_layer:x = self.conv2(x)return x# 初始化網絡,選擇是否使用第二層
model = LayerPrunedNet(use_layer=False)