深度學習之模型壓縮三駕馬車:基于ResNet18的模型剪枝實戰(1)

一、背景:為什么需要模型剪枝?

隨著深度學習的發展,模型參數量和計算量呈指數級增長。以ResNet18為例,其在ImageNet上的參數量約為1100萬,雖然在服務器端運行流暢,但在移動端或嵌入式設備上部署時,內存和計算資源的限制使得直接使用大模型變得困難。模型剪枝(Model Pruning)作為模型壓縮的核心技術之一,通過刪除冗余的神經元或通道,在保持模型性能的前提下顯著降低模型大小和計算量,是解決這一問題的關鍵手段。
在前面一篇文章我們也提到了模型壓縮的一些基本定義和核心原理:《深度學習之模型壓縮三駕馬車:模型剪枝、模型量化、知識蒸餾》。

本文將基于PyTorch框架,以ResNet18在CIFAR-10數據集上的分類任務為例,詳細講解結構化通道剪枝的完整實現流程,包括模型訓練、剪枝策略、剪枝后結構調整、微調及效果評估。

二、整體流程概覽

本文代碼的核心流程可總結為以下6步:

  1. 環境初始化與數據集加載
  2. 原始模型訓練與評估
  3. 卷積層結構化剪枝(以conv1層為例)
  4. 剪枝后模型結構調整(BN層、殘差下采樣層等)
  5. 剪枝模型微調
  6. 剪枝前后模型效果對比
    特地說明:在這里選擇conv1層作為例子,不是因為選擇這個就會效果更好。

三、關鍵步驟代碼解析

3.1 環境初始化與數據集準備

首先需要配置計算設備(GPU/CPU),并加載CIFAR-10數據集。CIFAR-10包含10類32x32的彩色圖像,訓練集5萬張,測試集1萬張。

def setup_device():return torch.device("cuda" if torch.cuda.is_available() else "cpu")def load_dataset():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 歸一化到[-1,1]])train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)return train_dataset, test_dataset

3.2 原始模型訓練

使用預訓練的ResNet18模型,修改全連接層輸出為10類(匹配CIFAR-10的類別數),并進行5輪訓練:

def create_model(device):model = models.resnet18(pretrained=True)  # 加載ImageNet預訓練權重model.fc = nn.Linear(512, 10)  # 修改輸出層為10類return model.to(device)def train_model(model, train_loader, criterion, optimizer, device, epochs=3):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in tqdm(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")return model

3.3 結構化通道剪枝核心實現

本文重點是對卷積層進行結構化剪枝(按通道剪枝),具體步驟如下:

3.3.1 計算通道重要性

通過計算卷積核的L2范數評估通道重要性。假設卷積層權重維度為[out_channels, in_channels, kernel_h, kernel_w],將每個輸出通道的權重展平為一維向量,計算其L2范數,范數越小表示該通道對模型性能貢獻越低,越應被剪枝。

layer = dict(model.named_modules())[layer_name]  # 獲取目標卷積層
weight = layer.weight.data
channel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)  # 計算每個輸出通道的L2范數
3.3.2 生成剪枝掩碼

根據剪枝比例(如20%),選擇范數最小的通道生成掩碼:

num_channels = weight.shape[0]  # 原始輸出通道數(如ResNet18的conv1層為64)
num_prune = int(num_channels * amount)  # 需剪枝的通道數(如64*0.2=12)
_, indices = torch.topk(channel_norm, k=num_prune, largest=False)  # 找到最不重要的12個通道mask = torch.ones(num_channels, dtype=torch.bool)
mask[indices] = False  # 掩碼:保留的通道標記為True(52個),剪枝的標記為False(12個)
3.3.3 替換卷積層

創建新的卷積層,僅保留掩碼為True的通道:

new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,  # 剪枝后輸出通道數(52)kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None
).to(device)  # 移動到模型所在設備new_conv.weight.data = layer.weight.data[mask]  # 保留掩碼為True的通道權重
if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]  # 偏置同理
3.3.4 關鍵:剪枝后結構調整

直接剪枝會導致后續層(如BN層、殘差連接中的下采樣層)的輸入/輸出通道不匹配,必須同步調整:

(1) 調整BN層
卷積層后通常接BN層,BN的num_features需與卷積輸出通道數一致:

if 'conv1' in layer_name:bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)  # 新BN層通道數52with torch.no_grad():# 同步原始BN層的參數(僅保留未被剪枝的通道)new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1

(2) 調整殘差下采樣層
ResNet的殘差塊(如layer1.0)中,若主路徑的通道數被剪枝,需要通過1x1卷積的下采樣層(downsample)匹配 shortcut 的通道數:

block = model.layer1[0]
if not hasattr(block, 'downsample') or block.downsample is None:# 原始無downsample,創建新的1x1卷積+BNdownsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,  # 52(剪枝后的conv1輸出)out_channels=block.conv2.out_channels,  # 64(主路徑conv2的輸出)kernel_size=1,stride=1,bias=False).to(device)torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')  # 初始化權重downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)block.downsample = nn.Sequential(downsample_conv, downsample_bn)  # 添加downsample層
else:# 原有downsample層,調整輸入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channels  # 輸入通道改為52downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 輸入通道用掩碼篩選

(3) 前向傳播驗證
調整后需驗證模型能否正常前向傳播,避免通道不匹配導致的錯誤:

with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)  # 測試輸入(B, C, H, W)try:model(test_input)print("? 前向傳播驗證通過")except Exception as e:print(f"? 驗證失敗: {str(e)}")raise

3.3的總結,直接上代碼

def prune_conv_layer(model, layer_name, amount=0.2):# 獲取模型當前所在設備device = next(model.parameters()).device  # 新增:獲取設備layer = dict(model.named_modules())[layer_name]weight = layer.weight.datachannel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)num_channels = weight.shape[0]  # 原始通道數(如 64)num_prune = int(num_channels * amount)_, indices = torch.topk(channel_norm, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False  # 生成剪枝掩碼(長度 64,52 個 True)new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,  # 剪枝后通道數(如 52)kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None)new_conv = new_conv.to(device)  # 新增:移動到模型所在設備new_conv.weight.data = layer.weight.data[mask]  # 保留 mask 為 True 的通道if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]# 替換原始卷積層parent_name, sep, name = layer_name.rpartition('.')parent = model.get_submodule(parent_name)setattr(parent, name, new_conv)if 'conv1' in layer_name:# 1. 更新與 conv1 直接關聯的 BN1 層bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels)  # 新 BN 層通道數 52new_bn1 = new_bn1.to(device)  # 新增:移動到模型所在設備with torch.no_grad():new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1# 2. 處理殘差連接中的 downsample(關鍵修正:添加缺失的 downsample)block = model.layer1[0]if not hasattr(block, 'downsample') or block.downsample is None:# 原始無 downsample,需創建新的 1x1 卷積+BN 來匹配通道downsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,  # 52out_channels=block.conv2.out_channels,  # 64(主路徑輸出通道數)kernel_size=1,stride=1,bias=False)downsample_conv = downsample_conv.to(device)  # 新增:移動到模型所在設備# 初始化 1x1 卷積權重(這里簡單復制原模型可能的統計量,實際可根據需求調整)torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)downsample_bn = downsample_bn.to(device)  # 新增:移動到模型所在設備with torch.no_grad():# 初始化 BN 參數(可保持默認,或根據原模型統計量調整)downsample_bn.weight.fill_(1.0)downsample_bn.bias.zero_()downsample_bn.running_mean.zero_()downsample_bn.running_var.fill_(1.0)block.downsample = nn.Sequential(downsample_conv, downsample_bn)print("? 為 layer1.0 添加新的 downsample 層")else:# 原有 downsample 層,調整輸入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channels  # 輸入通道調整為 52downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 輸入通道用 mask 篩選downsample_conv = downsample_conv.to(device)  # 新增:移動到模型所在設備downsample_bn = block.downsample[1]new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)new_downsample_bn = new_downsample_bn.to(device)  # 新增:移動到模型所在設備with torch.no_grad():new_downsample_bn.weight.data = downsample_bn.weight.data.clone()new_downsample_bn.bias.data = downsample_bn.bias.data.clone()new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()block.downsample[1] = new_downsample_bn# 3. 同步 layer1.0.conv1 的輸入通道(保持原有邏輯)next_convs = ['layer1.0.conv1']for conv_path in next_convs:try:conv = model.get_submodule(conv_path)if conv.in_channels != new_conv.out_channels:print(f"同步輸入通道: {conv.in_channels}{new_conv.out_channels}")conv.in_channels = new_conv.out_channelsconv.weight = nn.Parameter(conv.weight.data[:, mask, :, :].clone())conv = conv.to(device)  # 新增:移動到模型所在設備except AttributeError as e:print(f"?? 卷積層調整失敗: {conv_path} ({str(e)})")# 驗證前向傳播with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)  # 確保測試輸入也在相同設備try:model(test_input)print("? 前向傳播驗證通過")except Exception as e:print(f"? 驗證失敗: {str(e)}")raisereturn model

3.4 剪枝模型微調

剪枝后模型的部分參數被刪除,需要通過微調恢復性能。一開始,我們只是在微調時凍結了除 fc 層外的所有參數,但是效果并不好,當然分析原因,除了動了conv1的原因(conv1 是模型的第一個卷積層,負責提取最基礎的圖像特征(如邊緣、紋理、顏色等)。這些底層特征對后續所有層的特征提取至關重要。),最重要的是裁剪后,需要對裁剪的層進行微調,確保參數適應新的特征維度。
微調時凍結了除 fc 層外的所有參數的代碼和結果:

for name, param in pruned_model.named_parameters():if 'fc' not in name:param.requires_grad = Falseoptimizer = optim.Adam(pruned_model.fc.parameters(), lr=0.001)print("微調剪枝后的模型")pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device,epochs=5)
原始模型準確率: 80.07%
剪枝后模型準確率: 37.80%

可以看到這個相差很大
本文選擇解凍被剪枝的層(如conv1bn1)及相關層(如layer1.0.conv1downsample)進行參數更新:

print("開始微調剪枝后的模型")
for name, param in pruned_model.named_parameters():# 僅解凍與剪枝相關的層if 'conv1' in name or 'bn1' in name or 'layer1.0.conv1' in name or 'layer1.0.downsample' in name or 'fc' in name:param.requires_grad = Trueelse:param.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.001)
pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=5)
原始模型準確率: 78.94%
剪枝后模型準確率:  81.30%

重新微調了裁剪后的層后,結果有了很大改變。

四、實驗結果與分析

通過代碼中的evaluate_model函數評估剪枝前后的模型準確率:

def evaluate_model(model, device, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()acc = 100 * correct / totalreturn acc

假設原始模型準確率為88.5%,剪枝20%通道后(模型大小降低約20%),通過微調可恢復至87.2%,驗證了剪枝策略的有效性。

五、總結與改進方向

本文實現了基于通道L2范數的結構化剪枝,重點解決了剪枝后模型結構不一致的問題(如BN層、殘差下采樣層的調整),并通過微調恢復了模型性能。
在這個例子中,僅裁剪 conv1 層的影響
僅裁剪 conv1 層對模型的影響極大,原因如下:

  • 底層特征的重要性 : conv1 輸出的是最基礎的圖像特征,所有后續層的特征均基于此生成。裁剪 conv1 會直接限制后續所有層的特征表達能力。
  • 結構連鎖反應 : conv1 的輸出通道減少會觸發 bn1 、 layer1.0.conv1 、 downsample 等多個模塊的調整,任何一個模塊的調整失誤(如通道數不匹配、參數初始化不當)都會導致整體性能下降。
    實際應用中可從以下方向改進:

模型裁剪通常優先選擇 中間層(如ResNet的 layer2 、 layer3 ) ,而非底層或頂層,原因如下:

  • 底層(如 conv1 ) :負責基礎特征提取,裁剪后特征損失大,對性能影響顯著。
  • 中間層(如 layer2 、 layer3 ) :特征具有一定抽象性但冗余度高(同一層的多個通道可能提取相似特征),裁剪后對性能影響較小。
  • 頂層(如 fc 層) :負責分類決策,參數密度高但冗余度低,裁剪易導致分類能力下降。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/908873.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/908873.shtml
英文地址,請注明出處:http://en.pswp.cn/news/908873.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

uni-app學習筆記二十四--showLoading和showModal的用法

showLoading(OBJECT) 顯示 loading 提示框, 需主動調用 uni.hideLoading 才能關閉提示框。 OBJECT參數說明 參數類型必填說明平臺差異說明titleString是提示的文字內容,顯示在loading的下方maskBoolean否是否顯示透明蒙層,防止觸摸穿透,默…

【大模型RAG】六大 LangChain 支持向量庫詳細對比

摘要 向量數據庫已經成為檢索增強生成(RAG)、推薦系統和多模態檢索的核心基礎設施。本文從 Chroma、Elasticsearch、Milvus、Redis、FAISS、Pinecone 六款 LangChain 官方支持的 VectorStore 出發,梳理它們的特性、典型應用場景與性能邊界&a…

【MySQL】數據庫三大范式

目錄 一. 什么是范式 二. 第一范式 三. 第二范式 不滿足第二范式時可能出現的問題 四. 第三范式 一. 什么是范式 在數據庫中范式其實就是一組規則,在我們設計數據庫的時候,需要遵守不同的規則要求,設計出合理的關系型數據庫,…

Coze工作流-語音故事創作-文本轉語音的應用

教程簡介 本教程將帶著大家去了解怎么樣把文本轉換成語音,例如說我們要做一些有聲故事,我們可能會用上一些語音的技術,來把你創作的故事朗讀出來 首先我們創建一個工作流 對各個模塊進行編輯,如果覺得系統提示詞寫的不好&#xf…

5.子網劃分及分片相關計算

某公司網絡使用 IP 地址空間 192.168.2.0/24,現需將其均分給 市場部 和 研發部 兩個子網。已知: 🏢 市場部子網 🖥? 已分配 IP 地址范圍:192.168.2.1 ~ 192.168.2.30🌐 路由器接口 IP:192.16…

三體問題詳解

從物理學角度,三體問題之所以不穩定,是因為三個天體在萬有引力作用下相互作用,形成一個非線性耦合系統。我們可以從牛頓經典力學出發,列出具體的運動方程,并說明為何這個系統本質上是混沌的,無法得到一般解…

機器學習算法時間復雜度解析:為什么它如此重要?

時間復雜度的重要性 雖然scikit-learn等庫讓機器學習算法的實現變得異常簡單(通常只需2-3行代碼),但這種便利性往往導致使用者忽視兩個關鍵方面: 算法核心原理的理解缺失 忽視算法的數據適用條件 典型算法的時間復雜度陷阱 SV…

uniapp 對接騰訊云IM群組成員管理(增刪改查)

UniApp 實戰:騰訊云IM群組成員管理(增刪改查) 一、前言 在社交類App開發中,群組成員管理是核心功能之一。本文將基于UniApp框架,結合騰訊云IM SDK,詳細講解如何實現群組成員的增刪改查全流程。 權限校驗…

OPENCV圖形計算面積、弧長API講解(1)

一.OPENCV圖形面積、弧長計算的API介紹 之前我們已經把圖形輪廓的檢測、畫框等功能講解了一遍。那今天我們主要結合輪廓檢測的API去計算圖形的面積,這些面積可以是矩形、圓形等等。圖形面積計算和弧長計算常用于車輛識別、橋梁識別等重要功能,常用的API…

一.設計模式的基本概念

一.核心概念 對軟件設計中重復出現問題的成熟解決方案,提供代碼可重用性、可維護性和擴展性保障。核心原則包括: 1.1. 單一職責原則? ?定義?:一個類只承擔一個職責,避免因職責過多導致的代碼耦合。 1.2. 開閉原則? ?定義?&#xf…

React第五十七節 Router中RouterProvider使用詳解及注意事項

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一個核心組件&#xff0c;用于提供基于數據路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了傳統的 <BrowserRouter>&#xff0c;支持更強大的數據加載和操作功能&#xff08;如 loader 和…

Opencv中的addweighted函數

一.addweighted函數作用 addweighted&#xff08;&#xff09;是OpenCV庫中用于圖像處理的函數&#xff0c;主要功能是將兩個輸入圖像&#xff08;尺寸和類型相同&#xff09;按照指定的權重進行加權疊加&#xff08;圖像融合&#xff09;&#xff0c;并添加一個標量值&#x…

C++ 基礎特性深度解析

目錄 引言 一、命名空間&#xff08;namespace&#xff09; C 中的命名空間? 與 C 語言的對比? 二、缺省參數? C 中的缺省參數? 與 C 語言的對比? 三、引用&#xff08;reference&#xff09;? C 中的引用? 與 C 語言的對比? 四、inline&#xff08;內聯函數…

關于面試找工作的總結(四)

不同情況下收到offer后的處理方法 1.不會去的,只是面試練手2.還有疑問,考慮中3.offer/職位不滿足期望的4.已確認,但又收到更好的5.還想挽回之前的offer6.確認,準備入職7.還想拖一下的1.不會去的,只是面試練手 HR您好,非常榮幸收到貴司的offer,非常感謝一直以來您的幫助,…

什么是高考?高考的意義是啥?

能見到這個文章的群體&#xff0c;應該都經歷過高考&#xff0c;突然想起“什么是高考&#xff1f;意義何在&#xff1f;” 一、高考的定義與核心功能 **高考&#xff08;普通高等學校招生全國統一考試&#xff09;**是中國教育體系的核心選拔性考試&#xff0c;旨在為高校選拔…

L1和L2核心區別 !!--part 2

哈嘍&#xff0c;我是 我不是小upper~ 昨天&#xff0c;咱們分享了關于 L1 正則化和 L2 正則化核心區別的精彩內容。今天我來進一步補充和拓展。 首先&#xff0c;咱們先來聊聊 L1 和 L2 正則化&#xff0c;方便剛接觸的同學理解。 L1 正則化&#xff08;Lasso&#xff09;&…

字節推出統一多模態模型 BAGEL,GPT-4o 級的圖像生成能力直接開源了!

字節推出的 BAGEL 是一個開源的統一多模態模型&#xff0c;他們直接開源了GPT-4o級別的圖像生成能力。&#xff08;輕松拿捏“萬物皆可吉卜力”玩法~&#xff09;。可以在任何地方對其進行微調、提煉和部署&#xff0c;它以開放的形式提供與 GPT-4o 和 Gemini 2.0 等專有系統相…

互聯網大廠Java面試:從Spring Cloud到Kafka的技術考察

場景&#xff1a;互聯網大廠Java求職者面試 面試官與謝飛機的對話 面試官&#xff1a;我們先從基礎開始&#xff0c;謝飛機&#xff0c;你能簡單介紹一下Java SE和Java EE的區別嗎&#xff1f; 謝飛機&#xff1a;哦&#xff0c;這個簡單。Java SE是標準版&#xff0c;適合桌…

18-Oracle 23ai JSON二元性顛覆傳統

在當今百花齊放的多模型數據庫時代&#xff0c;開發人員常在關系型與文檔型數據庫間艱難取舍。Oracle Database 23ai推出的JSON關系二元性&#xff08;JSON Relational Duality&#xff09;?? 和二元性視圖&#xff08;Duality Views&#xff09;?? 創新性地統一了兩者優勢…

藍橋杯 冶煉金屬

原題目鏈接 &#x1f527; 冶煉金屬轉換率推測題解 &#x1f4dc; 原題描述 小藍有一個神奇的爐子用于將普通金屬 O O O 冶煉成為一種特殊金屬 X X X。這個爐子有一個屬性叫轉換率 V V V&#xff0c;是一個正整數&#xff0c;表示每 V V V 個普通金屬 O O O 可以冶煉出 …