深度學習中的模型剪枝工具Torch-Pruning的使用

? ? ? Torch-Pruning(TP)是一個結構化剪枝框架,源碼地址:https://github.com/VainF/Torch-Pruning,最新發布版本v1.6.0,License為MIT。

? ? ? TP支持對各種深度神經網絡進行結構化剪枝。與通過掩碼將參數設置為零的torch.nn.utils.prune不同,TP部署了一種名為DepGraph的算法來分組和移除耦合參數(coupled parameter)。

? ? ? TP僅依賴PyTorch和Numpy,并且與PyTorch 1.x和2.x兼容,在Anaconda虛擬環境上通過pip安裝v1.6.0版本,執行以下命令:

pip install torch-pruning==1.6.0

? ? ? 在結構化剪枝中,移除單個參數可能會影響多個層。例如,剪枝線性層的輸出維度將需要移除下一個線性層中相應的輸入維度。層之間的這種依賴關系使得手動剪枝復雜網絡變得非常困難。TP通過引入一種名為DepGraph的基于圖的算法來解決這個問題,該算法可以自動識別依賴關系并收集需要剪枝的組。

? ? ? 這里以?https://blog.csdn.net/fengbingchun/article/details/149307432?中的數據集為例,使用DenseNet進行分類,測試代碼如下:

? ? ? 1. 對之前生成的分類模型進行剪枝::保存剪枝后的模型使用torch.save(model,name),不能使用torch.save(model.state_dict(),name)

def model_pruning(model_name, classes_number, prune_amount):# https://github.com/VainF/Torch-Pruning/blob/master/examples/torchvision_models/torchvision_global_pruning.pymodel = models.densenet121(weights=None)model.classifier = nn.Linear(model.classifier.in_features, classes_number)# print("before pruning, model:", model)model.load_state_dict(torch.load(model_name, weights_only=False, map_location="cpu"))orininal_size = tp.utils.count_params(model)model.cpu().eval()for p in model.parameters():p.requires_grad_(True)ignored_layers = []for m in model.modules():if isinstance(m, nn.Linear):ignored_layers.append(m)print(f"ignored_layers: {ignored_layers}")example_inputs = torch.randn(1, 3, 224, 224)# build network prunersimportance = tp.importance.MagnitudeImportance(p=1)pruner = tp.pruner.MagnitudePruner(model,example_inputs=example_inputs,importance=importance,iterative_steps=1,pruning_ratio=prune_amount,global_pruning=True,round_to=None,unwrapped_parameters=None,ignored_layers=ignored_layers,channel_groups={})# pruninglayer_channel_cfg = {}for module in model.modules():if module not in pruner.ignored_layers:if isinstance(module, nn.Conv2d):layer_channel_cfg[module] = module.out_channelselif isinstance(module, nn.Linear):layer_channel_cfg[module] = module.out_featurespruner.step()# print("after pruning, model", model)result_size = tp.utils.count_params(model)print(f"model: original size: {orininal_size}; result_size: {result_size}")# testingwith torch.no_grad():out = model(example_inputs)print("test out:", out)torch.save(model, "new_structured_prune_melon_classify.pt") # cann't bu used: torch.save(model.state_dict(), "")

? ? ? 剪枝前后,模型的改動如下圖所示:

? ? ? 剪枝前模型大小約為27.1MB,剪枝后模型大小約為14.0M。

? ? ? 2. 模型剪枝后需要對其進行微調,即重新訓練:

def _load_dataset(dataset_path, mean, std, batch_size):mean = _str2tuple(mean)std = _str2tuple(std)train_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])train_dataset = ImageFolder(root=dataset_path+"/train", transform=train_transform)print(f"train dataset length: {len(train_dataset)}; classes: {train_dataset.class_to_idx}; number of categories: {len(train_dataset.class_to_idx)}")train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=0)val_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])val_dataset = ImageFolder(root=dataset_path+"/val", transform=val_transform)print(f"val dataset length: {len(val_dataset)}; classes: {val_dataset.class_to_idx}")assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"val_loader = DataLoader(val_dataset, batch_size, shuffle=True, num_workers=0)return len(train_dataset), len(val_dataset), train_loader, val_loaderdef fine_tuning(dataset_path, epochs, mean, std, model_name):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = torch.load(model_name, weights_only=False)model.to(device)train_dataset_num, val_dataset_num, train_loader, val_loader = _load_dataset(dataset_path, mean, std, 4)optimizer = optim.Adam(model.parameters(), lr=0.00001) # set the optimizercriterion = nn.CrossEntropyLoss() # set the losshighest_accuracy = 0.minimum_loss = 100.new_model_name = "fine_tuning_melon_classify.pt"for epoch in range(epochs):epoch_start = time.time()train_loss = 0.0train_acc = 0.0val_loss = 0.0val_acc = 0.0model.train() # set to training modefor _, (inputs, labels) in enumerate(train_loader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad() # clean existing gradientsoutputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossloss.backward() # backpropagate the gradientsoptimizer.step() # update the parameterstrain_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute the accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floattrain_acc += acc.item() * inputs.size(0) # compute the total accuracy# print(f"train batch number: {i}; train loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")model.eval() # set to evaluation modewith torch.no_grad():for _, (inputs, labels) in enumerate(val_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossval_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute validation accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floatval_acc += acc.item() * inputs.size(0) # compute the total accuracyavg_train_loss = train_loss / train_dataset_num # average training lossavg_train_acc = train_acc / train_dataset_num # average training accuracyavg_val_loss = val_loss / val_dataset_num # average validation lossavg_val_acc = val_acc / val_dataset_num # average validation accuracyepoch_end = time.time()print(f"epoch:{epoch+1}/{epochs}; train loss:{avg_train_loss:.6f}, accuracy:{avg_train_acc:.6f}; validation loss:{avg_val_loss:.6f}, accuracy:{avg_val_acc:.6f}; time:{epoch_end-epoch_start:.2f}s")if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:torch.save(model, new_model_name)highest_accuracy = avg_val_accminimum_loss = avg_val_lossif avg_val_loss < 0.0001 or avg_val_acc > 0.9999:print(colorama.Fore.YELLOW + "stop training early")torch.save(model, new_model_name)break

? ? ? 微調時迭代幾次即可滿足要求,執行結果如下圖所示:

? ? ? 3. 使用剪枝后的模型和微調后的模型進行預測::加載模型使用torch.load(model_name, weights_only=False),不能使用model.load_state_dict(torch.load(model_name, weights_only=False, map_location="cpu"))

def _parse_labels_file(labels_file):classes = {}with open(labels_file, "r") as file:for line in file:idx_value = []for v in line.split(" "):idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the lineassert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"classes[int(idx_value[0])] = idx_value[1]return classesdef _get_images_list(images_path):image_names = []p = Path(images_path)for subpath in p.rglob("*"):if subpath.is_file():image_names.append(subpath)return image_namesdef predict(model_name, labels_file, images_path, mean, std):classes = _parse_labels_file(labels_file)assert len(classes) != 0, "the number of categories can't be 0"image_names = _get_images_list(images_path)assert len(image_names) != 0, "no images found"mean = _str2tuple(mean)std = _str2tuple(std)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = torch.load(model_name, weights_only=False)model.to(device)model.eval()with torch.no_grad():for image_name in image_names:input_image = Image.open(image_name)preprocess = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std) # RGB])input_tensor = preprocess(input_image) # (c,h,w)input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model, (1,c,h,w)input_batch = input_batch.to(device)output = model(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0) # the output has unnormalized scores, to get probabilities, you can run a softmax on itmax_value, max_index = torch.max(probabilities, dim=0)print(f"{image_name.name}\t{classes[max_index.item()]}\t{max_value.item():.4f}")

? ? ? 執行結果如下圖所示:微調前的模型準確率非常低,微調后的模型準確率非常高

? ? ? GitHub:https://github.com/fengbingchun/NN_Test

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/89577.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/89577.shtml
英文地址,請注明出處:http://en.pswp.cn/web/89577.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

力扣-121.買賣股票的最佳時機

121.買賣股票的最佳時機 class Solution {public int maxProfit(int[] prices) {int min prices[0];int max 0;for (int i 1; i < prices.length; i) {max Math.max(prices[i] - min, max);if (prices[i] < min) {min prices[i];}}return max;} }小結&#xff1a;貪…

lvs原理及實戰部署

一、集群與分布式系統 1 集群 1-1概念 集群式架構是將多個相同或相似的節點組合在一起&#xff0c;形成一個邏輯上的 “整體”&#xff0c;對外提供統一的服務或資源。節點之間通常具有較高的同構性&#xff08;硬件、軟件配置相似&#xff09;&#xff0c;且緊密協作。 1-2 三…

[Linux]如何設置靜態IP位址?

自從將Ubuntu Server 24.04 LTS作業系統建置在VM上後&#xff0c;逐漸導入一些容器和微服務器並使可由其他Client端來連接使用&#xff0c;其中包含AIGC模型和自動化工作流等服務&#xff0c;例如Open-WebUI和n8n。然而&#xff0c;若VM重新開機或路由器因故斷電等等狀態&#…

【Leecode 隨筆】

文章目錄題目一&#xff1a;盛最多水的容器題目描述&#xff1a;題目分析&#xff1a;解題思路&#xff1a;示例代碼&#xff1a;深入剖析&#xff1a;題目二&#xff1a;最長無重復字符的子串題目描述&#xff1a;題目分析&#xff1a;解題思路&#xff1a;示例代碼&#xff1…

Springboot項目應用PageInfo分頁問題失效

使用github的pagehelper分頁依賴<!-- 分頁控件 --><dependency><groupId>com.github.pagehelper</groupId><artifactId>pagehelper</artifactId><version>5.3.0</version><scope>compile</scope></dependency&…

【無標題】標準模型粒子行為與11維拓撲量子色動力學模型嚴格對應的全面論述

標準模型粒子行為與11維拓撲量子色動力學模型嚴格對應的全面論述標準模型粒子與拓撲結構的嚴格對應 mermaid graph LRsubgraph 標準模型粒子A[費米子] --> A1[夸克]A --> A2[輕子]B[玻色子] --> B1[規范玻色子]B --> B2[希格斯]endsubgraph 11維拓撲模型C[實體頂點…

SQL一些關于存儲過程和使用的總結

存儲過程&#xff1a;數據庫里的 "定制工具箱"存儲過程就像一個裝滿工具的箱子&#xff0c;你需要什么功能&#xff0c;就調用對應的工具。它是用 SQL 語句寫好的一段程序&#xff0c;存儲在數據庫里&#xff0c;隨時可以調用。創建存儲過程 就像在工具箱里放新工具。…

springCloud -- 微服務01

目錄 一、認識微服務 1.單體架構 2.微服務 3.SpringCloud 二、微服務拆分 1.服務拆分原則 2.服務調用 3. RestTemplate 三、服務注冊和發現 1. 注冊中心原理 2. 服務發現 2.1 服務注冊 2.2 服務發現 四、OpenFeign 一、認識微服務 1.單體架構 單體架構就是整個項目中所有功能…

Deep Multi-scale Convolutional Neural Network for Dynamic Scene Deblurring 論文閱讀

用于動態場景去模糊的深度多尺度卷積神經網絡 摘要 針對一般動態場景的非均勻盲去模糊是一個具有挑戰性的計算機視覺問題&#xff0c;因為模糊不僅來源于多個物體運動&#xff0c;還來源于相機抖動和場景深度變化。為了去除這些復雜的運動模糊&#xff0c;傳統的基于能量優化的…

PDF 拆分合并PDFSam:開源免費 多文件合并 + 按頁碼拆分 本地處理

各位打工人和學生黨們&#xff0c;你知道嗎&#xff0c;處理PDF文件簡直是咱們的日常噩夢啊&#xff0c;尤其是遇到要合并好幾個文件&#xff0c;或者從中摳幾頁出來的時候&#xff0c;簡直頭大如斗&#xff01;今天給你們安利一個神仙工具&#xff0c;PDFSam&#xff0c;聽我的…

AI產品經理面試寶典第32天:AI+工業場景落地核心問題與應答策略

一、AI+工業落地價值怎么答? 面試官:AI在工業領域能創造哪些核心價值?請用具體案例說明 你的回答: AI在工業領域創造價值的底層邏輯是"數據閉環"。以阿里云ET工業大腦為例,通過采集生產線3000+傳感器數據,構建出影響良品率的60個關鍵變量模型。當數據流經AI…

【09】MFC入門到精通——MFC 屬性頁對話框的 CPropertyPage類 和 CPropertySheet 類

文章目錄九、屬性頁對話框的類CPropertyPage類 和 CPropertySheet 類。9.1 CPropertyPage 類&#xff08;1&#xff09;構造函數&#xff08;2&#xff09;CancelToClose()函數&#xff08;3&#xff09;SetModified()函數&#xff08;4&#xff09;可重載函數9.2 CPropertyShe…

Python學習筆記4

時間:2025.7.18學習內容&#xff1a;【語法基礎】if判斷、比較運算符與邏輯運算符一、if判斷if判斷基本格式&#xff1a;if要判斷的條件&#xff0c;條件成立時要做的事情注意&#xff1a;input內默認存儲的是字符串age17 if age<18:print(未成年不能上網) scoreinput(你的成…

20250718-2-Kubernetes 應用程序生命周期管理-Pod對象:基本概念(豌豆莢)_筆記

二、Kubernetes應用程序生命周期管理&#xfeff;1. 課程內容概述主要內容&#xff1a;Pod資源共享實現機制管理命令應用自修復&#xff08;重啟策略健康檢查&#xff09;環境變量Init container靜態Pod2. Pod對象介紹&#xfeff;1&#xff09;Pod基本概念&#xfeff;&#x…

為Notepad++插上JSON格式化的翅膀

文章目錄概要安裝步驟效果展示概要 JSMinNPP.dll 是一個 Notepad 插件&#xff0c;用于壓縮 JavaScript 代碼和格式化JSON字符床。以下是安裝和使用的詳細步驟&#xff1a; 安裝步驟 下載 JSMinNPP.dll 插件 https://pan.quark.cn/s/73dd0ac225be 放置 DLL 文件 打開 Notepa…

STM32-第七節-TIM定時器-3(輸入捕獲)

一、簡介&#xff1a;1.名稱&#xff1a;IC&#xff0c;輸入捕獲2.電路&#xff1a;如圖為通用定時器框圖&#xff0c;下半部分的左半模塊&#xff0c;與輸出比較部分共用捕獲/比較寄存器與引腳。3.功能&#xff1a;當通道輸入引腳出現電平跳變時&#xff0c;當前CNT的值&#…

Console 納管 Elasticsearch 9(二):日志監控

前面介紹過 INFINI Console 納管 Elasticsearch 9&#xff08;一&#xff09;&#xff0c;進行指標監控、數據管理、DSL 語句執行&#xff0c;但日志監控功能需要結合 Agent 才能使用。現在來實現一下&#xff1a; Agent 需要和 ES 部署到同一機器上&#xff0c;這里是在我本地…

實訓十——路由器與TCP/IP模型

補充拓撲圖&#xff08;交換機串聯通信&#xff09;電腦A——交換機S1——交換機S2——電腦B問&#xff1a;A和B如何通信&#xff1f;首先A會將通信的數據封裝好&#xff0c;將源端口、目標端口&#xff0c;源地址、目標地址&#xff0c;源MAC、目標MAC封裝起來&#xff0c;但是…

【Android】ViewBinding(視圖綁定)

一、什么是ViewBindingViewBinding是Android Studio 3.6推出的新特性&#xff0c;旨在替代findViewById(內部實現還是使用findViewById)。通過ViewBinding&#xff0c;可以更輕松地編寫可與視圖交互的代碼。在模塊中啟用ViewBinding之后&#xff0c;系統會為該模塊中的每個 XML…

泛型與類型安全深度解析及響應式API實戰

一、泛型通配符&#xff1a;靈活與安全的平衡術 在Java動物收容所系統中&#xff0c;我們常需要處理不同動物類型的集合。通過泛型通配符&#xff0c;可以構建更靈活的API&#xff1a; class Shelter<T extends Animal> {private List<T> animals new ArrayList&l…