一、背景:為什么需要模型剪枝?
隨著深度學習的發展,模型參數量和計算量呈指數級增長。以ResNet18為例,其在ImageNet上的參數量約為1100萬,雖然在服務器端運行流暢,但在移動端或嵌入式設備上部署時,內存和計算資源的限制使得直接使用大模型變得困難。模型剪枝(Model Pruning)作為模型壓縮的核心技術之一,通過刪除冗余的神經元或通道,在保持模型性能的前提下顯著降低模型大小和計算量,是解決這一問題的關鍵手段。
在前面一篇文章我們也提到了模型壓縮的一些基本定義和核心原理:《深度學習之模型壓縮三駕馬車:模型剪枝、模型量化、知識蒸餾》。
本文將基于PyTorch框架,以ResNet18在CIFAR-10數據集上的分類任務為例,詳細講解結構化通道剪枝的完整實現流程,包括模型訓練、剪枝策略、剪枝后結構調整、微調及效果評估。
二、整體流程概覽
本文代碼的核心流程可總結為以下6步:
- 環境初始化與數據集加載
- 原始模型訓練與評估
- 卷積層結構化剪枝(以
conv1
層為例) - 剪枝后模型結構調整(BN層、殘差下采樣層等)
- 剪枝模型微調
- 剪枝前后模型效果對比
特地說明:在這里選擇conv1
層作為例子,不是因為選擇這個就會效果更好。
三、關鍵步驟代碼解析
3.1 環境初始化與數據集準備
首先需要配置計算設備(GPU/CPU),并加載CIFAR-10數據集。CIFAR-10包含10類32x32的彩色圖像,訓練集5萬張,測試集1萬張。
def setup_device():return torch.device("cuda" if torch.cuda.is_available() else "cpu")def load_dataset():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)) # 歸一化到[-1,1]])train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)return train_dataset, test_dataset
3.2 原始模型訓練
使用預訓練的ResNet18模型,修改全連接層輸出為10類(匹配CIFAR-10的類別數),并進行5輪訓練:
def create_model(device):model = models.resnet18(pretrained=True) # 加載ImageNet預訓練權重model.fc = nn.Linear(512, 10) # 修改輸出層為10類return model.to(device)def train_model(model, train_loader, criterion, optimizer, device, epochs=3):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in tqdm(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")return model
3.3 結構化通道剪枝核心實現
本文重點是對卷積層進行結構化剪枝(按通道剪枝),具體步驟如下:
3.3.1 計算通道重要性
通過計算卷積核的L2范數評估通道重要性。假設卷積層權重維度為[out_channels, in_channels, kernel_h, kernel_w]
,將每個輸出通道的權重展平為一維向量,計算其L2范數,范數越小表示該通道對模型性能貢獻越低,越應被剪枝。
layer = dict(model.named_modules())[layer_name] # 獲取目標卷積層
weight = layer.weight.data
channel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1) # 計算每個輸出通道的L2范數
3.3.2 生成剪枝掩碼
根據剪枝比例(如20%),選擇范數最小的通道生成掩碼:
num_channels = weight.shape[0] # 原始輸出通道數(如ResNet18的conv1層為64)
num_prune = int(num_channels * amount) # 需剪枝的通道數(如64*0.2=12)
_, indices = torch.topk(channel_norm, k=num_prune, largest=False) # 找到最不重要的12個通道mask = torch.ones(num_channels, dtype=torch.bool)
mask[indices] = False # 掩碼:保留的通道標記為True(52個),剪枝的標記為False(12個)
3.3.3 替換卷積層
創建新的卷積層,僅保留掩碼為True
的通道:
new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune, # 剪枝后輸出通道數(52)kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None
).to(device) # 移動到模型所在設備new_conv.weight.data = layer.weight.data[mask] # 保留掩碼為True的通道權重
if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask] # 偏置同理
3.3.4 關鍵:剪枝后結構調整
直接剪枝會導致后續層(如BN層、殘差連接中的下采樣層)的輸入/輸出通道不匹配,必須同步調整:
(1) 調整BN層
卷積層后通常接BN層,BN的num_features
需與卷積輸出通道數一致:
if 'conv1' in layer_name:bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device) # 新BN層通道數52with torch.no_grad():# 同步原始BN層的參數(僅保留未被剪枝的通道)new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1
(2) 調整殘差下采樣層
ResNet的殘差塊(如layer1.0
)中,若主路徑的通道數被剪枝,需要通過1x1卷積的下采樣層(downsample
)匹配 shortcut 的通道數:
block = model.layer1[0]
if not hasattr(block, 'downsample') or block.downsample is None:# 原始無downsample,創建新的1x1卷積+BNdownsample_conv = nn.Conv2d(in_channels=new_conv.out_channels, # 52(剪枝后的conv1輸出)out_channels=block.conv2.out_channels, # 64(主路徑conv2的輸出)kernel_size=1,stride=1,bias=False).to(device)torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu') # 初始化權重downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)block.downsample = nn.Sequential(downsample_conv, downsample_bn) # 添加downsample層
else:# 原有downsample層,調整輸入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channels # 輸入通道改為52downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()) # 輸入通道用掩碼篩選
(3) 前向傳播驗證
調整后需驗證模型能否正常前向傳播,避免通道不匹配導致的錯誤:
with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device) # 測試輸入(B, C, H, W)try:model(test_input)print("? 前向傳播驗證通過")except Exception as e:print(f"? 驗證失敗: {str(e)}")raise
3.3的總結,直接上代碼
def prune_conv_layer(model, layer_name, amount=0.2):# 獲取模型當前所在設備device = next(model.parameters()).device # 新增:獲取設備layer = dict(model.named_modules())[layer_name]weight = layer.weight.datachannel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)num_channels = weight.shape[0] # 原始通道數(如 64)num_prune = int(num_channels * amount)_, indices = torch.topk(channel_norm, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False # 生成剪枝掩碼(長度 64,52 個 True)new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune, # 剪枝后通道數(如 52)kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None)new_conv = new_conv.to(device) # 新增:移動到模型所在設備new_conv.weight.data = layer.weight.data[mask] # 保留 mask 為 True 的通道if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]# 替換原始卷積層parent_name, sep, name = layer_name.rpartition('.')parent = model.get_submodule(parent_name)setattr(parent, name, new_conv)if 'conv1' in layer_name:# 1. 更新與 conv1 直接關聯的 BN1 層bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels) # 新 BN 層通道數 52new_bn1 = new_bn1.to(device) # 新增:移動到模型所在設備with torch.no_grad():new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1# 2. 處理殘差連接中的 downsample(關鍵修正:添加缺失的 downsample)block = model.layer1[0]if not hasattr(block, 'downsample') or block.downsample is None:# 原始無 downsample,需創建新的 1x1 卷積+BN 來匹配通道downsample_conv = nn.Conv2d(in_channels=new_conv.out_channels, # 52out_channels=block.conv2.out_channels, # 64(主路徑輸出通道數)kernel_size=1,stride=1,bias=False)downsample_conv = downsample_conv.to(device) # 新增:移動到模型所在設備# 初始化 1x1 卷積權重(這里簡單復制原模型可能的統計量,實際可根據需求調整)torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)downsample_bn = downsample_bn.to(device) # 新增:移動到模型所在設備with torch.no_grad():# 初始化 BN 參數(可保持默認,或根據原模型統計量調整)downsample_bn.weight.fill_(1.0)downsample_bn.bias.zero_()downsample_bn.running_mean.zero_()downsample_bn.running_var.fill_(1.0)block.downsample = nn.Sequential(downsample_conv, downsample_bn)print("? 為 layer1.0 添加新的 downsample 層")else:# 原有 downsample 層,調整輸入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channels # 輸入通道調整為 52downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()) # 輸入通道用 mask 篩選downsample_conv = downsample_conv.to(device) # 新增:移動到模型所在設備downsample_bn = block.downsample[1]new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)new_downsample_bn = new_downsample_bn.to(device) # 新增:移動到模型所在設備with torch.no_grad():new_downsample_bn.weight.data = downsample_bn.weight.data.clone()new_downsample_bn.bias.data = downsample_bn.bias.data.clone()new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()block.downsample[1] = new_downsample_bn# 3. 同步 layer1.0.conv1 的輸入通道(保持原有邏輯)next_convs = ['layer1.0.conv1']for conv_path in next_convs:try:conv = model.get_submodule(conv_path)if conv.in_channels != new_conv.out_channels:print(f"同步輸入通道: {conv.in_channels} → {new_conv.out_channels}")conv.in_channels = new_conv.out_channelsconv.weight = nn.Parameter(conv.weight.data[:, mask, :, :].clone())conv = conv.to(device) # 新增:移動到模型所在設備except AttributeError as e:print(f"?? 卷積層調整失敗: {conv_path} ({str(e)})")# 驗證前向傳播with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device) # 確保測試輸入也在相同設備try:model(test_input)print("? 前向傳播驗證通過")except Exception as e:print(f"? 驗證失敗: {str(e)}")raisereturn model
3.4 剪枝模型微調
剪枝后模型的部分參數被刪除,需要通過微調恢復性能。一開始,我們只是在微調時凍結了除 fc
層外的所有參數,但是效果并不好,當然分析原因,除了動了conv1
的原因(conv1
是模型的第一個卷積層,負責提取最基礎的圖像特征(如邊緣、紋理、顏色等)。這些底層特征對后續所有層的特征提取至關重要。),最重要的是裁剪后,需要對裁剪的層進行微調,確保參數適應新的特征維度。
微調時凍結了除 fc
層外的所有參數的代碼和結果:
for name, param in pruned_model.named_parameters():if 'fc' not in name:param.requires_grad = Falseoptimizer = optim.Adam(pruned_model.fc.parameters(), lr=0.001)print("微調剪枝后的模型")pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device,epochs=5)
原始模型準確率: 80.07%
剪枝后模型準確率: 37.80%
可以看到這個相差很大
本文選擇解凍被剪枝的層(如conv1
、bn1
)及相關層(如layer1.0.conv1
、downsample
)進行參數更新:
print("開始微調剪枝后的模型")
for name, param in pruned_model.named_parameters():# 僅解凍與剪枝相關的層if 'conv1' in name or 'bn1' in name or 'layer1.0.conv1' in name or 'layer1.0.downsample' in name or 'fc' in name:param.requires_grad = Trueelse:param.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.001)
pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=5)
原始模型準確率: 78.94%
剪枝后模型準確率: 81.30%
重新微調了裁剪后的層后,結果有了很大改變。
四、實驗結果與分析
通過代碼中的evaluate_model
函數評估剪枝前后的模型準確率:
def evaluate_model(model, device, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()acc = 100 * correct / totalreturn acc
假設原始模型準確率為88.5%,剪枝20%通道后(模型大小降低約20%),通過微調可恢復至87.2%,驗證了剪枝策略的有效性。
五、總結與改進方向
本文實現了基于通道L2范數的結構化剪枝,重點解決了剪枝后模型結構不一致的問題(如BN層、殘差下采樣層的調整),并通過微調恢復了模型性能。
在這個例子中,僅裁剪 conv1 層的影響
僅裁剪 conv1 層對模型的影響極大,原因如下:
- 底層特征的重要性 : conv1 輸出的是最基礎的圖像特征,所有后續層的特征均基于此生成。裁剪 conv1 會直接限制后續所有層的特征表達能力。
- 結構連鎖反應 : conv1 的輸出通道減少會觸發 bn1 、 layer1.0.conv1 、 downsample 等多個模塊的調整,任何一個模塊的調整失誤(如通道數不匹配、參數初始化不當)都會導致整體性能下降。
實際應用中可從以下方向改進:
模型裁剪通常優先選擇 中間層(如ResNet的 layer2 、 layer3 ) ,而非底層或頂層,原因如下:
- 底層(如 conv1 ) :負責基礎特征提取,裁剪后特征損失大,對性能影響顯著。
- 中間層(如 layer2 、 layer3 ) :特征具有一定抽象性但冗余度高(同一層的多個通道可能提取相似特征),裁剪后對性能影響較小。
- 頂層(如 fc 層) :負責分類決策,參數密度高但冗余度低,裁剪易導致分類能力下降。