深度學習之模型壓縮三駕馬車:基于ResNet18的模型剪枝實戰(2)

前言

《深度學習之模型壓縮三駕馬車:基于ResNet18的模型剪枝實戰(1)》里面我只是提到了對conv1層進行剪枝,只是為了驗證這個剪枝的整個過程,但是后面也有提到:僅裁剪 conv1層的影響極大,原因如下:

  • 底層特征的重要性 : conv1輸出的是最基礎的圖像特征,所有后續層的特征均基于此生成。裁剪 conv1 會直接限制后續所有層的特征表達能力。
  • 結構連鎖反應 : conv1的輸出通道減少會觸發 bn1layer1.0.conv1downsample 等多個模塊的調整,任何一個模塊的調整失誤(如通道數不匹配、參數初始化不當)都會導致整體性能下降。
    雖然,在例子中,我們只是簡單的進行了驗證,發現效果也不是很差,但是如果具體到自己的數據,或者更加復雜的特征或者模型,可能就會影響到了整體的性能,因此,我們在原有的基礎上做了如下的改動:
  1. 剪枝目標層調整 :將 conv1 改為 layer2.0.conv1 ,減少對底層特征的破壞。
  2. 通道評估優化 :通過前向傳播收集激活值,優先剪枝激活值低的通道,更符合實際特征貢獻。
  3. 微調策略改進 :動態解凍剪枝層及關聯的BN、downsample層,學習率降低(0.0001),微調輪次增加(10輪),確保參數充分適應。

這些修改可顯著提升剪枝后模型的穩定性和準確率。建議運行時觀察微調階段的Loss是否持續下降,若下降緩慢可進一步降低學習率(如0.00001)。
所有代碼都在這:https://gitee.com/NOON47/model_prune

詳細改動

  1. 剪枝目標層調整 :將 conv1 改為 layer2.0.conv1 ,減少對底層特征的破壞。
    layer_to_prune = 'layer2.0.conv1'  # 顯式定義要剪枝的層名pruned_model = prune_conv_layer(model, layer_to_prune, amount=0.2)
  1. 通道評估優化 :通過前向傳播收集激活值,優先剪枝激活值低的通道,更符合實際特征貢獻。
    model.eval()with torch.no_grad():test_input = torch.randn(128, 3, 32, 32).to(device)  # 模擬 CIFAR10 輸入features = []def hook_fn(module, input, output):features.append(output)handle = layer.register_forward_hook(hook_fn)model(test_input)handle.remove()activation = features[0]  # shape: [128, out_channels, H, W]channel_importance = activation.mean(dim=(0, 2, 3))  # 按通道求平均激活值num_channels = weight.shape[0]num_prune = int(num_channels * amount)_, indices = torch.topk(channel_importance, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False  # 生成剪枝掩碼
  1. 微調策略改進 :動態解凍剪枝層及關聯的BN、downsample層,學習率降低(0.0001),微調輪次增加(10輪),確保參數充分適應。
    print("開始微調剪枝后的模型")# 新增:根據剪枝層動態解凍相關層(假設剪枝層為layer2.0.conv1)pruned_layer_prefix = layer_to_prune.rpartition('.')[0]  # 例如 'layer2.0'for name, param in pruned_model.named_parameters():if (pruned_layer_prefix in name) or ('fc' in name) or ('bn' in name):  # 解凍剪枝層、BN層和fc層param.requires_grad = Trueelse:param.requires_grad = Falseoptimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.0001)  # 微調學習率降低pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=10)  # 增加微調輪次

完整的裁剪函數:

def prune_conv_layer(model, layer_name, amount=0.2):device = next(model.parameters()).devicelayer = dict(model.named_modules())[layer_name]weight = layer.weight.data# 基于激活值的通道重要性評估model.eval()with torch.no_grad():test_input = torch.randn(128, 3, 32, 32).to(device)  # 模擬 CIFAR10 輸入features = []def hook_fn(module, input, output):features.append(output)handle = layer.register_forward_hook(hook_fn)model(test_input)handle.remove()activation = features[0]  # shape: [128, out_channels, H, W]channel_importance = activation.mean(dim=(0, 2, 3))  # 按通道求平均激活值num_channels = weight.shape[0]num_prune = int(num_channels * amount)_, indices = torch.topk(channel_importance, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False  # 生成剪枝掩碼# 創建并替換新卷積層new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None).to(device)new_conv.weight.data = layer.weight.data[mask]  # 應用掩碼剪枝權重if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]# 替換原始卷積層parent_name, sep, name = layer_name.rpartition('.')parent = model.get_submodule(parent_name)setattr(parent, name, new_conv)# 僅處理首層 conv1 的特殊邏輯if layer_name == 'conv1':# 更新首層 BN 層(bn1)bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)with torch.no_grad():new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1# 處理 layer1.0 的 downsample 層(若不存在則創建)block = model.layer1[0]if not hasattr(block, 'downsample') or block.downsample is None:# 創建 1x1 卷積 + BN 用于通道匹配downsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,out_channels=block.conv2.out_channels,  # 與主路徑輸出通道一致(ResNet18 為 64)kernel_size=1,stride=1,bias=False).to(device)# 初始化權重(使用原卷積層的統計量)with torch.no_grad():downsample_conv.weight.data = layer.weight.data.mean(dim=(2,3), keepdim=True)  # 原卷積核均值初始化downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)with torch.no_grad():downsample_bn.weight.data.fill_(1.0)downsample_bn.bias.data.zero_()downsample_bn.running_mean.data.zero_()downsample_bn.running_var.data.fill_(1.0)block.downsample = nn.Sequential(downsample_conv, downsample_bn)print("? 為 layer1.0 添加新的 downsample 層")else:# 調整已有 downsample 層的輸入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channelsdownsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()).to(device)# 更新對應的 BN 層downsample_bn = block.downsample[1]new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)with torch.no_grad():new_downsample_bn.weight.data = downsample_bn.weight.data.clone()new_downsample_bn.bias.data = downsample_bn.bias.data.clone()new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()block.downsample[1] = new_downsample_bn# 同步 layer1.0.conv1 的輸入通道target_conv = model.layer1[0].conv1if target_conv.in_channels != new_conv.out_channels:print(f"同步 layer1.0.conv1 輸入通道: {target_conv.in_channels}{new_conv.out_channels}")target_conv.in_channels = new_conv.out_channelstarget_conv.weight = nn.Parameter(target_conv.weight.data[:, mask, :, :].clone()).to(device)else:# 中間層剪枝邏輯(如 layer2.0.conv1)block_prefix = layer_name.rsplit('.', 1)[0]  # 提取 block 前綴(如 'layer2.0')block = model.get_submodule(block_prefix)     # 獲取對應的 block(如 layer2.0)# 更新當前 block 內的 BN 層(conv1 對應 bn1,conv2 對應 bn2)target_bn_name = f"{block_prefix}.bn1" if 'conv1' in layer_name else f"{block_prefix}.bn2"try:target_bn = model.get_submodule(target_bn_name)new_bn = nn.BatchNorm2d(new_conv.out_channels).to(device)with torch.no_grad():new_bn.weight.data = target_bn.weight.data[mask].clone()new_bn.bias.data = target_bn.bias.data[mask].clone()new_bn.running_mean.data = target_bn.running_mean.data[mask].clone()new_bn.running_var.data = target_bn.running_var.data[mask].clone()setattr(block, target_bn_name.split('.')[-1], new_bn)  # 替換原 BN 層print(f"? 更新剪枝層 {layer_name} 對應的 BN 層 {target_bn_name}")except AttributeError:print(f"?? 未找到剪枝層 {layer_name} 對應的 BN 層,跳過 BN 更新")# 新增:同步后續卷積層的輸入通道(如 conv1 后調整 conv2)if 'conv1' in layer_name:next_conv = block.conv2if next_conv.in_channels != new_conv.out_channels:print(f"同步 {block_prefix}.conv2 輸入通道: {next_conv.in_channels}{new_conv.out_channels}")next_conv.in_channels = new_conv.out_channelsnext_conv.weight = nn.Parameter(next_conv.weight.data[:, mask, :, :].clone()).to(device)  # 按剪枝掩碼篩選輸入通道權重# 可選:如果存在 downsample 層,調整其輸入通道(根據實際需求啟用)# if hasattr(block, 'downsample') and block.downsample is not None:#     downsample_conv = block.downsample[0]#     downsample_conv.in_channels = new_conv.out_channels#     downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()).to(device)#     print(f"? 調整剪枝層 {layer_name} 關聯的 downsample 層輸入通道")# 驗證前向傳播with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)try:model(test_input)print("? 前向傳播驗證通過")except Exception as e:print(f"? 驗證失敗: {str(e)}")raisereturn model

改動后結果

經過改動后, 增加微調輪次,得到的結果如下:

剪枝前模型大小信息:
==========================================================================================
Total params: 11,181,642
Trainable params: 11,181,642
Non-trainable params: 0
Total mult-adds (M): 37.03
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.81
Params size (MB): 44.73
Estimated Total Size (MB): 45.55
==========================================================================================
原始模型準確率: 81.42%剪枝后模型大小信息:
==========================================================================================
Total params: 11,138,392
Trainable params: 11,138,392
Non-trainable params: 0
Total mult-adds (M): 36.33
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.80
Params size (MB): 44.55
Estimated Total Size (MB): 45.37
==========================================================================================
剪枝后模型準確率: 83.28%

個人認為,這個才是比較符合實際應用的。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/86533.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/86533.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/86533.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

傳輸層協議:UDP

目錄 1、概念 2、報文結構 3、核心特性 3.1 無連接 3.2 不可靠交付 3.3 面向數據報 3.4 輕量級&高效 3.5 支持廣播和組播 4、典型應用場景 5、優缺點分析 6、與TCP的區別 1、概念 UDP(User Datagram Protocol,用戶數據報協議&#xff09…

JVM虛擬機:內存結構、垃圾回收、性能優化

1、JVM虛擬機的簡介 Java 虛擬機(Java Virtual Machine 簡稱:JVM)是運行所有 Java 程序的抽象計算機,是 Java 語言的運行環境,實現了 Java 程序的跨平臺特性。JVM 屏蔽了與具體操作系統平臺相關的信息,使得 Java 程序只需生成在 JVM 上運行的目標代碼(字節碼),就可以…

c++ 面試題(1)-----深度優先搜索(DFS)實現

操作系統:ubuntu22.04 IDE:Visual Studio Code 編程語言:C11 題目描述 地上有一個 m 行 n 列的方格,從坐標 [0,0] 起始。一個機器人可以從某一格移動到上下左右四個格子,但不能進入行坐標和列坐標的數位之和大于 k 的格子。 例…

【匯編逆向系列】七、函數調用包含多個參數之浮點型- XMM0-3寄存器

目錄 1. 匯編代碼 1.1 debug編譯 1.2 release編譯 2. 匯編分析 2.1 浮點參數傳遞規則 2.2 棧幀rsp的變化時序 2.3 參數的訪問邏輯 2.4 返回值XMM0寄存器 3. 匯編轉化 3.1 Debug編譯 3.2 Release 編譯 3.3 C語言轉化 1. 匯編代碼 上一節介紹了整型的函數傳參&#x…

華為云Flexus+DeepSeek征文 | 從零到一:用Flexus云服務打造低延遲聯網搜索Agent

作者簡介 我是摘星,一名專注于云計算和AI技術的開發者。本次通過華為云MaaS平臺體驗DeepSeek系列模型,將實際使用經驗分享給大家,希望能幫助開發者快速掌握華為云AI服務的核心能力。 目錄 作者簡介 前言 1. 項目背景與技術選型 1.1 項目…

【多智能體】受木偶戲啟發實現多智能體協作編排

😊你好,我是小航,一個正在變禿、變強的文藝傾年。 🔔本專欄《人工智能》旨在記錄最新的科研前沿,包括大模型、具身智能、智能體等相關領域,期待與你一同探索、學習、進步,一起卷起來叭&#xff…

Java八股文——Spring篇

文章目錄 Java八股文專欄其它文章Java八股文——Spring篇SpringSpring的IoC和AOPSpring IoC實現機制Spring AOP實現機制 動態代理JDK ProxyCGLIBByteBuddy Spring框架中的單例Bean是線程安全的嗎?什么是AOP,你們項目中有沒有使用到AOPSpring中的事務是如…

NineData數據庫DevOps功能全面支持百度智能云向量數據庫 VectorDB,助力企業 AI 應用高效落地

NineData 的數據庫 DevOps 解決方案已完成對百度智能云向量數據庫 VectorDB 的全鏈路適配,成為國內首批提供 VectorDB 原生操作能力的服務商。此次合作聚焦 AI 開發核心場景,通過標準化 SQL 工作臺與細粒度權限管控兩大能力,助力企業安全高效…

開源技術驅動下的上市公司財務主數據管理實踐

開源技術驅動下的上市公司財務主數據管理實踐 —— 以人造板制造業為例 引言:財務主數據的戰略價值與行業挑戰 在資本市場監管日益嚴格與企業數字化轉型的雙重驅動下,財務主數據已成為上市公司財務治理的核心基礎設施。對于人造板制造業而言&#xff0…

借助它,普轉也能獲得空轉信息?

在生命科學研究領域,轉錄組技術是探索基因表達奧秘的有力工具,在疾病機制探索、生物發育進程解析等諸多方面取得了顯著進展。然而,隨著研究的深入,研究人員發現普通轉錄組只能提供整體樣本中的基因表達水平信息,卻無法…

synchronized 學習

學習源: https://www.bilibili.com/video/BV1aJ411V763?spm_id_from333.788.videopod.episodes&vd_source32e1c41a9370911ab06d12fbc36c4ebc 1.應用場景 不超賣,也要考慮性能問題(場景) 2.常見面試問題: sync出…

Java事務回滾詳解

一、什么是事務回滾? 事務回滾指的是:當執行過程中發生異常時,之前對數據庫所做的更改全部撤銷,數據庫狀態恢復到事務開始前的狀態。這是數據庫“原子性”原則的體現。 二、Spring 中的 Transactional 默認行為 在 Spring 中&am…

云災備數據復制技術研究

云災備數據復制技術:數字時代的“安全氣囊” 在當今信息化時代,數據就像城市的“生命線”,一旦中斷,后果不堪設想。想象一下,如果政務系統突然崩潰,成千上萬的市民服務將陷入癱瘓。這就是云災備技術的重要…

如何處理Shopify主題的顯示問題:實用排查與修復指南

在Shopify店鋪運營過程中,主題顯示問題是影響用戶體驗與品牌形象的常見痛點。可能是字體錯位、圖片無法加載、移動端顯示混亂、功能失效等,這些都可能造成客戶流失和轉化下降。 本文將從問題識別、原因分析、修復方法到開發者建議全方位解讀如何高效解決…

前端監控方案詳解

一、前端監控方案是什么? 前端監控方案是一套系統化的工具和流程,用于收集、分析和報告網站或Web應用在前端運行時的各種性能指標、錯誤日志、用戶行為等數據。它通常包括以下幾個核心模塊: 性能監控:頁面加載時間、資源加載時間…

Camera相機人臉識別系列專題分析之十二:人臉特征檢測FFD算法之libvega_face.so數據結構詳解

【關注我,后續持續新增專題博文,謝謝!!!】 上一篇我們講了: Camera相機人臉識別系列專題分析之十一:人臉特征檢測FFD算法之低功耗libvega_face.so人臉屬性(年齡,性別,膚…

如何配置HarmonyOS 5與React Native的開發環境?

配置 HarmonyOS 5 與 React Native 的開發環境需遵循以下步驟 一、基礎工具安裝 ?DevEco Studio 5.0? 從 HarmonyOS 開發者官網 下載安裝勾選組件: HarmonyOS SDK (API 12)ArkTS 編譯器JS/ArkTS 調試工具HarmonyOS 本地模擬器 ?Node.js 18.17 # 安裝后驗證版…

kotlin kmp 副作用函數 effect

在 Kotlin Multiplatform (KMP) Compose 中,“effect functions”(或“effect handlers”)是專門的可組合函數,用于在 UI 中管理副作用。 在 Compose 中,可組合函數應該是“純”的和聲明式的。這意味著它們應該理想地…

3.3.1_1 檢錯編碼(奇偶校驗碼)

從這節課開始,我們會探討數據鏈路層的差錯控制功能,差錯控制功能的主要目標是要發現并且解決一個幀內部的位錯誤,我們需要使用特殊的編碼技術去發現幀內部的位錯誤,當我們發現位錯誤之后,通常來說有兩種解決方案。第一…

【Pandas】pandas DataFrame isna

Pandas2.2 DataFrame Missing data handling 方法描述DataFrame.fillna([value, method, axis, …])用于填充 DataFrame 中的缺失值(NaN)DataFrame.backfill(*[, axis, inplace, …])用于**使用后向填充(即“下一個有效觀測值”&#xff09…