FLOPs、FLOPS、Params的含義及PyTorch中的計算方法

FLOPs、FLOPS、Params的含義及PyTorch中的計算方法

含義解釋

  1. FLOPS:注意全大寫,是floating point operations per second的縮寫(這里的大S表示second秒),表示每秒浮點運算次數,理解為計算速度。是一個衡量硬件性能的指標。

  2. FLOPs:注意s小寫,是floating point operations的縮寫(這里的小s則表示復數),表示浮點運算數,理解為計算量。可以用來衡量算法/模型的復雜度。

  3. Params:沒有固定的名稱,大小寫均可,表示模型的參數量,也是用來衡量算法/模型的復雜度。通常我們在論文中見到的是這樣:# Params,那個井號是表示 number of 的意思,因此 # Params 的意思就是:參數的數量。

在這里插入圖片描述

FLOPs與模型時間復雜度、GPU利用率有關,Params與模型空間復雜度、顯存占用有關。即我們常見的nvidia-smi命令中的GPU利用率(紅框)和顯存占用(籃框)。

MAC

MAC:Multiply Accumulate,乘加運算。乘積累加運算(英語:Multiply Accumulate, MAC)是在數字信號處理器或一些微處理器中的特殊運算。實現此運算操作的硬件電路單元,被稱為“乘數累加器”。這種運算的操作,是將乘法的乘積結果和累加器的值相加,再存入累加器:
a←a+b×ca\leftarrow a+b\times c aa+b×c
使用MAC可以將原本需要的兩個指令操作減少到一個指令操作,從而提高運算效率。

FLOPs的計算

以下不考慮激活函數的計算量。

卷積層

(2×Ci×K2?1)×H×W×C0(2\times C_i\times K^2-1)\times H\times W\times C_0(2×Ci?×K2?1)×H×W×C0?

CiC_iCi?=輸入通道數, KKK=卷積核尺寸,H,WH,WH,W=輸出特征圖空間尺寸,CoC_oCo?=輸出通道數。

一個MAC算兩個個浮點運算,所以在最前面×2\times 2×2。不考慮bias時有?1-1?1,有bias時沒有?1-1?1。由于考慮的一般是模型推理時的計算量,所以上述公式是針對一個輸入樣本的情況,即batch size=1。

理解上面這個公式分兩步,括號內是第一步,計算出輸出特征圖的一個pixel的計算量,然后再乘以 H×W×CoH\times W\times C_oH×W×Co? 拓展到整個輸出特征圖。

括號內的部分又可以分為兩步,(2?Ci?K2?1)=(Ci?K2)+(Ci?K2?1)(2\cdot C_i\cdot K^2-1)=(C_i\cdot K^2)+(C_i\cdot K^2-1)(2?Ci??K2?1)=(Ci??K2)+(Ci??K2?1)。第一項是乘法運算數,第二項是加法運算數,因為 nnn 個數相加,要加 n?1n-1n?1 次,所以不考慮bias,會有一個?1-1?1,如果考慮bias,剛好中和掉,括號內變為 2?Ci?K22\cdot C_i\cdot K^22?Ci??K2

全連接層

全連接層: (2×I?1)×O(2\times I-1)\times O(2×I?1)×O

III=輸入層神經元個數 ,OOO=輸出層神經元個數。

還是因為一個MAC算兩個個浮點運算,所以在最前面×2\times 2×2。同樣不考慮bias時有?1-1?1,有bias時沒有?1-1?1。分析同理,括號內是一個輸出神經元的計算量,拓展到OOO了輸出神經元。

NVIDIA Paper [2017-ICLR]

筆者在這里放上 NVIDIA 在 【2017-ICLR】的論文:PRUNING CONVOLUTIONAL NEURAL NETWORKS FOR RESOURCE EFFICIENT INFERENCE 的附錄部分FLOPs計算方法截圖放在下面供讀者參考。
在這里插入圖片描述

使用PyTorch直接輸出模型的Params(參數量)

完整統計參數量

import torch 
from torchvision.models import resnet50
import numpy as npTotal_params = 0
Trainable_params = 0
NonTrainable_params = 0model = resnet50()
for param in model.parameters():mulValue = np.prod(param.size())  # 使用numpy prod接口計算參數數組所有元素之積Total_params += mulValue  # 總參數量if param.requires_grad:Trainable_params += mulValue  # 可訓練參數量else:NonTrainable_params += mulValue  # 非可訓練參數量print(f'Total params: {Total_params / 1e6}M')
print(f'Trainable params: {Trainable_params/ 1e6}M')
print(f'Non-trainable params: {NonTrainable_params/ 1e6}M')

輸出:

Total params: 25.557032M
Trainable params: 25.557032M
Non-trainable params: 0.0M

簡單統計可訓練的參數量

通常,我們想知道的只是可訓練的參數量,我們也可以簡單地直接一行統計出可訓練的參數量:

import torchvision.models as modelsmodel = models.resnet50(pretrained=False)Trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Trainable params: {Trainable_params/ 1e6}M')

輸出:

Trainable params: 25.557032M

統計每一層的參數量

倘若想要統計每一層的參數量,參考代碼如下:

model = vgg16()
for name, parameters in model.named_parameters():print(name, ':', np.prod(parameters.size()))

會打印出每一層的名稱及參數量:

features.0.weight : 1728
features.0.bias : 64
features.2.weight : 36864
features.2.bias : 64
features.5.weight : 73728
...

使用thop庫來獲取模型的FLOPs(計算量)和Params(參數量)

安裝

直接pypi安裝即可

pip install thop

使用

我們使用thop庫來計算vgg16模型的計算量和參數量。

import torch
from thop import profile
from archs.ViT_model import get_vit, ViT_Aes
from torchvision.models import resnet50model = resnet50()
input1 = torch.randn(4, 3, 224, 224) 
flops, params = profile(model, inputs=(input1, ))
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')

輸出:

FLOPs = 16.446058496G
Params = 25.557032M

Ref:

https://openreview.net/forum?id=SJGCiw5gl

https://www.zhihu.com/question/65305385/answer/451060549

https://www.cnblogs.com/chuqianyu/p/14254702.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532797.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532797.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532797.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

設置中文linux輸入ubuntu,Linux_ubuntu怎么設置成中文?ubuntu中文設置圖文方法,  很多朋友安裝ubuntu后,發 - phpStudy...

ubuntu怎么設置成中文?ubuntu中文設置圖文方法很多朋友安裝ubuntu后,發現都是英文,看不懂要怎么辦?其實ubuntu是可以設置成中文的,下文小編就為大家帶來ubuntu中文的設置方法,一起去看下設置方法吧。ubuntu中文設置方…

科普 | 單精度、雙精度、多精度和混合精度計算的區別是什么?

科普 | 單精度、雙精度、多精度和混合精度計算的區別是什么? 轉自:https://zhuanlan.zhihu.com/p/93812784 我們提到圓周率 π 的時候,它有很多種表達方式,既可以用數學常數3.14159表示,也可以用一長串1和0的二進制長串表示。 …

linux 磁盤分配 簡書,linux 磁盤分區

1物理磁盤的構成: 盤面:由一圈一圈的磁道組成機械手臂:讀取數據主軸馬達:幫助機械手臂轉動2 扇區:磁盤上存取數據的最小單位512字節按照扇區分配大小,如果數據只有一字節也會占用512字節簇:用若…

條件控制與條件傳送詳解

條件控制與條件傳送詳解 提要 CSAPP3e中文譯本 3.6.5 用條件控制來實現條件分支 3.6.6 用條件傳送來實現條件分支 CSAPP3e第三章前面主要是介紹了機器級代碼的二進制形式和匯編形式、反匯編、x86匯編的基礎指令、條件碼及其訪問方式等。 在介紹到匯編語言的條件分支時分了兩…

聯合體(union)的使用方法及其本質

聯合體(union)的使用方法及其本質 轉自:https://blog.csdn.net/huqinwei987/article/details/23597091 有些基礎知識快淡忘了,所以有必要復習一遍,在不借助課本死知識的前提下做些推理判斷,溫故知新。 1…

linux設備驅動之串口移植,Linux設備驅動之UART驅動結構

一、對于串口驅動Linux系統中UART驅動屬于終端設備驅動,應該說是實現串口驅動和終端驅動來實現串口終端設備的驅動。要了解串口終端的驅動在Linux系統的結構就先要了解終端設備驅動在Linux系統中的結構體系,一方面自己了解的不夠,另一發面關于…

linux python復制安裝,復制一個Python全部環境到另一個環境,python另一個,導出此環境下安裝的包...

復制一個Python全部環境到另一個環境,python另一個,導出此環境下安裝的包導出此環境下安裝的包的版本信息清單pipfreeze>requirements.txt聯網,下載清單中的包到all-packet文件夾[[email protected] ~]# pip download -d ./all-packet -r requirement…

NVIDIA英偉達的Multi-GPU多卡通信框架NCCL

NVIDIA英偉達的Multi-GPU多卡通信框架NCCL 筆者注:NCCL 開源項目地址:https://github.com/NVIDIA/nccl 轉自:https://www.zhihu.com/question/63219175/answer/206697974 NCCL是Nvidia Collective multi-GPU Communication Library的簡稱&…

C語言n個坐標點間的最大距離,c語言已知兩點坐標,求另一點到穿過這兩點的直線最短距離。...

c語言已知兩點坐標,求另一點到穿過這兩點的直線最短距離。以下文字資料是由(歷史新知網www.lishixinzhi.com)小編為大家搜集整理后發布的內容,讓我們趕快一起來看一下吧!c語言已知兩點坐標,求另一點到穿過這兩點的直線最短距離。#…

[分布式訓練] 單機多卡的正確打開方式:理論基礎

[分布式訓練] 單機多卡的正確打開方式:理論基礎 轉自:https://fyubang.com/2019/07/08/distributed-training/ 瓦礫由于最近bert-large用的比較多,踩了很多分布式訓練的坑,加上在TensorFlow和PyTorch之間更換,算是熟…

s3c2416開發板 linux,S3C2416移植內核Linux3.1的wm9713聲卡過程

移植內核的聲卡驅動。原因沒有聲卡驅動,WM9713聲卡驅動移植(原來的內核有UDA1341聲卡驅動,我們再次基礎上直接修改)1、直接復制內核得到三個文件:s3c2416_wm9713.c , wm9713.c , s3c2416_ac97.c.linux-3.1\sound\soc\codecs\Wm9713.c---->wm9713.c;li…

Linux查看文件內容命令:cat, tail, head, more, less

Linux查看文件內容命令:cat, tail, head, more, less cat 直接顯示整個文件。 cat直接顯示全部文件內容,沒有換頁等交互。 cat filenamemore more命令,功能類似 cat ,cat命令是整個文件的內容從上到下顯示在屏幕上。 more會…

linux查看隊列 msg,linux第10天 msg消息隊列

cat /proc/sys/kernel/msgmax最大消息長度限制cat /proc/sys/kernel/msgmnb消息隊列總的字節數cat /proc/sys/kernel/msgmni消息條目數消息隊列綜合案例//server#include #include #include #include #include #include #include #include #define ERR_EXIT(m)do{perror(m);}wh…

Linux中 C++ main函數參數argc和argv含義及用法

Linux中 C main函數參數argc和argv含義及用法 簡介 argc 是 argument count的縮寫,表示傳入main函數的參數個數; argv 是 argument vector的縮寫,表示傳入main函數的參數序列或指針,并且第一個參數argv[0]一定是程序的名稱&…

c語言六位搶答器課程設計,51單片機八路搶答器課程設計

;說明:本人的這個設計改進后解決了前一個版本中1號搶答優先的問題,并增加了錦囊的設置,當參賽選手在回答問題時要求使用錦囊,則主持人按下搶答開始鍵,計時重新開始。;八路搶答器電路請看下圖是用ps仿真的,已…

ELF文件詳解—初步認識

ELF文件詳解—初步認識 轉自:https://blog.csdn.net/daide2012/article/details/73065204 一、 引言 在講解ELF文件格式之前,我們來回顧一下,一個用C語言編寫的高級語言程序是從編寫到打包、再到編譯執行的基本過程,我們知道在C…

埃及分數問題c語言,埃及分數問題(轉)

今日,小雨和小明來到網絡中心,繼續與劉老師討論“數的認識”問題。劉老師說:“還有一種‘埃及分數’需要認識。這是一類分裂分數的思維題,對思維能力的訓練很有價值。”小明說:“有意思,愿洗耳恭聽。”劉老…

linux常用命令--開發調試篇

前言 Linux常用命令中有一些命令可以在開發或調試過程中起到很好的幫助作用,有些可以幫助了解或優化我們的程序,有些可以幫我們定位疑難問題。本文將簡單介紹一下這些命令。 轉自:https://www.yanbinghu.com/2018/09/26/61877.html 示例程序…

簡單有趣的c語言小程序,一個有趣的小程序

該樓層疑似違規已被系統折疊 隱藏此樓查看此樓源碼:#include #include #include #include #include HINSTANCE g_hInstance 0;LRESULT CALLBACK WndProc(HWND, UINT, WPARAM, LPARAM);int WINAPI WinMain(HINSTANCE hInstance,HINSTANCE hPreInstance,LPSTR lpCmdLine,int nSh…

linux下ora 01110,ORA-01003ORA-01110

Oracle 9i數據庫登錄時,提示ORA-01003&ORA-01110,大概意思是數據文件存儲介質損壞。startup nomount,正常;alter database mount,也正常;alter database open,提示如下:alter database open*ERROR 位于第 1 行:ORA…