模型剪枝----ResNet18剪枝實戰

剪枝

模型剪枝(Model Pruning) 是一種 模型壓縮(Model Compression) 技術,主要思想是:
深度神經網絡里有很多 冗余參數(對預測結果貢獻很小)。
通過去掉這些冗余連接/通道/卷積核,能讓模型更小、更快,同時盡量保持精度。

非結構化剪枝(Unstructured Pruning)

對單個權重參數設置閾值,小于閾值的直接置零。
優點:保留了原始網絡結構,容易實現。
缺點:稀疏矩陣計算對普通硬件加速有限(需要專門稀疏庫)。

#將所有的卷積層通道減掉30%
for module in pruned_model.modules():if isinstance(module,nn.Conv2d):#這行代碼的作用是對指定模塊按照L2范數的標準,沿著輸出通道維度剪去30%的不重要通道,prune.ln_structured(module,name = "weight",amount = 0.3,n=2,dim = 0)

對ResNet18減和不減的效果差不多,一個是精度,另一個是一輪推理的時間
在這里插入圖片描述
分析原因 確實把 30% 卷積核置零,但是模塊結構沒變:Conv2d 還是原來那么大,只是部分權重被置零, PyTorch 的默認實現不會自動跳過這些“無效通道”, 所以 FLOPs 還是一樣,ptflops 統計出來的數字沒減少, GPU 上仍然執行全量卷積,推理時間幾乎不會變化

結構化剪枝(Structured Pruning)

刪除整個卷積核、通道、層。
優點:能直接減少計算量和推理時間。
缺點:剪掉的多了容易掉精度。

完整代碼

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import time
from tqdm import tqdm
from ptflops import get_model_complexity_info
import torch_pruning as tp# ======================
# 1. 數據準備
# ======================
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,shuffle=False, num_workers=2)device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" )
# ======================
# 2. 定義訓練和測試函數
# ======================
def train(model,optimizer,criterion,epoch):model.train()for inx,(inputs,targets) in enumerate(trainloader):inputs,targets = inputs.to(device),targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs,targets)loss.backward()optimizer.step()def test(model,criterion,epoch,tag = ""):model.eval()start = time.time()correct,total,loss_sum = 0,0,0.0with torch.no_grad():for inputs, targets in testloader:inputs,targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss_sum = criterion(outputs,targets).item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()acc = 100. * correct / totalend = time.time()time_cost = end - startprint(f"{tag} Epoch {epoch}: Loss={loss_sum:.4f}, Acc={acc:.2f}%, Time={time_cost:.2f}s")return acc,time_costdef print_model_stats(model,tag = ""):#統計模型參數和flopsmac, params = get_model_complexity_info(model,(3,32,32),as_strings = True,print_per_layer_stat = False,verbose = False)print(f"{tag} Params:{params},FLOPs:{mac}")# ======================
# 3. 訓練基線模型
# ======================
print("===============BaseLine ResNet18")
baseline_model = models.resnet18(pretrained = True)
baseline_model.fc = nn.Linear(baseline_model.fc.in_features,10)
baseline_model = baseline_model.to(device)
print_model_stats(baseline_model,"Baseline")criterion = nn.CrossEntropyLoss()
optimer = optim.SGD(baseline_model.parameters(),lr = 0.01,momentum = 0.9,weight_decay = 5e-4)
baseline_acc = []
baseline_time = []
for epoch in tqdm(range(10)):train(baseline_model,optimer,criterion,epoch)acc,time_cost = test(baseline_model,criterion,epoch,"Baseline")baseline_acc.append(acc)baseline_time.append(time_cost)# ======================
# 4. 剪枝 + 微調
# ======================
pruned_model = models.resnet18(pretrained = True)
pruned_model.fc = nn.Linear(pruned_model.fc.in_features,10)
pruned_model = pruned_model.to(device)#===============非結構化剪枝=====================
# #將所有的卷積層通道減掉30%
# for module in pruned_model.modules():
#     if isinstance(module,nn.Conv2d):
#         #這行代碼的作用是對指定模塊按照L2范數的標準,沿著輸出通道維度剪去30%的不重要通道,
#         prune.ln_structured(module,name = "weight",amount = 0.3,n=2,dim = 0)#==========================結構化剪枝=====================
# 創建依賴圖對象,用于處理剪枝時各層之間的依賴關系
DG = tp.DependencyGraph()
# 構建模型的依賴關系圖,需要提供示例輸入來追蹤計算圖
# example_inputs用于追蹤模型的前向傳播路徑,確定各層之間的依賴關系
DG.build_dependency(pruned_model,example_inputs = torch.randn(1,3,32,32).to(device))def prune_conv_by_ratio(conv, ratio=0.3):# 計算每個輸出通道的L1范數(絕對值求和),用于評估通道的重要性# conv.weight.data.abs().sum((1, 2, 3)) 對卷積核的后三維(H, W, C_in)求和,得到每個輸出通道的L1范數weight = conv.weight.data.abs().sum((1, 2, 3))  # 根據指定的剪枝比例計算需要移除的通道數量num_remove = int(weight.numel() * ratio)# 找到L1范數最小的num_remove個通道的索引# torch.topk返回最大的k個元素,設置largest=False后返回最小的k個元素_, idxs = torch.topk(weight, k=num_remove, largest=False)# 獲取剪枝組,指定要剪枝的層、剪枝方式和剪枝索引# tp.prune_conv_out_channels表示沿輸出通道維度進行剪枝group = DG.get_pruning_group(conv, tp.prune_conv_out_channels, idxs=idxs.tolist())# 執行剪枝操作,物理移除指定的通道group.prune()# 遍歷剪枝模型的所有模塊
for m in pruned_model.modules():# 檢查模塊是否為卷積層if isinstance(m, nn.Conv2d):# 對該卷積層執行剪枝操作,移除30%的輸出通道prune_conv_by_ratio(m, ratio=0.3)#=======================================================print_model_stats(pruned_model,"Pruned")
criterion1 = nn.CrossEntropyLoss()
optimer1 = optim.SGD(pruned_model.parameters(),lr = 0.01,momentum = 0.9,weight_decay = 5e-4)
pruned_acc = []
pruned_time = []for epoch in tqdm(range(10)):train(pruned_model,optimer1,criterion1,epoch)acc,time_cost = test(pruned_model,criterion1,epoch,"Pruned")pruned_acc.append(acc)pruned_time.append(time_cost)# ======================
# 5. 對比結果
# ======================
print("\n==== Final Accuracy Comparison ====")print(f" Baseline={max(baseline_acc):.2f}% time={sum(baseline_time)/len(baseline_time):.2f}, Pruned={max(pruned_acc):.2f}% time={sum(pruned_time)/len(pruned_time):.2f}")

最終訓練10輪的情況下精度下降7%,模型參數量減少4倍,感覺能夠接受
Params:11.18 M – > 2.7M
FLOPs:37.25 MMac --> 9.48 MMac
acc : 82.86% —> 75.77%
time : 1.20 ----> 1.12
在這里插入圖片描述

基于正則化/稀疏約束

在訓練時加上稀疏正則項,讓網絡自動學習出“重要性低”的權重趨近于零,再做剪枝。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/98441.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/98441.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/98441.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

K8S-Pod(上)

Pod概念 Pod 是可以在 Kubernetes 中創建和管理的、最小的可部署的計算單元。 Pod是一組(一個或多個)容器;這些容器共享存儲、網絡、以及怎樣運行這些容器的規約。Pod 中的內容總是并置(colocated)的并且一同調度&am…

Flink TaskManager日志時間與實際時間有偏差

Flink 啟動一個任務后,發現TaskManager上日志時間與實際時間相差約 15 小時。 核心原因可能是: 1、 服務器(或容器)的系統時間配置錯誤2、 Flink 日志組件(如 Logback/Log4j)的時間配置未使用系統默認時區…

Webug3.0通關筆記18 中級進階第06關 實戰練習:DisCuz論壇SQL注入漏洞

目錄 一、環境搭建 1、服務啟動 2、源碼解壓 3、構造訪問靶場URL 4、靶場安裝 5、訪問論壇首頁 二、代碼分析 1、源碼分析 2、SQL注入分析 三、滲透實戰 (1)判斷是否有SQL注入風險 (2)查詢賬號密碼 Discuz! 作為國內知…

SWEET:大語言模型的選擇性水印

摘要背景與問題大語言模型出色的生成能力引發了倫理與法律層面的擔憂,于是通過嵌入水印來檢測機器生成文本的方法逐漸發展起來。但現有工作在代碼生成任務中無法良好發揮作用,原因在于代碼生成任務本身的特性(代碼有其特定的語法、邏輯結構&a…

FastDFS V6雙IP特性及配置

FastDFS V6.0開始支持雙IP,tracker server和storage server均支持雙IP。V6.0新增特性說明如下:支持雙IP,一個內網IP,一個外網IP,可以支持NAT方式的內網和外網兩個IP,解決跨機房或混合云部署問題。FastDFS雙…

筆記本、平板如何成為電腦拓展屏?向日葵16成為副屏功能一鍵實現

向日葵16重磅上線,本次更新新增了諸多實用功能,提升遠控效率,實現應用融合突破設備邊界,同時全面提升遠控性能,操作更順滑、畫質更清晰!無論遠程辦公、設計、IT運維、開發還是游戲娛樂,向日葵16…

基于Spring Boot + MyBatis的用戶管理系統配置

我來為您詳細分析這兩個配置文件的功能和含義。 一、文件整體概述 這是一個基于Spring Boot MyBatis的用戶管理系統配置: UserMapper.xml:MyBatis的SQL映射文件,定義了用戶表的增刪改查操作application.yml:Spring Boot的核心配置…

80(HTTP默認端口)和8080端口(備用HTTP端口)區別

文章目錄**1. 用途**- **80端口**- **8080端口****2. 默認配置**- **80端口**- **8080端口****3. 聯系**- **邏輯端口**:兩者都是TCP/IP協議中的邏輯端口,用于標識不同的網絡服務。- **可配置性**:端口號可以根據需要修改(例如將T…

【開題答辯全過程】以 汽車知名品牌信息管理系統為例,包含答辯的問題和答案

個人簡介一名14年經驗的資深畢設內行人,語言擅長Java、php、微信小程序、Python、Golang、安卓Android等開發項目包括大數據、深度學習、網站、小程序、安卓、算法。平常會做一些項目定制化開發、代碼講解、答辯教學、文檔編寫、也懂一些降重方面的技巧。感謝大家的…

從全棧工程師視角解析Java與前端技術在電商場景中的應用

從全棧工程師視角解析Java與前端技術在電商場景中的應用 面試背景介紹 面試官:你好,很高興見到你。我叫李明,是這家電商平臺的資深架構師。今天我們會聊聊你的技術能力和項目經驗。你可以先簡單介紹一下自己嗎? 應聘者&#xff1a…

【python】python進階——多線程

引言在現代軟件開發中,程序的執行效率至關重要。無論是處理大量數據、響應用戶交互,還是與外部系統通信,常常需要讓程序同時執行多個任務。Python作為一門功能強大且易于學習的編程語言,提供了多種并發編程方式,其中多…

【JavaEE】(23) 綜合練習--博客系統

一、功能描述 用戶登錄后,可查看所有人的博客。點擊 “查看全文” 可查看該博客完整內容。如果該博客作者是登錄用戶,可以編輯或刪除博客。發表博客的頁面同編輯頁面。 本練習的博客網站,并沒有添加注冊功能,以及上傳作者頭像功能…

MySQL全庫檢索關鍵詞 - idea 工具 Full-Text Search分享

我們經常要在庫中查找一個數據,又不知道在哪個表、哪個字段;或者想找到哪里有在用這個數據。我們可以用:idea 的 Database工具 - Full-Text Search打開idea,在工具欄找到 Database 然后新建自己的連接,然后右鍵&#x…

銀行卡號識別案例

代碼實現:import cv2 import numpy as np import argparse import myutils-i moban.png -t card1.pngap argparse.ArgumentParser() ap.add_argument("-i","--image", requiredTrue,help"path to input image") ap.add_argument(&quo…

云管平臺上線只是開始:從“建好”到“用好”的運營、推廣與深化指南

項目上線的喜悅轉瞬即逝,隨之而來的是一個更為現實和復雜的階段:運營。云管平臺(CMP)的成功,不再僅僅取決于其技術架構的先進性,更在于它能否融入組織的肌理,為不同角色持續創造價值。本文將從管理者、平臺團隊、開發者、運維和財務五個核心角色的視角,深入探討平臺上線…

distributed.client.Client 用戶可調用函數分析

distributed.client.Client 用戶可調用函數分析 1. 核心計算函數 任務提交和執行submit(func, *args, keyNone, workersNone, resourcesNone, retriesNone, priority0, fifo_timeout60s, allow_other_workersFalse, actorFalse, actorsFalse, pureNone, **kwargs) 提交單個函數…

數字圖像處理——信用卡識別

在數字支付時代,信用卡處理自動化技術日益重要。本文介紹如何利用Python和OpenCV實現信用卡數字的自動識別,結合圖像處理與模式識別技術,具有顯著實用價值。系統概述與工作原理信用卡數字識別系統包含兩大核心模塊:模板數字預處理…

嵌入式ARM64 基于RK3588原生SDK添加用戶配置選項./build lunch debian

1 背景 在我們正常拿到SDK后會有一些配置選項,在使用./build.sh lunch之后會輸出一些defautconfig讓我們選擇,瑞芯微的原廠sdk會提供一些主板的配置選項,但是我們的如果是一塊新的主板就需要添加自己的配置選項,本文就討論如何來添…

專為石油和天然氣檢測而開發的基于無人機的OGI相機

專為石油和天然氣檢測而開發的基于無人機的OGI相機基于無人機的 OGI 相機:(Optical Gas Imaging,光學氣體成像)其實是近幾年油氣、電力、化工等行業里非常熱門的應用方向。什么是 OGI 相機OGI(Optical Gas Imaging)&am…

iPhone17全系優缺點分析,加持遠程控制讓你的手機更好用!

知名數碼廠商蘋果,不久前已官宣將于北京時間9月10日凌晨1點開啟發布會,主打對于iPhone 17系列產品介紹,并且和以往不同的是,今年會在購物平臺上開啟線上直播,還是很有新意的。9.13全平臺渠道將開啟預售模式&#xff0c…