torch.cat
?和?torch.stack
?是 PyTorch 中用于組合張量的兩個常用函數,它們的核心區別在于輸入張量的維度和輸出張量的維度變化。以下是詳細對比:
1.?torch.cat?(Concatenate)
- 作用:沿現有維度拼接多個張量,不創建新維度
-
輸入要求:所有張量的形狀必須除拼接維度外完全相同。
-
語法:
torch.cat(tensors, dim=0) # dim 指定拼接的維度
-
示例:
a = torch.tensor([[1, 2], [3, 4]]) # shape (2, 2) b = torch.tensor([[5, 6]]) # shape (1, 2)# 沿 dim=0 拼接(行方向) c = torch.cat([a, b], dim=0) print(c) # tensor([[1, 2], # [3, 4], # [5, 6]]) # shape (3, 2)
- 特點:
-
拼接后的張量在指定維度上的大小是輸入張量該維度大小的總和。
-
其他維度必須完全一致。
-
2. torch.stack
-
作用:沿新維度堆疊多個張量,創建新維度。
-
輸入要求:所有張量的形狀必須完全相同。
-
語法:
torch.stack(tensors, dim=0) # dim 指定新維度的位置
-
示例:
a = torch.tensor([1, 2]) # shape (2,) b = torch.tensor([3, 4]) # shape (2,)# 沿新維度 dim=0 堆疊 c = torch.stack([a, b], dim=0) print(c) # tensor([[1, 2], # [3, 4]]) # shape (2, 2)# 沿新維度 dim=1 堆疊 d = torch.stack([a, b], dim=1) print(d) # tensor([[1, 3], # [2, 4]]) # shape (2, 2)
-
特點:
-
輸出張量比輸入張量多一個維度。
-
適用于將多個相同形狀的張量合并為批次(如?
batch_size
?維度)。
-
3. 關鍵區別總結
4. 直觀對比示例
假設有兩個張量:
x = torch.tensor([1, 2]) # shape (2,)
y = torch.tensor([3, 4]) # shape (2,)
torch.cat
?結果:
torch.cat([x, y], dim=0) # tensor([1, 2, 3, 4]), shape (4,)
torch.stack
?結果:
torch.stack([x, y], dim=0) # tensor([[1, 2], [3, 4]]), shape (2, 2)
5. 如何選擇?
-
用?
torch.cat
?當需要擴展現有維度(如拼接多個特征圖)。 -
用?
torch.stack
?當需要創建新維度(如構建批次數據或堆疊不同模型的輸出)
通過理解兩者的維度變化邏輯,可以避免常見的形狀錯誤(如?size mismatch
)。?