實踐教程|基于 pytorch 實現模型剪枝

PyTorch剪枝方法詳解,附詳細代碼。

  • 一,剪枝分類

  • 1.1,非結構化剪枝

  • 1.2,結構化剪枝

  • 1.3,本地與全局修剪

  • 二,PyTorch 的剪枝

  • 2.1,pytorch 剪枝工作原理

  • 2.2,局部剪枝

  • 2.3,全局非結構化剪枝

  • 三,總結

  • 參考資料

一,剪枝分類

所謂模型剪枝,其實是一種從神經網絡中移除"不必要"權重或偏差(weigths/bias)的模型壓縮技術。關于什么參數才是“不必要的”,這是一個目前依然在研究的領域。

1.1,非結構化剪枝

非結構化剪枝(Unstructured Puning)是指修剪參數的單個元素,比如全連接層中的單個權重、卷積層中的單個卷積核參數元素或者自定義層中的浮點數(scaling floats)。其重點在于,剪枝權重對象是隨機的,沒有特定結構,因此被稱為非結構化剪枝

1.2,結構化剪枝

與非結構化剪枝相反,結構化剪枝會剪枝整個參數結構。比如,丟棄整行或整列的權重,或者在卷積層中丟棄整個過濾器(Filter)。

1.3,本地與全局修剪

剪枝可以在每層(局部)或多層/所有層(全局)上進行。

二,PyTorch 的剪枝

目前 PyTorch 框架支持的權重剪枝方法有:

  • Random: 簡單地修剪隨機參數。

  • Magnitude: 修剪權重最小的參數(例如它們的 L2 范數)

以上兩種方法實現簡單、計算容易,且可以在沒有任何數據的情況下應用。

2.1,pytorch 剪枝工作原理

剪枝功能在 torch.nn.utils.prune 類中實現,代碼在文件 torch/nn/utils/prune.py 中,主要剪枝類如下圖所示。

圖片

pytorch_pruning_api_file.png

剪枝原理是基于張量(Tensor)的掩碼(Mask)實現。掩碼是一個與張量形狀相同的布爾類型的張量,掩碼的值為 True 表示相應位置的權重需要保留,掩碼的值為 False 表示相應位置的權重可以被刪除。

Pytorch 將原始參數 <param> 復制到名為 <param>_original 的參數中,并創建一個緩沖區來存儲剪枝掩碼 <param>_mask。同時,其也會創建一個模塊級的 forward_pre_hook 回調函數(在模型前向傳播之前會被調用的回調函數),將剪枝掩碼應用于原始權重。

pytorch 剪枝的 api 和教程比較混亂,我個人將做了如下表格,希望能將 api 和剪枝方法及分類總結好。

圖片

pytorch_pruning_api

pytorch 中進行模型剪枝的工作流程如下:

  1. 選擇剪枝方法(或者子類化 BasePruningMethod 實現自己的剪枝方法)。

  2. 指定剪枝模塊和參數名稱。

  3. 設置剪枝方法的參數,比如剪枝比例等。

2.2,局部剪枝

Pytorch 框架中的局部剪枝有非結構化和結構化剪枝兩種類型,值得注意的是結構化剪枝只支持局部不支持全局。

2.2.1,局部非結構化剪枝

1,局部非結構化剪枝(Locall Unstructured Pruning)對應函數原型如下:

def random_unstructured(module, name, amount)  

1,函數功能:用于對權重參數張量進行非結構化剪枝。該方法會在張量中隨機選擇一些權重或連接進行剪枝,剪枝率由用戶指定。2,函數參數定義:

  • module (nn.Module): 需要剪枝的網絡層/模塊,例如 nn.Conv2d() 和 nn.Linear()。

  • name (str): 要剪枝的參數名稱,比如 “weight” 或 “bias”。

  • amount (int or float): 指定要剪枝的數量,如果是 0~1 之間的小數,則表示剪枝比例;如果是證書,則直接剪去參數的絕對數量。比如amount=0.2 ,表示將隨機選擇 20% 的元素進行剪枝。

3,下面是 random_unstructured 函數的使用示例。

import torch  
import torch.nn.utils.prune as prune  
conv = torch.nn.Conv2d(1, 1, 4)  
prune.random_unstructured(conv, name="weight", amount=0.5)  
conv.weight  
"""  
tensor([[[[-0.1703,  0.0000, -0.0000,  0.0690],  [ 0.1411,  0.0000, -0.0000, -0.1031],  [-0.0527,  0.0000,  0.0640,  0.1666],  [ 0.0000, -0.0000, -0.0000,  0.2281]]]], grad_fn=<MulBackward0>)  
"""  

可以看出輸出的 conv 層中權重值有一半比例為 0

2.2.2,局部結構化剪枝

局部結構化剪枝(Locall Structured Pruning)有兩種函數,對應函數原型如下:

def random_structured(module, name, amount, dim)  
def ln_structured(module, name, amount, n, dim, importance_scores=None)  

1,函數功能

與非結構化移除的是連接權重不同,結構化剪枝移除的是整個通道權重。

2,參數定義

與局部非結構化函數非常相似,唯一的區別是您必須定義 dim 參數(ln_structured 函數多了 n 參數)。

n 表示剪枝的范數,dim 表示剪枝的維度。

對于 torch.nn.Linear:

  • dim = 0:移除一個神經元。

  • dim = 1:移除與一個輸入的所有連接。

對于 torch.nn.Conv2d:

  • dim = 0(Channels) : 通道 channels 剪枝/過濾器 filters 剪枝

  • dim = 1(Neurons): 二維卷積核 kernel 剪枝,即與輸入通道相連接的 kernel

2.2.3,局部結構化剪枝示例代碼

在寫示例代碼之前,我們先需要理解 Conv2d 函數參數、卷積核 shape、軸以及張量的關系。首先,Conv2d 函數原型如下;

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)  

而 pytorch 中常規卷積的卷積核權重 shape 都為(C_out, C_in, kernel_height, kernel_width),所以在代碼中卷積層權重 shape[3, 2, 3, 3],dim = 0 對應的是 shape [3, 2, 3, 3] 中的 3。這里我們 dim 設定了哪個軸,那自然剪枝之后權重張量對應的軸機會發生變換。

圖片

dim

理解了前面的關鍵概念,下面就可以實際使用了,dim=0 的示例如下所示。

conv = torch.nn.Conv2d(2, 3, 3)  
norm1 = torch.norm(conv.weight, p=1, dim=[1,2,3])  
print(norm1)  
"""  
tensor([1.9384, 2.3780, 1.8638], grad_fn=<NormBackward1>)  
"""  
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0)  
print(conv.weight)  
"""  
tensor([[[[-0.0005,  0.1039,  0.0306],  [ 0.1233,  0.1517,  0.0628],  [ 0.1075, -0.0606,  0.1140]],  [[ 0.2263, -0.0199,  0.1275],  [-0.0455, -0.0639, -0.2153],  [ 0.1587, -0.1928,  0.1338]]],  [[[-0.2023,  0.0012,  0.1617],  [-0.1089,  0.2102, -0.2222],  [ 0.0645, -0.2333, -0.1211]],  [[ 0.2138, -0.0325,  0.0246],  [-0.0507,  0.1812, -0.2268],  [-0.1902,  0.0798,  0.0531]]],  [[[ 0.0000, -0.0000, -0.0000],  [ 0.0000, -0.0000, -0.0000],  [ 0.0000, -0.0000,  0.0000]],  [[ 0.0000,  0.0000,  0.0000],  [-0.0000,  0.0000,  0.0000],  [-0.0000, -0.0000, -0.0000]]]], grad_fn=<MulBackward0>)  
"""  

從運行結果可以明顯看出,卷積層參數的最后一個通道參數張量被移除了(為 0 張量),其解釋參見下圖。

圖片

dim_understand

dim = 1 的情況:

conv = torch.nn.Conv2d(2, 3, 3)  
norm1 = torch.norm(conv.weight, p=1, dim=[0, 2,3])  
print(norm1)  
"""  
tensor([3.1487, 3.9088], grad_fn=<NormBackward1>)  
"""  
prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1)  
print(conv.weight)  
"""  
tensor([[[[ 0.0000, -0.0000, -0.0000],  [-0.0000,  0.0000,  0.0000],  [-0.0000,  0.0000, -0.0000]],  [[-0.2140,  0.1038,  0.1660],  [ 0.1265, -0.1650, -0.2183],  [-0.0680,  0.2280,  0.2128]]],  [[[-0.0000,  0.0000,  0.0000],  [ 0.0000,  0.0000, -0.0000],  [-0.0000, -0.0000, -0.0000]],  [[-0.2087,  0.1275,  0.0228],  [-0.1888, -0.1345,  0.1826],  [-0.2312, -0.1456, -0.1085]]],  [[[-0.0000,  0.0000,  0.0000],  [ 0.0000, -0.0000,  0.0000],  [ 0.0000, -0.0000,  0.0000]],  [[-0.0891,  0.0946, -0.1724],  [-0.2068,  0.0823,  0.0272],  [-0.2256, -0.1260, -0.0323]]]], grad_fn=<MulBackward0>)  
"""  

很明顯,對于 dim=1的維度,其第一個張量的 L2 范數更小,所以shape 為 [2, 3, 3] 的張量中,第一個 [3, 3] 張量參數會被移除(即張量為 0 矩陣) 。

2.3,全局非結構化剪枝

前文的 local 剪枝的對象是特定網絡層,而 global 剪枝是將模型看作一個整體去移除指定比例(數量)的參數,同時 global 剪枝結果會導致模型中每層的稀疏比例是不一樣的。

全局非結構化剪枝函數原型如下:

# v1.4.0 版本  
def global_unstructured(parameters, pruning_method, **kwargs)  
# v2.0.0-rc2版本  
def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):  

1,函數功能

隨機選擇全局所有參數(包括權重和偏置)的一部分進行剪枝,而不管它們屬于哪個層。

2,參數定義

  • parameters((Iterable of (module, name) tuples)): 修剪模型的參數列表,列表中的元素是 (module, name)。

  • pruning_method(function): 目前好像官方只支持 pruning_method=prune.L1Unstuctured,另外也可以是自己實現的非結構化剪枝方法函數。

  • importance_scores: 表示每個參數的重要性得分,如果為 None,則使用默認得分。

  • **kwargs: 表示傳遞給特定剪枝方法的額外參數。比如 amount 指定要剪枝的數量。

3,global_unstructured 函數的示例代碼如下所示。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  class LeNet(nn.Module):  def __init__(self):  super(LeNet, self).__init__()  # 1 input image channel, 6 output channels, 3x3 square conv kernel  self.conv1 = nn.Conv2d(1, 6, 3)  self.conv2 = nn.Conv2d(6, 16, 3)  self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension  self.fc2 = nn.Linear(120, 84)  self.fc3 = nn.Linear(84, 10)  def forward(self, x):  x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))  x = F.max_pool2d(F.relu(self.conv2(x)), 2)  x = x.view(-1, int(x.nelement() / x.shape[0]))  x = F.relu(self.fc1(x))  x = F.relu(self.fc2(x))  x = self.fc3(x)  return x  model = LeNet().to(device=device)  model = LeNet()  parameters_to_prune = (  (model.conv1, 'weight'),  (model.conv2, 'weight'),  (model.fc1, 'weight'),  (model.fc2, 'weight'),  (model.fc3, 'weight'),  
)  prune.global_unstructured(  parameters_to_prune,  pruning_method=prune.L1Unstructured,  amount=0.2,  
)  
# 計算卷積層和整個模型的稀疏度  
# 其實調用的是 Tensor.numel 內內函數,返回輸入張量中元素的總數  
print(  "Sparsity in conv1.weight: {:.2f}%".format(  100. * float(torch.sum(model.conv1.weight == 0))  / float(model.conv1.weight.nelement())  )  
)  
print(  "Global sparsity: {:.2f}%".format(  100. * float(  torch.sum(model.conv1.weight == 0)  + torch.sum(model.conv2.weight == 0)  + torch.sum(model.fc1.weight == 0)  + torch.sum(model.fc2.weight == 0)  + torch.sum(model.fc3.weight == 0)  )  / float(  model.conv1.weight.nelement()  + model.conv2.weight.nelement()  + model.fc1.weight.nelement()  + model.fc2.weight.nelement()  + model.fc3.weight.nelement()  )  )  
)  
# 程序運行結果  
"""  
Sparsity in conv1.weight: 3.70%  
Global sparsity: 20.00%  
"""  

運行結果表明,雖然模型整體(全局)的稀疏度是 20%,但每個網絡層的稀疏度不一定是 20%。

三,總結

另外,pytorch 框架還提供了一些幫助函數:

  1. torch.nn.utils.prune.is_pruned(module): 判斷模塊 是否被剪枝。

  2. torch.nn.utils.prune.remove(module, name):用于將指定模塊中指定參數上的剪枝操作移除,從而恢復該參數的原始形狀和數值。

雖然 PyTorch 提供了內置剪枝 API ,也支持了一些非結構化和結構化剪枝方法,但是 API 比較混亂,對應文檔描述也不清晰,所以后面我還會結合微軟的開源 nni 工具來實現模型剪枝功能。

更多剪枝方法實踐,可以參考這個 github 倉庫:Model-Compression。

參考資料

  1. How to Prune Neural Networks with PyTorch

  2. PRUNING TUTORIAL

  3. PyTorch Pruning

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/42100.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/42100.shtml
英文地址,請注明出處:http://en.pswp.cn/news/42100.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

前端如何安全的渲染HTML字符串?

在現代的Web 應用中&#xff0c;動態生成和渲染 HTML 字符串是很常見的需求。然而&#xff0c;不正確地渲染HTML字符串可能會導致安全漏洞&#xff0c;例如跨站腳本攻擊&#xff08;XSS&#xff09;。為了確保應用的安全性&#xff0c;我們需要采取一些措施來在安全的環境下渲染…

QString常用函數介紹

此篇博客核心介紹QT中的QString類型的常用函數&#xff0c;介紹到的函數均從幫助手冊或其他博客中看到 QString 字符串類 Header: #include qmake: QT core 一、QString字符串轉換 1、QString類字符串轉換為整數 int toInt(bool *ok Q_NULLPTR, int base 10) cons…

Python 基礎 -- Tutorial(二)

5、數據結構 本章更詳細地描述了一些你已經學過的東西&#xff0c;并添加了一些新的東西。 5.1. 更多關于Lists 列表(list)數據類型有更多的方法。下面是列表對象的所有方法: list.append(x) 在列表末尾添加一項。相當于a[len(a):] [x]。 list.extend(iterable) 通過添加可…

如何使用SpringBoot 自定義轉換器

&#x1f600;前言 本篇博文是關于SpringBoot 自定義轉換器的使用&#xff0c;希望你能夠喜歡&#x1f60a; &#x1f3e0;個人主頁&#xff1a;晨犀主頁 &#x1f9d1;個人簡介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章可以幫助到大家&#xff0c;您的…

02-前端基礎第二天-HTML5

01-HTML標簽&#xff08;下&#xff09;導讀 目標&#xff1a; 能夠書寫表格能夠寫出無序列表能夠寫出3~4個常用input表單類型能夠寫出下拉列表表單能夠使用表單元素實現注冊頁面能夠獨立查閱W3C文檔 目錄&#xff1a; 表格標簽列表標簽表單標簽綜合案例查閱文檔 02-表格標…

Nginx搭建本地服務器,無需購買服務器即可測試vue項目打包后的效果

一.前言 本文是在windows環境&#xff08;Linux環境下其實也大同小異&#xff09;下基于Nginx實現搭建本地服務器&#xff0c;手把手教你部署vue項目。 二.Nginx入門 1&#xff09;下載安裝 進入Nginx官網下載&#xff0c;選擇stable版本下的windows版本下載即可 2&#xff09;…

Ubuntu 20.04配置靜態ip

ip配置文件 cd /etc/netplan配置 根據需求增加 # Let NetworkManager manage all devices on this system network:version: 2renderer: NetworkManager # 管理 不是必須ethernets:enp4s0: #網卡名dhcp4: no #關閉ipv4動態分配ip地址dhcp6: no #關閉ipv6動態分配…

Arrays.asList() 返回的list不能add,remove

一.Arrays.asList() 返回的list不能add,remove Arrays.asList()返回的是List,而且是一個定長的List&#xff0c;所以不能轉換為ArrayList&#xff0c;只能轉換為AbstractList 原因在于asList()方法返回的是某個數組的列表形式,返回的列表只是數組的另一個視圖,而數組本身并沒…

Wireshark 抓包過濾命令匯總

Wireshark 抓包過濾命令匯總 Wireshark 是一個強大的網絡分析工具&#xff0c;它可以幫助網絡管理員和安全專家監控和分析網絡流量。通過捕獲網絡數據包&#xff0c;Wireshark 能夠幫助我們識別網絡中的問題、瓶頸以及潛在的安全威脅。在使用 Wireshark 進行網絡數據包分析時&…

SQL Server基礎之游標

一&#xff1a;認識游標 游標是SQL Server的一種數據訪問機制&#xff0c;它允許用戶訪問單獨的數據行。用戶可以對每一行進行單獨的處理&#xff0c;從而降低系統開銷和潛在的阻隔情況&#xff0c;用戶也可以使用這些數據生成的SQL代碼并立即執行或輸出。 1.游標的概念 游標是…

DELL PowerEdge R720XD 磁盤RAID及Hot Spare熱備盤配置

一臺DELL PowerEdge R720XD服務器&#xff0c;需進行磁盤RAID及Hot Spare熱備盤配置&#xff0c;本文記錄配置過程示例。 一、設備環境 服務器型號&#xff1a;DELL PowerEdge R720XD 硬盤配置&#xff1a;800G硬盤共24塊 二、配置計劃 1、當前狀態&#xff1a;2塊盤配置RAID…

AIGC+游戲:一個被忽視的長賽道

&#xff08;圖片來源&#xff1a;Pixels&#xff09; AIGC徹底變革了游戲&#xff0c;但還不夠。 數科星球原創 作者丨苑晶 編輯丨大兔 消費還沒徹底復蘇&#xff0c;游戲卻已經出現拐點。 在游戲熱度猛增的背后&#xff0c;除了版號的利好因素外&#xff0c;AIGC技術的廣泛…

js下載后端返回的文件

文件流下載 后端返回文件流形式&#xff0c;前端下載 // res 為請求返回的數據對象const file_data res.data // 后端返回的文件流const blob new Blob([file_data]) const href window.URL.createObjectURL(blob) // 創建下載的鏈接 const file_name decodeURI(res.header…

4. 軟件開發的環境搭建

目錄 1. 搭建環境 1.1 檢查 JDK 1.2 檢查 MySQL 數據庫 1.3 檢查 Maven 1.4 檢查 GITEEGIT 1.5 安裝插件 1.5.1 安裝 Spring Boot Helper 1.5.2 安裝 lombok 1.6 創建倉庫 1.6.1 登錄 GITEE 創建倉庫并復制倉庫地址 1.6.2 克隆到本地 1.7 創建工程 1.7.1 設置編碼…

【Spring】Bean的實例化

1、簡介 在容器中的Bean要實例化為對象有三種方式 1、構造方法 2、靜態工廠 3、實例工廠 4、實現工廠接口 2、構造方法 構造方法實例化Bean即是直接通過構造方法創建對象 <bean id"bookDao" class"com.wn.spring.dao.impl.BookDaoImpl"/> 當不存在…

怎么把pdf壓縮到5m以內?壓縮辦法非常多

怎么把pdf壓縮到5m以內&#xff1f;PDF文件是我們辦公過程中較為常用的文件格式&#xff0c;PDF文件所包含的內容通常較多&#xff0c;比如文本、圖像以及音視頻等等。這樣的話&#xff0c;PDF文件占用內存也較大。如果需要對PDF文件進行使用、傳輸、分享等的話&#xff0c;可能…

單片機之從C語言基礎到專家編程 - 4 C語言基礎 - 4.8 運算符

1.算術運算符 運算符名稱備注加法運算符雙目運算&#xff0c;a b-減法運算符雙目運算&#xff0c;a - b*乘法運算符雙目運算&#xff0c;a * b/除法運算符雙目運算&#xff0c;a / b%求余運算符雙目運算, a % b自增運算符單目運算, a–自減運算符單目運算, a– 2.關系運算符…

Vue2集成Echarts實現可視化圖表

一、依賴配置 1、引入echarts相關依賴 也可以卸載原有的&#xff0c;重新安裝 卸載&#xff1a;npm uninstall echarts --save 安裝&#xff1a;npm install echarts4.8.0 --save 引入水球圖形依賴 npm install echarts-liquidfill2.0.2 --save 水球圖可參考文檔&#xff1…

MySQL索引(Index)

Index 數據庫中的索引&#xff08;Index&#xff09;是一種數據結構&#xff0c;用于提高數據庫查詢性能和加速數據檢索過程。索引可以看作是數據庫表中某個或多個列的數據結構&#xff0c;類似于書中的目錄&#xff0c;可以幫助數據庫管理系統更快地定位和訪問數據。它們是數…

Linux——KVM虛擬化

目錄標題 虛擬化技術虛擬化技術發展案例KVM簡介KVM架構及原理KVM原理KVM虛擬化架構/三種模式虛擬化前、虛擬化后對比KVM蓋中蓋套娃實驗 虛擬化技術 通過虛擬化技術將一臺計算機虛擬為多臺邏輯計算機&#xff0c;在一臺計算機上同時運行多個邏輯計算機&#xff0c;同時每個邏輯…