einops和einsum:直接操作張量的利器

einops和einsum:直接操作張量的利器

einops和einsum是Vision Transformer的代碼實現里出現的兩個操作tensor維度和指定tensor計算的神器,在卷積神經網絡里不多見,本文將介紹簡單介紹一下這兩樣工具,方便大家更好地理解Vision Transformer的代碼。

einops:直接操作tensor維度的神器

github地址:https://github.com/arogozhnikov/einops

einops:靈活和強大的張量操作,可讀性強和可靠性好的代碼。支持numpy、pytorch、tensorflow等。

有了他,研究者們可以自如地操作張量的維度,使得研究者們能夠簡單便捷地實現并驗證自己的想法,在Vision Transformer等需要頻繁操作張量維度的代碼實現里極其有用。

這里簡單地介紹幾個最常用的函數。

安裝

einops的安裝非常簡單,直接pip即可:

pip install einops

rearrange

import torch
from einops import rearrangei_tensor = torch.randn(16, 3, 224, 224)		# 在CV中很常見的四維tensor: (N,C,H,W)
print(i_tensor.shape)
o_tensor = rearrange(i_tensor, 'n c h w -> n h w c')
print(o_tensor.shape)

輸出:

torch.Size([16, 3, 224, 224])
torch.Size([16, 224, 224, 3])

在CV中很常見的四維tensor:(N,C,H,W),即表示(批尺寸,通道數,圖像高,圖像寬),在Vision Transformer中,經常需要對tensor的維度進行變換操作,rearrange函數可以很方便地、很直觀地操作tensor的各個維度。

除此之外,rearrange還有稍微進階一點的玩法:

 
i_tensor = torch.randn(16, 3, 224, 224)
o_tensor = rearrange(i_tensor, 'n c h w -> n c (h w)')
print(o_tensor.shape)  
o_tensor = rearrange(i_tensor, 'n c (m1 p1) (m2 p2) -> n c m1 p1 m2 p2', p1=16, p2=16)
print(o_tensor.shape)  

輸出:

torch.Size([16, 3, 50176])
torch.Size([16, 3, 14, 16, 14, 16])

可以進行指定維度的合并和拆分,注意拆分時需要在變換規則后面指定參數。

repeat

from einops import repeati_tensor = torch.randn(3, 224, 224)  
print(i_tensor.shape)
o_tensor = repeat(i_tensor, 'c h w -> n c h w', n=16)  
print(o_tensor.shape)

repeat時記得指定右側repeat之后的維度值

輸出:

torch.Size([3, 224, 224])
torch.Size([16, 3, 224, 224])

reduce

from einops import reducei_tensor = torch.randn((16, 3, 224, 224))
o_tensor = reduce(i_tensor, 'n c h w -> c h w', 'mean')
print(o_tensor.shape)
o_tensor_ = reduce(i_tensor, 'b c (m1 p1) (m2 p2)  -> b c m1 m2 ', 'mean', p1=16, p2=16)
print(o_tensor_.shape)

輸出:

torch.Size([3, 224, 224])
torch.Size([16, 3, 14, 14])

reduce時記得指定左側要被reduce的維度值

Rearrange

import torch
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from einops.layers.torch import Rearrangemodel = Sequential(Conv2d(3, 64, kernel_size=3),MaxPool2d(kernel_size=2),Rearrange('b c h w -> b (c h w)'),      # 相當于 flatten 展平的作用Linear(64*15*15, 120), ReLU(),Linear(120, 10)
)i_tensor = torch.randn(16, 3, 32, 32)
o_tensor = model(i_tensor)
print(o_tensor.shape)

輸出:

torch.Size([16, 10])

einops.layers.torch.Rearrange 是nn.Module的子類,可以放在網絡里面直接當作一層。

torch.einsum:愛因斯坦簡記法

愛因斯坦簡記法:是一種由愛因斯坦提出的,對向量、矩陣、張量的求和運算 ∑\sum求和簡記法

在該簡記法當中,省略掉的部分是:

  1. 求和符號 ∑\sum
  2. 求和號的下標 iii

省略規則為:默認成對出現的下標(如下例1中的 iii 和例2中的 kkk )為求和下標,被省略。

1)xiyix_iy_ixi?yi?簡化表示內積 <x,y><\mathbf{x},\mathbf{y}><x,y>
xiyi:=∑ixiyi=ox_iy_i := \sum_i x_iy_i = o xi?yi?:=i?xi?yi?=o

其中o為輸出。

  1. XikYkjX_{ik}Y_{kj}Xik?Ykj? 簡化表示矩陣乘法 XY\mathbf{X}\mathbf{Y}XY
    XikYkj:=∑kXikYkj=OijX_{ik}Y_{kj}:=\sum_k X_{ik}Y_{kj}=\mathbf{O}_{ij} Xik?Ykj?:=k?Xik?Ykj?=Oij?
    其中 Oij\mathbf{O}_{ij}Oij? 為輸出矩陣的第ij個元素。

這樣的求和簡記法,能夠以一種統一的方式表示各種各樣的張量運算(內積、外積、轉置、點乘、矩陣的跡、其他自定義運算),為不同運算的實現提供了一個統一模型。

einsum在numpy和pytorch中都有實現,下面我們以在torch中為例,展示一下最簡單的用法

import torchi_a = torch.randn(16, 32, 4, 8)
i_b = torch.randn(16, 32, 8, 16)out = torch.einsum('b h i j, b h j d -> b h i d', i_a, i_b)
print(out.shape)

輸出:

torch.Size([16, 32, 4, 16])

可以看到,torch.einsum可以簡便地指定tensor運算,輸入的兩個tensor維度分別為 bhijb\ h\ i\ jb?h?i?jbhjdb\ h\ j\ db?h?j?d ,經過tensor運算后,得到的張量維度為 bhidb\ h\ i\ db?h?i?d 。代碼運行結果與我們的預期一致。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/532858.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/532858.shtml
英文地址,請注明出處:http://en.pswp.cn/news/532858.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

php的filter input,php中filter_input函數用法分析

本文實例分析了php中filter_input函數用法。分享給大家供大家參考。具體分析如下&#xff1a;在 php5.2 中,內置了filter 模塊,用于變量的驗證和過濾,過濾變量等操作&#xff0c;這里我們看下如何直接過濾用戶輸入的內容.fliter 模塊對應的 filter_input 函數使用起來非常的簡單…

COCO 數據集格式及mmdetection中的轉換方法

COCO 數據集格式及mmdetection中的轉換方法 COCO格式 CV中的目標檢測任務不同于分類&#xff0c;其標簽的形式稍為復雜&#xff0c;有幾種常用檢測數據集格式&#xff0c;本文將簡要介紹最為常見的COCO數據集的格式。 完整的官方樣例可自行查閱&#xff0c;以下是幾項關鍵的…

php獲取h1,jQuery獲取h1-h6標題元素值方法實例

本文主要介紹了jQuery實現獲取h1-h6標題元素值的方法,涉及$(":header")選擇器操作h1-h6元素及事件響應相關技巧,需要的朋友可以參考下&#xff0c;希望能幫助到大家。1、問題背景&#xff1a;查找到h1-h6&#xff0c;并遍歷它們&#xff0c;打印出內容2、實現代碼&am…

在導入NVIDIA的apex庫時報錯 ImportError cannot import name ‘UnencryptedCookieSessionFactoryConfig‘ from

在導入NVIDIA的apex庫時報錯 ImportError: cannot import name ‘UnencryptedCookieSessionFactoryConfig’ from ‘pyramid.session’ (unknown location) 報錯 在使用NVIDIA的apex庫時報錯 ImportError: cannot import name ‘UnencryptedCookieSessionFactoryConfig’ fro…

php怎么取request,PHP-如何在Guzzle中獲取Request對象?

我需要使用Guzzle檢查數據庫中的很多項目.例如,項目數量為2000-5000.將其全部加載到單個數組中太多了,因此我想將其分成多個塊&#xff1a;SELECT * FROM items LIMIT100.當最后一個項目發送到Guzzle時,則請求下一個100個項目.在“已滿”處理程序中,我應該知道哪個項目得到了響…

[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 論文簡析及關鍵代碼簡析

[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 論文簡析及關鍵代碼簡析 論文&#xff1a;https://arxiv.org/abs/2104.00323 代碼&#xff1a;https://github.com/dvlab-research/JigsawClustering 總結 本文提出了一種單批次&#xff0…

java jps都卡死,java長時間運行后,jps失效

在部署完應用后&#xff0c;原本jps使用的好好的&#xff0c;能正確的查詢到自己正在運行的java程序。但&#xff0c;過了一段時間后&#xff0c;再使用jps來查看運行的應用時&#xff0c;自己運行的程序都看不到&#xff0c;但是自己也沒有關閉這些程序啊&#xff01;然而使用…

指針(*)、取地址()、解引用(*)與引用()

指針(*)、取地址(&)、解引用(*)與引用(&) C 提供了兩種指針運算符&#xff0c;一種是取地址運算符 &&#xff0c;一種是間接尋址運算符 *。 指針是一個包含了另一個變量地址的變量&#xff0c;您可以把一個包含了另一個變量地址的變量說成是"指向"另一…

matlab電類,985電氣研二,有發過考研經驗貼 電氣電力類的有

該樓層疑似違規已被系統折疊 隱藏此樓查看此樓clc;clear;p[2.259;2.257;2.256;2.254;2.252;2.248;2.247;2.245;2.244;2.243;2.239;2.238;2.236;2.235;2.234;2.231;2.229;2.228;2.226;2.225;2.221;2.220;2.219;2.217;2.216;2.211;2.209;2.208;2.207;2.206;2.202;2.201;2.199;2.1…

matlab legend 分塊,matlab?legend?分塊!

matlab legend 分塊&#xff01;(2013-03-26 18:07:38)%%%壓差clc;clear all;figure(55);set (gcf,Position,[116 123 275 210],color,w);P[25 26 27 28 29 30 31 32 33 34 35];%理論q0.00006*pi*28*P*10^(6)*0.03^3/(12*0.028448*5);q1110.00006*pi*28*P*10^(6)*0.03^3/(12*0.…

利用opencv-python繪制多邊形框或(半透明)區域填充(可用于分割任務mask可視化)

利用opencv-python繪制多邊形框或&#xff08;半透明&#xff09;區域填充&#xff08;可用于分割任務mask可視化&#xff09; 本文主要就少opencv中兩個函數polylines和fillPoly分別用于繪制多邊形框或區域填充&#xff0c;并會會以常見用途分割任務mask&#xff08;還是筆者…

matlab與maple互聯,Matlab,Maple和Mathematica三款主流科學計算軟件的互操作

本文根據網上零散的信息以及這三款軟件自帶的說明文檔整理而成&#xff0c;為備忘而記錄。記錄了Matlab和Maple之間的相互調用&#xff0c;以及Matlab和Mathematica之間相互調用的安裝配置方法。為何需要互操作&#xff1f; 數值計算和圖形方面Matlab毫無疑問是最強的&a…

PyTorch中的topk方法以及分類Top-K準確率的實現

PyTorch中的topk方法以及分類Top-K準確率的實現 Top-K 準確率 在分類任務中的類別數很多時&#xff08;如ImageNet中1000類&#xff09;&#xff0c;通常任務是比較困難的&#xff0c;有時模型雖然不能準確地將ground truth作為最高概率預測出來&#xff0c;但通過學習&#…

java高級語言特性,Java高級語言特性之注解

注解的定義Java 注解(Annotation)又稱 Java 標注&#xff0c;是 JDK1.5 引入的一種注釋機制。注解是元數據的一種形式&#xff0c;提供有關于程序但不屬于程序本身的數據。注解對它們注解的代碼的操作沒有直接影響。注解本身沒有任何意義&#xff0c;單獨的注解就是一種注釋&am…

C/C++中的typedef 和 #define

C/C中的typedef 和 #define typedef C/C中的關鍵字typedef允許用戶為類型名來起一個新名字&#xff0c;通常會是縮寫或者能夠清晰表明類型含義的新名字。 例&#xff1a; typedef unsigned int UINT; UINT 100;值得注意的是&#xff0c;typedef除了為C/C內置的數據類型取別…

php3.2.3 升級,thinkphp3.2.3 升級到3.2.4時出錯問題

有些項目最初用OneThink做的&#xff0c;而OneThink 默認使用的TP 是3.2.0 的&#xff0c;沒事的時候就想給升級一下&#xff0c;但是直接復制進去的時候&#xff0c;有錯誤&#xff0c;導致OneThink 不能運行&#xff0c;排查后&#xff0c;需要修改兩個地方1、修改 Applicati…

Positional Encodings in ViTs 近期各視覺Transformer中的位置編碼方法總結及代碼解析 1

Positional Encodings in ViTs 近期各視覺Transformer中的位置編碼方法總結及代碼解析 最近CV領域的Vision Transformer將在NLP領域的Transormer結果借鑒過來&#xff0c;屠殺了各大CV榜單。對其做各種改進的頂會論文也是層出不窮&#xff0c;本文將聚焦于各種最新的視覺trans…

mysql 分析查詢語句,MySQL教程之SQL語句分析查詢優化

怎么獲取有功能問題的SQL1、經過用戶反應獲取存在功能問題的SQL2、經過慢查詢日志獲取功能問題的SQL3、實時獲取存在功能問題的SQL運用慢查詢日志獲取有功能問題的SQL首要介紹下慢查詢相關的參數1、slow_query_log 發動定制記載慢查詢日志設置的辦法&#xff0c;能夠經過MySQL指…

關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題

關于PyTorch中的register_forward_hook()函數未能執行其中hook函數的問題 Hook 是 PyTorch 中一個十分有用的特性。利用它&#xff0c;我們可以不必改變網絡輸入輸出的結構&#xff0c;方便地獲取、改變網絡中間層變量的值和梯度。這個功能被廣泛用于可視化神經網絡中間層的 f…

geoda權重矩陣導入matlab,空間計量經濟學-分析解析.ppt

廈門大學 鄧明 空間截面回歸模型 地理加權回歸模型 地理加權回歸模型擴展了普通線性回歸模型。在GWR模型中&#xff0c;特定區位的回歸系數不再是利用全部信息獲得的假定常數&#xff0c;而是利用鄰近觀測值的子樣本數據信息進行局域(Local)回歸估計而得&#xff0c;并隨著空間…