文章目錄
- 一、寫在前面
- 二、Reshape
- (一)用法
- (二)代碼展示
- 三、Unfold
- (一)torch.unfold 的基本概念
- (二)torch.unfold 的工作原理
- (三) 示例代碼
- (四)torch.unfold 的應用場景
- (五)注意事項
- (六)總結
- 四、View
- (一)用法
- (二)注意事項
- (三)其他方法
- 五、Flatten
- (一)torch.flatten 的基本概念
- (二)torch.flatten 的工作原理
- (三)示例代碼
- 六、Permute
- (一)torch.permute 的基本概念
- (二)torch.permute 的工作原理
- (三) 示例代碼
- (四) torch.permute 的應用場景
- 七、總結
一、寫在前面
最近在解析transformer源碼的時候突然看到了unfold?我在想unfold是什么意思?為什么不用reshape,他們的底層邏輯有什么區別呢?于是便相對比一下他們之間的區別,便有了本篇博客,希望對大家有幫助!
二、Reshape
(一)用法
1. torch.reshape(input, shape)
輸入是tensor和shape,其中原始shape和目標shape的元素數量要一致。
2. Tensor.reshape(shape) → Tensor
與上述用法一致,只不過這個是直接在tensor的基礎上進行reshape。
reshpe是按照順序進行重新排列組合的。
[16,2] 其實與 [4,2,2,2] 是一樣的,只要最后一維的數字是一樣,其實結果都是一樣的。
(二)代碼展示
三、Unfold
(一)torch.unfold 的基本概念
torch.unfold 的作用是將輸入張量的某個維度展開為多個滑動窗口。每個窗口包含一個局部區域,這些區域可以用于后續的計算。
- 語法:
torch.Tensor.unfold(dimension, size, step)
- 參數:
- dimension:要展開的維度(整數)。
- size:滑動窗口的大小(整數)。
- step:滑動窗口的步長(整數)。
- 返回值:返回一個新的張量,其中指定維度的每個元素被展開為多個滑動窗口。
(二)torch.unfold 的工作原理
假設我們有一個形狀為 (N, C, H, W) 的張量(例如圖像數據),我們希望在高度維度(H)上提取滑動窗口。
- 輸入張量:(N, C, H, W)
- 展開維度:dimension=2(即高度維度 H)
- 窗口大小:size=k(例如 k=3)
- 步長:step=s(例如 s=1)
torch.unfold 會將高度維度 H 展開為多個大小為 k 的滑動窗口,每個窗口之間間隔 s。
(三) 示例代碼
- 示例 1:一維張量的展開
import torch
# 創建一個一維張量
x = torch.arange(10)
print("原始張量:", x)# 使用 unfold 展開
unfolded = x.unfold(dimension=0, size=3, step=1)
print("展開后的張量:\n", unfolded)
- 輸出:
原始張量: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
展開后的張量:tensor([[0, 1, 2],[1, 2, 3],[2, 3, 4],[3, 4, 5],[4, 5, 6],[5, 6, 7],[6, 7, 8],[7, 8, 9]])
- 示例 2:二維張量的展開(圖像處理)
import torch
# 創建一個二維張量(模擬圖像)x = torch.arange(16).reshape(1, 1, 4, 4) # 形狀為 (1, 1, 4, 4)print("原始張量:\n", x)# 使用 unfold 展開unfolded_1 = x.unfold(dimension=2, size=3, step=1)unfolded_2 = unfolded_1.unfold(dimension=3, size=3, step=1)print("展開后的張量形狀:", unfolded_1.shape)print("展開后的張量:\n", unfolded_1)print("展開后的張量形狀:", unfolded_2.shape)print("展開后的張量:\n", unfolded_2)
- 輸出:
原始張量:tensor([[[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11],[12, 13, 14, 15]]]])
展開后的張量形狀: torch.Size([1, 1, 2, 4, 3])
展開后的張量:tensor([[[[[ 0, 4, 8],[ 1, 5, 9],[ 2, 6, 10],[ 3, 7, 11]],[[ 4, 8, 12],[ 5, 9, 13],[ 6, 10, 14],[ 7, 11, 15]]]]])
展開后的張量形狀: torch.Size([1, 1, 2, 2, 3, 3])
展開后的張量:tensor([[[[[[ 0, 1, 2],[ 4, 5, 6],[ 8, 9, 10]],[[ 1, 2, 3],[ 5, 6, 7],[ 9, 10, 11]]],[[[ 4, 5, 6],[ 8, 9, 10],[12, 13, 14]],[[ 5, 6, 7],[ 9, 10, 11],[13, 14, 15]]]]]])
數組的運算主要看最后兩維,倒數第二維代表行,倒數第一維代表列。
(四)torch.unfold 的應用場景
- 卷積操作
在卷積神經網絡(CNN)中,卷積核通過滑動窗口的方式提取圖像的局部特征。torch.unfold 可以用于手動實現卷積操作。
- 圖像處理
在圖像處理中,torch.unfold 可以用于提取圖像的局部區域(例如提取圖像的滑動窗口)。
- 時間序列分析
在時間序列分析中,torch.unfold 可以用于提取時間序列的滑動窗口,用于特征提取或模型訓練。
(五)注意事項
- 維度選擇:需要明確指定要展開的維度(dimension)。
- 窗口大小和步長:窗口大小(size)和步長(step)的選擇會影響展開后的張量形狀。
- 內存消耗:torch.unfold 可能會生成較大的張量,尤其是在高維數據上使用時,需要注意內存消耗。
(六)總結
- torch.unfold 的作用:從張量的某個維度提取滑動窗口。
- 常用參數:dimension(展開維度)、size(窗口大小)、step(步長)。
- 應用場景:卷積操作、圖像處理、時間序列分析等。
- 注意事項:選擇合適的維度、窗口大小和步長,避免內存消耗過大。
四、View
torch.view 用于返回一個與原始張量共享相同數據存儲的新張量,但具有不同的形狀。換句話說,view 只是改變了張量的視圖(view),而不會復制數據。
(一)用法
- Tensor.view(*shape):
- 參數:
*shape:新的形狀(可以是整數或元組)。
- 返回值:
返回一個新的張量,具有指定的形狀,并與原始張量共享相同的數據存儲。
(二)注意事項
- view 返回的張量與原始張量共享相同的數據存儲。
- 如果原始張量的數據發生變化,view 返回的張量也會隨之變化。
- view 要求張量的內存必須是連續的(即張量在內存中是連續存儲的)。如果內存不連續,view 會拋出錯誤。
(三)其他方法
- view_as
- 作用:將當前張量轉換為與另一個張量相同的形狀。
- 語法:
torch.Tensor.view_as(other)
- 參數:
other:目標張量,當前張量的形狀將被轉換為與 other 相同的形狀。
- 返回值:
返回一個新的張量,具有與 other 相同的形狀,并與原始張量共享數據存儲。
- 示例:
import torch
# 創建一個張量
x = torch.arange(12)
print("原始張量:", x)
# 創建目標張量
other = torch.empty(3, 4)
# 使用 view_as 改變形狀
y = x.view_as(other)
print("改變形狀后的張量:\n", y)
- 輸出:
原始張量: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
改變形狀后的張量:tensor([[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]])
- view_as_real
- 作用:將復數張量轉換為實數張量。
- 語法:
torch.Tensor.view_as_real()
- 返回值:
返回一個新的實數張量,形狀為 (…, 2),其中最后一個維度包含復數的實部和虛部。
- 示例:
import torch
# 創建一個復數張量
x = torch.tensor([1 + 2j, 3 + 4j])
print("原始張量:", x)
# 使用 view_as_real 轉換為實數張量
y = x.view_as_real()
print("轉換后的張量:\n", y)
- 輸出:
原始張量: tensor([1.+2.j, 3.+4.j])
轉換后的張量:tensor([[1., 2.],[3., 4.]])
- view_as_complex
- 作用:將實數張量轉換為復數張量。
- 語法:
torch.Tensor.view_as_complex()
- 返回值:
返回一個新的復數張量,形狀為 (…,),其中最后一個維度被解釋為復數的實部和虛部。
- 示例:
import torch
# 創建一個實數張量
x = torch.tensor([[1., 2.], [3., 4.]])
print("原始張量:\n", x)
# 使用 view_as_complex 轉換為復數張量
y = x.view_as_complex()
print("轉換后的張量:", y)
- 輸出:
原始張量:tensor([[1., 2.],[3., 4.]])
轉換后的張量: tensor([1.+2.j, 3.+4.j])
- view_as_strided
- 作用:返回一個具有指定步長和內存布局的張量視圖。
- 語法:
torch.Tensor.view_as_strided(size, stride)
- 參數:
size:新的形狀(元組)。
stride:新的步長(元組)。
- 返回值:
返回一個新的張量,具有指定的形狀和步長,并與原始張量共享數據存儲。
- 示例:
import torch
# 創建一個張量
x = torch.arange(9).view(3, 3)
print("原始張量:\n", x)
# 使用 view_as_strided 改變形狀和步長
y = x.view_as_strided((2, 2), (1, 2))
print("改變形狀和步長后的張量:\n", y)
- 輸出:
原始張量:tensor([[0, 1, 2],[3, 4, 5],[6, 7, 8]])
改變形狀和步長后的張量:tensor([[0, 2],[1, 3]])
- view_as_real 和 view_as_complex 的應用場景
- 復數計算:
在涉及復數計算的任務中,view_as_real 和 view_as_complex 可以用于在復數和實數之間進行轉換。
- 信號處理:
在信號處理中,復數張量常用于表示頻域信號,view_as_real 和 view_as_complex 可以用于頻域和時域之間的轉換。
- view_as_strided 的應用場景
- 自定義內存布局:
在需要自定義內存布局的場景中,view_as_strided 可以用于創建具有特定步長和形狀的張量視圖。
- 高效內存訪問:
在需要高效訪問內存的場景中,view_as_strided 可以用于優化內存訪問模式。
五、Flatten
(一)torch.flatten 的基本概念
torch.flatten 的作用是將輸入張量的指定維度展平為一維。它可以展平整個張量,也可以只展平部分維度。
- 語法:
torch.flatten(input, start_dim=0, end_dim=-1)
-
參數:
- input:輸入張量。
- start_dim:開始展平的維度(整數),默認為 0。
- end_dim:結束展平的維度(整數),默認為 -1。
- 返回值:返回一個新的張量,具有展平后的形狀。
(二)torch.flatten 的工作原理
- torch.flatten 會將指定范圍內的維度展平為一維。
- 如果 start_dim=0 且 end_dim=-1,則整個張量會被展平為一維。
- 如果只展平部分維度,則其他維度保持不變。
(三)示例代碼
- 示例 1:展平整個張量
import torch
# 創建一個二維張量
x = torch.arange(12).view(3, 4)
print("原始張量:\n", x)
# 使用 flatten 展平整個張量
y = torch.flatten(x)
print("展平后的張量:", y)
- 輸出:
原始張量:tensor([[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]])
展平后的張量: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
- 示例 2:展平部分維度
import torch
# 創建一個三維張量
x = torch.arange(24).view(2, 3, 4)
print("原始張量:\n", x)
# 使用 flatten 展平第二個維度到第三個維度
y = torch.flatten(x, start_dim=1, end_dim=2)
print("展平后的張量:\n", y)
- 輸出:
原始張量:tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])
展平后的張量:tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]])
- 示例 3:展平特定維度
import torch
# 創建一個四維張量
x = torch.arange(24).view(2, 2, 3, 2)
print("原始張量:\n", x)
# 使用 flatten 展平第三個維度
y = torch.flatten(x, start_dim=2, end_dim=2)
print("展平后的張量:\n", y)
- 輸出:
原始張量:tensor([[[[ 0, 1],[ 2, 3],[ 4, 5]],[[ 6, 7],[ 8, 9],[10, 11]]],[[[12, 13],[14, 15],[16, 17]],[[18, 19],[20, 21],[22, 23]]]])
展平后的張量:tensor([[[ 0, 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10, 11]],[[12, 13, 14, 15, 16, 17],[18, 19, 20, 21, 22, 23]]])
六、Permute
(一)torch.permute 的基本概念
torch.permute 的作用是將輸入張量的維度按照指定的順序重新排列。它類似于 NumPy 中的 transpose,但更加靈活,可以同時對多個維度進行排列。
- 語法:
torch.Tensor.permute(*dims)
- 參數:
- *dims:新的維度順序(元組或多個整數)。
- 返回值:返回一個新的張量,具有重新排列后的維度順序,并與原始張量共享數據存儲。
(二)torch.permute 的工作原理
torch.permute 會將輸入張量的維度按照指定的順序重新排列。
新的維度順序必須與原始張量的維度數量相同,并且每個維度索引只能出現一次。
(三) 示例代碼
- 示例 1:二維張量的轉置
import torch
# 創建一個二維張量
x = torch.arange(12).view(3, 4)
print("原始張量:\n", x)# 使用 permute 轉置
y = x.permute(1, 0) # 將維度 0 和 1 交換
print("轉置后的張量:\n", y)
- 輸出:
原始張量:tensor([[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]])
轉置后的張量:tensor([[ 0, 4, 8],[ 1, 5, 9],[ 2, 6, 10],[ 3, 7, 11]])
- 示例 2:三維張量的維度重排
import torch
# 創建一個三維張量
x = torch.arange(24).view(2, 3, 4)
print("原始張量:\n", x)# 使用 permute 重排維度
y = x.permute(2, 0, 1) # 將維度 0, 1, 2 重排為 2, 0, 1
print("重排后的張量形狀:", y.shape)
print("重排后的張量:\n", y)
- 輸出:
原始張量:tensor([[[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])
重排后的張量形狀: torch.Size([4, 2, 3])
重排后的張量:tensor([[[ 0, 4, 8],[12, 16, 20]],[[ 1, 5, 9],[13, 17, 21]],[[ 2, 6, 10],[14, 18, 22]],[[ 3, 7, 11],[15, 19, 23]]])
- 示例 3:四維張量的維度重排
import torch
# 創建一個四維張量
x = torch.arange(24).view(2, 2, 3, 2)
print("原始張量:\n", x)# 使用 permute 重排維度
y = x.permute(3, 1, 2, 0) # 將維度 0, 1, 2, 3 重排為 3, 1, 2, 0
print("重排后的張量形狀:", y.shape)
print("重排后的張量:\n", y)
- 輸出:
原始張量:tensor([[[[ 0, 1],[ 2, 3],[ 4, 5]],[[ 6, 7],[ 8, 9],[10, 11]]],[[[12, 13],[14, 15],[16, 17]],[[18, 19],[20, 21],[22, 23]]]])
重排后的張量形狀: torch.Size([2, 2, 3, 2])
重排后的張量:tensor([[[[ 0, 12],[ 2, 14],[ 4, 16]],[[ 6, 18],[ 8, 20],[10, 22]]],[[[ 1, 13],[ 3, 15],[ 5, 17]],[[ 7, 19],[ 9, 21],[11, 23]]]])
(四) torch.permute 的應用場景
- 圖像處理
在圖像處理中,torch.permute 可以用于調整圖像的通道順序。例如,將 (C, H, W) 的圖像張量轉換為 (H, W, C)。
- 深度學習模型輸入
在深度學習中,模型的輸入張量通常需要特定的維度順序。torch.permute 可以用于調整輸入張量的維度順序。
- 數據預處理
在數據預處理中,torch.permute 可以用于調整數據的維度順序,以便進行后續的計算或操作。
如果每一維度都有特定的含義,此時想改變維度的時候用permute。
七、總結
- reshape 與 view 幾乎一致, 甚至可以說reshape可以代替view;
- reshape 與 permute 的區別在于,reshape是按照順序重新進行排列組合,permute是按照維度進行重新排列組合,如果各個維度都有特定的意義,那么permute會更合適。
- unfold是按照某一維度滑動選取數據,新增加一維,新增加的維度的大小為滑動窗口的大小,原始維度會根據滑動窗口和step的大小而變化。
- flatten也是按照順序進行展開,且展開的是某個范圍的維度,不是特定的維度。
- 不管是幾維數組,主要看最后兩維,倒數第二維代表行,倒數第一維代表列。