PyTorch常用Tensor形狀變換函數詳解
在PyTorch中,對張量(Tensor)進行形狀變換是深度學習模型構建中不可或缺的一環。無論是為了匹配網絡層的輸入要求,還是為了進行數據預處理和維度調整,都需要靈活運用各種形狀變換函數。本文將系統介紹幾個核心的形狀變換函數,并深入剖析它們的用法區別與關鍵點。
一、改變形狀與元素數量:view()
與 reshape()
view()
和 reshape()
是最常用的兩個用于重塑張量形狀的函數。它們都可以改變張量的維度,但前提是新舊張量的元素總數必須保持一致。盡管功能相似,但它們在工作機制上存在關鍵差異。
view()
view()
函數返回一個具有新形狀的張量,這個新張量與原始張量 共享底層數據。這意味著修改其中一個張量的數據,另一個也會隨之改變。
關鍵點:
- 內存共享:
view()
保證返回的張量與原張量共享數據,不會創建新的內存副本,因此效率很高。 - 連續性要求:
view()
只能作用于在內存中 連續 (contiguous) 的張量。對于一個非連續的張量(例如通過transpose
操作后得到的張量),直接使用view()
會引發錯誤。 在這種情況下,需要先調用.contiguous()
方法將其變為連續的,然后再使用view()
。
reshape()
reshape()
函數同樣用于改變張量的形狀,但它更加靈活和安全。
關鍵點:
- 智能處理:
reshape()
可以處理連續和非連續的張量。 - 視圖或副本:當作用于連續張量時,
reshape()
的行為類似于view()
,返回一個共享數據的視圖。 然而,當作用于非連續張量時,reshape()
會創建一個新的、具有所需形狀的連續張量,并復制原始數據,此時返回的是一個副本,與原張量不再共享內存。 - 不確定性:
reshape()
的語義是它 可能 會也 可能不會 共享存儲空間,事先無法確定。
用法區別與選擇
特性 | view() | reshape() |
---|---|---|
內存共享 | 總是共享(返回視圖) | 可能共享(視圖),也可能不共享(副本) |
對非連續張量 | 拋出錯誤 | 自動創建副本 |
推薦使用場景 | 當確定張量是連續的,并且需要保證內存共享以提升效率時。 | 當不確定張量的連續性,或希望代碼更健干時,reshape() 是更安全的選擇。 |
使用示例
import torch# 創建一個連續的張量
x = torch.arange(12) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])# 使用 view() 和 reshape()
x_view = x.view(3, 4)
x_reshape = x.reshape(3, 4)print("Original is contiguous:", x.is_contiguous()) # True
print("x_view:\n", x_view)
print("x_reshape:\n", x_reshape)# 創建一個非連續的張量
y = torch.arange(12).reshape(3, 4).t() # .t() 是 transpose(0, 1) 的簡寫
print("\nOriginal y is contiguous:", y.is_contiguous()) # False# 對非連續張量使用 reshape() - 成功
y_reshape = y.reshape(3, 4)
print("y_reshape:\n", y_reshape)
print("y_reshape is contiguous:", y_reshape.is_contiguous()) # True# 對非連續張量使用 view() - 報錯
try:y_view = y.view(3, 4)
except RuntimeError as e:print("\nError with view():", e)# 先用 .contiguous() 再用 view() - 成功
y_view_contiguous = y.contiguous().view(3, 4)
print("y_view_contiguous:\n", y_view_contiguous)
二、交換維度:transpose()
與 permute()
與 reshape
或 view
不同,transpose
和 permute
用于重新排列張量的維度,而不是像“拉伸”或“壓縮”數據那樣改變形狀。
transpose()
transpose()
函數專門用于 交換 張量的兩個指定維度。
用法: tensor.transpose(dim0, dim1)
關鍵點:
- 兩兩交換:每次只能交換兩個維度。
- 共享數據:返回的張量與原張量共享底層數據,但通常會導致張量在內存中變為非連續。
permute()
permute()
函數則提供了更強大的維度重排能力,可以一次性對所有維度進行任意順序的重新排列。
用法: tensor.permute(dims)
,其中 dims
是一個包含所有原始維度索引的新順序。
關鍵點:
- 任意重排:必須為所有維度提供新的順序。
- 通用性:
transpose(dim0, dim1)
可以看作是permute
的一個特例。 - 共享數據與非連續性:同樣地,
permute
返回的也是一個共享數據的視圖,并且通常會使張量變為非連續。
用法區別與選擇
- 當只需要交換兩個維度時,使用
transpose()
更直觀。 - 當需要進行更復雜的維度重排,例如將
(B, C, H, W)
變為(B, H, W, C)
時,必須使用permute()
。
重要提示:由于 transpose()
和 permute()
經常產生非連續的張量,如果后續需要使用 view()
,必須先調用 .contiguous()
方法。
使用示例
import torch# 假設張量形狀為 (batch, channel, height, width)
x = torch.randn(2, 3, 4, 5) # Shape: [2, 3, 4, 5]# 使用 transpose() 交換 height 和 width 維度
# 原始維度: 0, 1, 2, 3 -> 交換維度 2 和 3
x_transposed = x.transpose(2, 3)
print("Original shape:", x.shape) # torch.Size([2, 3, 4, 5])
print("Transposed shape:", x_transposed.shape) # torch.Size([2, 3, 5, 4])# 使用 permute() 將 (B, C, H, W) 變為 (B, H, W, C)
# 原始維度: 0, 1, 2, 3 -> 新維度順序: 0, 2, 3, 1
x_permuted = x.permute(0, 2, 3, 1)
print("Permuted shape:", x_permuted.shape) # torch.Size([2, 4, 5, 3])
三、增減維度:unsqueeze()
與 squeeze()
這兩個函數用于添加或移除長度為 1 的維度,這在處理批處理數據或需要廣播時非常有用。
unsqueeze()
unsqueeze()
用于在指定位置 添加 一個長度為 1 的維度。
用法: tensor.unsqueeze(dim)
關鍵點:
- 它會在
dim
參數指定的位置插入一個新維度。 - 常用于為單個樣本數據添加
batch
維度,或為二維張量添加channel
維度,以符合模型的輸入格式。
squeeze()
squeeze()
用于 移除 所有長度為 1 的維度。
用法:
tensor.squeeze()
: 移除所有長度為 1 的維度。tensor.squeeze(dim)
: 只在指定dim
位置移除長度為 1 的維度,如果該維度長度不為 1,則張量不變。
關鍵點:
- 這是一個降維操作,可以方便地去除多余的、長度為1的維度。
unsqueeze()
和squeeze()
互為逆操作。- 返回的張量同樣與原張量共享數據。
使用示例
import torch# 創建一個形狀為 (3, 4) 的張量
x = torch.randn(3, 4)
print("Original shape:", x.shape) # torch.Size([3, 4])# 使用 unsqueeze() 在第 0 維添加 batch 維度
x_unsqueezed_0 = x.unsqueeze(0)
print("Unsqueeze at dim 0:", x_unsqueezed_0.shape) # torch.Size([1, 3, 4])# 使用 unsqueeze() 在第 1 維添加 channel 維度
x_unsqueezed_1 = x.unsqueeze(1)
print("Unsqueeze at dim 1:", x_unsqueezed_1.shape) # torch.Size([3, 1, 4])# --- squeeze ---
y = torch.randn(1, 3, 1, 4)
print("\nOriginal y shape:", y.shape) # torch.Size([1, 3, 1, 4])# 使用 squeeze() 移除所有長度為 1 的維度
y_squeezed_all = y.squeeze()
print("Squeeze all ones:", y_squeezed_all.shape) # torch.Size([3, 4])# 使用 squeeze(dim) 只移除指定位置的維度
y_squeezed_dim = y.squeeze(0) # 移除第 0 維
print("Squeeze at dim 0:", y_squeezed_dim.shape) # torch.Size([3, 1, 4])
四、view
/reshape
與 squeeze
/unsqueeze
的關系:可替代性討論
從最終的形狀結果來看,view
和 reshape
在很多情況下確實可以實現與 squeeze
和 unsqueeze
相同的效果。然而,它們的設計理念和使用場景存在顯著差異,一般不建議混用。
使用 reshape
替代 unsqueeze
將一個形狀為 (A, B)
的張量通過 unsqueeze(0)
變為 (1, A, B)
,可以等價地使用 reshape(1, A, B)
來實現。
關鍵區別:
- 可讀性與意圖:
unsqueeze(dim)
的意圖非常明確——在指定位置插入一個新維度。這使得代碼更易于理解。而reshape()
需要提供完整的最終形狀,閱讀者需要通過對比新舊形狀才能理解其操作意圖。 - 便利性:使用
unsqueeze
時,無需知道張量的其他維度尺寸。而使用reshape
,則必須知道所有維度的大小才能構建新的形狀參數。
使用 reshape
替代 squeeze
將一個形狀為 (1, A, B, 1)
的張量通過 squeeze()
變為 (A, B)
,可以等價地使用 reshape(A, B)
實現。
關鍵區別:
- 自動化與便利性:
squeeze()
的核心優勢在于其自動化。它會自動移除所有大小為 1 的維度,使用者無需預先知道哪些維度是 1。如果使用reshape
,則必須手動計算出目標形狀,這在處理動態或未知的輸入形狀時會非常繁瑣。 - 條件性操作:
squeeze(dim)
只在指定維度大小為 1 時才執行操作,否則張量保持不變。reshape
不具備這種條件判斷能力,它會強制改變形狀,如果元素總數不匹配則會報錯。
雖然 reshape
功能更強大,理論上可以模擬 squeeze
和 unsqueeze
的操作,但強烈建議使用專用的函數。
- 當你的意圖是添加或移除單個維度時,請使用
unsqueeze
和squeeze
。這不僅使代碼更清晰、更具可讀性,還能利用squeeze
的自動檢測和條件操作特性,讓代碼更健壯。 - 只有當你需要進行更復雜的、非增減單一維度的形狀重塑時,才應使用
reshape
或view
。
使用示例
import torch# --- unsqueeze vs reshape ---
x = torch.randn(3, 4) # Shape: [3, 4]# 目標: 添加 batch 維度 -> (1, 3, 4)
x_unsqueezed = x.unsqueeze(0)
x_reshaped = x.reshape(1, 3, 4)print("Unsqueeze result:", x_unsqueezed.shape) # torch.Size([1, 3, 4])
print("Reshape result:", x_reshaped.shape) # torch.Size([1, 3, 4])
# 結果相同,但 unsqueeze(0) 意圖更明確# --- squeeze vs reshape ---
y = torch.randn(1, 3, 1, 4) # Shape: [1, 3, 1, 4]# 目標: 移除所有大小為1的維度 -> (3, 4)
y_squeezed = y.squeeze()
y_reshaped = y.reshape(3, 4) # 需要手動知道結果是 (3, 4)print("\nSqueeze result:", y_squeezed.shape) # torch.Size([3, 4])
print("Reshape result:", y_reshaped.shape) # torch.Size([3, 4])
# squeeze() 自動完成,reshape() 需要手動計算
五、復制與擴展數據:repeat()
與 expand()
除了改變張量的形狀,有時還需要沿著某些維度復制數據,以生成一個更大的張量。repeat()
和 expand()
函數都可以實現這一目的,但它們在實現方式和內存使用上有本質的區別。
expand()
expand()
函數通過擴展長度為 1 的維度來創建一個新的、更高維度的張量 視圖。它并不會實際分配新的內存來存儲復制的數據,因此非常高效。
關鍵點:
- 內存高效:
expand()
返回的是一個視圖,與原張量共享底層數據,不產生數據拷貝。 - 參數含義:
expand()
的參數指定的是張量的 最終目標形狀。 - 使用限制:
expand()
只能用于擴展大小為 1 的維度(也稱為“單例維度”)。對于大小不為 1 的維度,新尺寸必須與原尺寸相同。如果想保持某個維度不變,可以傳入-1
作為該維度的尺寸。
repeat()
repeat()
函數通過在物理內存中 真正地復制 數據來構造一個新張量。它會沿著指定的維度將張量重復指定的次數。
關鍵點:
- 數據拷貝:
repeat()
會創建一個全新的張量,其內容是原張量數據的重復,因此內存占用會相應增加。 - 參數含義:
repeat()
的參數指定的是每個維度需要 重復的次數。 - 無限制:
repeat()
可以對任意維度的張量進行重復,無論其原始大小是否為 1。
用法區別與選擇
特性 | expand() | repeat() |
---|---|---|
內存使用 | 高效,不分配新內存(返回視圖) | 內存消耗大,會創建數據的完整副本 |
參數含義 | 擴展后的 目標尺寸 | 每個維度的 重復次數 |
使用限制 | 只能擴展大小為 1 的維度 | 可以重復任意大小的維度 |
推薦使用場景 | 當需要進行廣播(Broadcasting)操作且注重內存效率時,例如將一個偏置向量擴展以匹配一個批次的數據。 | 當需要一個獨立的、數據重復的張量副本,并且后續可能需要就地修改其中的部分數據時。 |
使用示例
import torch# 創建一個包含單例維度的張量
x = torch.tensor([[1], [2], [3]]) # Shape: [3, 1]
print("Original tensor x:\n", x)
print("Original shape:", x.shape)# 使用 expand() 將大小為 1 的維度擴展到 4
# 目標形狀是 (3, 4),-1 表示該維度大小不變
expanded_x = x.expand(-1, 4)
print("\nExpanded x shape:", expanded_x.shape) # torch.Size([3, 4])
print("Expanded x:\n", expanded_x)# 使用 repeat() 實現類似效果
# 維度0重復1次(不變),維度1重復4次
repeated_x = x.repeat(1, 4)
print("\nRepeated x (1, 4) shape:", repeated_x.shape) # torch.Size([3, 4])
print("Repeated x (1, 4):\n", repeated_x)# 使用 repeat() 進行更復雜的復制
# 維度0重復2次,維度1重復3次
complex_repeated_x = x.repeat(2, 3)
print("\nRepeated x (2, 3) shape:", complex_repeated_x.shape) # torch.Size([6, 3])
print("Repeated x (2, 3):\n", complex_repeated_x)# expand() 無法作用于大小不為1的維度
try:# 嘗試將大小為3的維度擴展到6,會報錯x.expand(6, 4)
except RuntimeError as e:print("\nError with expand():", e)
六、組合技巧:先升維再擴展
在實際應用中,一個常見的需求是將一個一維向量復制多次,以構建一個二維矩陣。例如,將一個權重向量應用到批次中的每一個樣本上。這個操作可以通過組合“升維”和“擴展/復制”函數來高效實現。這里介紹兩種主流的組合方法:view
/expand
和 reshape
/repeat
。
方法一:view
/reshape
+ expand
(內存高效)
這個組合利用了 expand
函數不復制數據、只創建視圖的特性,是實現廣播操作的首選。
- 升維:首先,使用
view(1, -1)
或reshape(1, -1)
(或者更直觀的unsqueeze(0)
) 將一維向量(N)
變為二維的行向量(1, N)
。 - 擴展:然后,調用
expand(M, -1)
。這會將大小為 1 的第 0 維“擴展” M 次,得到一個(M, N)
的張量。-1
表示該維度的大小保持不變。
關鍵點:
- 整個過程沒有發生數據拷貝,返回的是一個共享原始數據的視圖,內存效率極高。
- 由于返回的是視圖,并且多個位置共享同一塊內存,因此不適合對結果進行就地修改(in-place modification)。
方法二:reshape
/view
+ repeat
(數據獨立)
這個組合會創建數據的完整物理副本,適用于需要一個獨立的、可修改的新張量的場景。
- 升維:與方法一相同,先將一維向量
(N)
變為(1, N)
。 - 復制:然后,調用
repeat(M, 1)
。這會將張量在第 0 維上復制 M 次,在第 1 維上復制 1 次(即不復制),最終得到一個(M, N)
的張量。
關鍵點:
repeat
會實際分配新的內存并復制數據,生成的新張量與原始張量完全獨立。- 內存開銷是
M * N
,但好處是你可以自由地修改新張量中的任何元素,而不會影響原始數據。
對比與選擇
特性 | view/reshape + expand | reshape/view + repeat |
---|---|---|
內存使用 | 高效,共享數據,不創建副本 | 消耗大,創建完整的數據副本 |
數據獨立性 | 不獨立,是原始數據的視圖 | 完全獨立,是全新的張量 |
修改數據 | 通常不應修改,可能會導致錯誤 | 可以自由、安全地修改 |
推薦場景 | 廣播、只讀操作、對內存敏感的場景 | 需要一個可修改的、數據獨立的副本時 |
使用示例
import torch# 1. 創建一個初始的一維向量
x = torch.arange(4) # tensor([0, 1, 2, 3]), Shape: [4]
num_repeats = 3
print(f"Original 1D tensor: {x}\n")# --- 方法一: reshape + expand (內存高效) ---
# 先升維 (4) -> (1, 4),再擴展 (1, 4) -> (3, 4)
x_expanded = x.reshape(1, -1).expand(num_repeats, -1)
print("--- reshape + expand ---")
print("Expanded shape:", x_expanded.shape)
print("Expanded tensor:\n", x_expanded)
# 注意:x_expanded 與 x 共享內存# --- 方法二: reshape + repeat (數據獨立) ---
# 先升維 (4) -> (1, 4),再復制 (1, 4) -> (3, 4)
x_repeated = x.reshape(1, -1).repeat(num_repeats, 1)
print("\n--- reshape + repeat ---")
print("Repeated shape:", x_repeated.shape)
print("Repeated tensor:\n", x_repeated)# 驗證數據獨立性
# 修改 repeated tensor 的一個元素
x_repeated[0, 1] = 99
print("\nModified repeated tensor:\n", x_repeated)
print("Original tensor after modifying repeated:", x) # 原始張量不受影響# 嘗試修改 expanded tensor 會引發問題,因為它是一個視圖
try:x_expanded[0, 1] = 99
except RuntimeError as e:print(f"\nError modifying expanded tensor: {e}")