文章目錄
- 張量拼接操作
- 1. torch.cat 函數的使用
- 1.1. torch.cat 定義
- 1.2. 語法
- 1.3. 關鍵規則
- 1.4. 示例代碼
- 1.4.1. 沿行拼接(dim=0)
- 1.4.2. 沿列拼接(dim=1)
- 1.4.3. 高維拼接(dim=2)
- 1.5. 錯誤場景分析
- 1.5.1. 維度數不一致
- 1.5.2. 非拼接維度大小不匹配
- 1.5.3. 設備或數據類型不一致
- 1.6. 與 torch.stack 的區別
- 1.7. 高級用法
- 1.7.1. 批量拼接(Batch-wise Concatenation)
- 1.7.2. 自動廣播支持
- 1.8. 總結
- 2. torch.stack 函數的使用
- 2.1. 函數定義
- 2.2. 核心規則
- 2.3. 使用示例
- 2.4. 與 torch.cat 的對比
- 2.4. 常見錯誤與調試
- 2.5. 工程實踐技巧
- 2.7. 性能優化建議
- 2.8. 總結
張量拼接操作
1. torch.cat 函數的使用
在 PyTorch 中,torch.cat 是用于沿指定維度拼接多個張量的核心函數
1.1. torch.cat 定義
功能: 將多個張量沿指定維度(dim)拼接,生成新張量。
輸入要求:
所有輸入張量的 維度數必須相同。
非拼接維度的大小必須一致。
張量必須位于 同一設備 且 數據類型相同。
1.2. 語法
torch.cat(tensors, dim=0, *, out=None) → Tensor
參數:
tensors (sequence of Tensors):需拼接的張量序列(列表或元組)。
dim (int, optional):拼接的維度索引,默認為 0。
out (Tensor, optional):可選輸出張量。
1.3. 關鍵規則
規則 | 示例 |
---|---|
輸入張量維度數必須相同 | 不允許將 2D 張量與 3D 張量拼接 |
非拼接維度大小必須一致 | 若 dim=1,所有張量的 dim=0、dim=2 等大小必須相同 |
拼接維度大小可以不同 | 沿 dim=0 拼接形狀為 (2, 3) 和 (3, 3) 的張量,結果為 (5, 3) |
輸出維度數與輸入相同 | 輸入均為 3D 張量,輸出仍為 3D 張量 |
1.4. 示例代碼
1.4.1. 沿行拼接(dim=0)
import torchA = torch.tensor([[1, 2], [3, 4]]) # shape: (2, 2)
B = torch.tensor([[5, 6], [7, 8]]) # shape: (2, 2)
C = torch.cat([A, B], dim=0) # shape: (4, 2)
print(C)
# 輸出:
# tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
1.4.2. 沿列拼接(dim=1)
D = torch.tensor([[9], [10]]) # shape: (2, 1)
E = torch.cat([A, D], dim=1) # shape: (2, 3)
print(E)
# 輸出:
# tensor([[ 1, 2, 9],
# [ 3, 4, 10]])
1.4.3. 高維拼接(dim=2)
F = torch.randn(2, 3, 4) # shape: (2, 3, 4)
G = torch.randn(2, 3, 5) # shape: (2, 3, 5)
H = torch.cat([F, G], dim=2) # shape: (2, 3, 9)
1.5. 錯誤場景分析
1.5.1. 維度數不一致
A_2D = torch.randn(2, 3)
B_3D = torch.randn(2, 3, 4)
try:torch.cat([A_2D, B_3D], dim=0) # 報錯:維度數不同
except RuntimeError as e:print("錯誤:", e)
1.5.2. 非拼接維度大小不匹配
A = torch.randn(2, 3)
B = torch.randn(3, 3) # dim=0 大小不同
try:torch.cat([A, B], dim=1) # 報錯:非拼接維度大小不一致
except RuntimeError as e:print("錯誤:", e)
1.5.3. 設備或數據類型不一致
if torch.cuda.is_available():A_cpu = torch.randn(2, 3)B_gpu = torch.randn(2, 3).cuda()try:torch.cat([A_cpu, B_gpu], dim=0) # 報錯:設備不一致except RuntimeError as e:print("錯誤:", e)
1.6. 與 torch.stack 的區別
函數 | 輸入維度 | 輸出維度 | 核心用途 |
---|---|---|---|
torch.cat | 所有張量維度相同 | 維度數與輸入相同 | 沿現有維度擴展張量 |
torch.stack | 所有張量形狀嚴格相同 | 新增一個維度 | 創建新維度合并張量 |
示例對比:
A = torch.tensor([1, 2]) # shape: (2)
B = torch.tensor([3, 4]) # shape: (2)# cat 沿 dim=0
C_cat = torch.cat([A, B]) # shape: (4)# stack 沿 dim=0
C_stack = torch.stack([A, B]) # shape: (2, 2)
1.7. 高級用法
1.7.1. 批量拼接(Batch-wise Concatenation)
# 批量數據拼接(batch_size=2)
batch_A = torch.randn(2, 3, 4) # shape: (2, 3, 4)
batch_B = torch.randn(2, 3, 5) # shape: (2, 3, 5)
batch_C = torch.cat([batch_A, batch_B], dim=2) # shape: (2, 3, 9)
1.7.2. 自動廣播支持
torch.cat 不支持廣播,必須顯式匹配形狀:
A = torch.randn(3, 1) # shape: (3, 1)
B = torch.randn(1, 3) # shape: (1, 3)
try:torch.cat([A, B], dim=1) # 報錯:非拼接維度大小不一致
except RuntimeError as e:print("錯誤:", e)
1.8. 總結
適用場景:合并同維度的特征、批量數據拼接等。
核心規則:
1、輸入張量維度數相同。2、非拼接維度大小嚴格一致。3、設備與數據類型一致。
優先使用 torch.cat:當需要在現有維度擴展時;需新增維度時選擇 torch.stack。
2. torch.stack 函數的使用
2.1. 函數定義
torch.stack(tensors, dim=0, *, out=None) → Tensor
功能:將多個張量沿新維度堆疊(非拼接),要求所有輸入張量形狀嚴格相同。
- 輸入:
- tensors (sequence of Tensors):形狀相同的張量序列(列表/元組)。
- dim (int):新維度的插入位置(支持負數索引)。
- 輸出:
- 比輸入張量多一維的新張量。
2.2. 核心規則
規則 | 示例 |
---|---|
輸入張量形狀必須完全相同 | (3, 4) 只能與 (3, 4) 堆疊,不能與 (3, 5) 堆疊 |
輸出維度 = 輸入維度 + 1 | 輸入(3, 4) → 輸出 (n, 3, 4)(n為堆疊數量) |
新維度大小 = 張量數量 | 堆疊3個張量 → 新維度大小為3 |
設備/數據類型必須一致 | 所有張量需在同一設備(CPU/GPU)且 dtype 相同 |
2.3. 使用示例
(1) 基礎用法
import torch
# 定義兩個相同形狀的張量
A = torch.tensor([1, 2, 3]) # shape: (3,)
B = torch.tensor([4, 5, 6]) # shape: (3,)# 沿新維度0堆疊
C = torch.stack([A, B]) # shape: (2, 3)
print(C)
# tensor([[1, 2, 3],
# [4, 5, 6]])# 沿新維度1堆疊
D = torch.stack([A, B], dim=1) # shape: (3, 2)
print(D)
# tensor([[1, 4],
# [2, 5],
# [3, 6]])
(2) 高維張量堆疊
# 形狀為 (2, 3) 的張量
X = torch.randn(2, 3)
Y = torch.randn(2, 3)# 沿dim=0堆疊(新增最外層維度)
Z0 = torch.stack([X, Y]) # shape: (2, 2, 3)# 沿dim=1堆疊(插入到第二維)
Z1 = torch.stack([X, Y], dim=1) # shape: (2, 2, 3)# 沿dim=-1堆疊(插入到最后一維)
Z2 = torch.stack([X, Y], dim=-1) # shape: (2, 3, 2)
(3) 批量數據構建
# 模擬批量圖像數據(單張圖像shape: (3, 32, 32))
image1 = torch.randn(3, 32, 32)
image2 = torch.randn(3, 32, 32)
image3 = torch.randn(3, 32, 32)# 構建batch維度(batch_size=3)
batch = torch.stack([image1, image2, image3]) # shape: (3, 3, 32, 32)
2.4. 與 torch.cat 的對比
特性 torch.stack torch.cat
輸入要求 所有張量形狀嚴格相同 僅需非拼接維度相同
輸出維度 比輸入多1維 與輸入維度相同
內存開銷 更高(新增維度) 更低(復用現有維度)
典型場景 構建batch、新增序列維度 合并特征、擴展現有維度
示例對比:
A = torch.tensor([1, 2])
B = torch.tensor([3, 4])# stack -> 新增維度
stacked = torch.stack([A, B]) # shape: (2, 2)# cat -> 沿現有維度擴展
concatenated = torch.cat([A, B]) # shape: (4)
2.4. 常見錯誤與調試
(1) 形狀不匹配
A = torch.randn(2, 3)
B = torch.randn(2, 4) # 第二維不同
try:torch.stack([A, B])
except RuntimeError as e:print("Error:", e) # Sizes of tensors must match
(2) 設備不一致
A_cpu = torch.randn(3, 4)
B_gpu = torch.randn(3, 4).cuda()
try:torch.stack([A_cpu, B_gpu])
except RuntimeError as e:print("Error:", e) # Expected all tensors to be on the same device
(3) 空張量處理
empty_tensors = [torch.tensor([]) for _ in range(3)]
try:torch.stack(empty_tensors) # 可能引發未定義行為
except RuntimeError as e:print("Error:", e)
2.5. 工程實踐技巧
(1) 批量數據預處理
# 從數據加載器中逐批讀取數據并堆疊
batch_images = []
for image in dataloader:batch_images.append(image)if len(batch_images) == batch_size:batch = torch.stack(batch_images) # shape: (batch_size, C, H, W)process_batch(batch)batch_images = []
(2) 序列建模中的時間步堆疊
# RNN輸入序列構建(T個時間步,每個步長特征dim=D)
time_steps = [torch.randn(1, D) for _ in range(T)]
input_seq = torch.stack(time_steps, dim=1) # shape: (1, T, D)
(3) 多任務輸出合并
# 多任務學習中的輸出堆疊
task1_out = torch.randn(batch_size, 10)
task2_out = torch.randn(batch_size, 5)
multi_out = torch.stack([task1_out, task2_out], dim=1) # shape: (batch_size, 2, ...)
2.7. 性能優化建議
避免循環中頻繁堆疊:優先在內存中收集所有張量后一次性堆疊。
# 低效做法
result = None
for x in data_stream:if result is None:result = x.unsqueeze(0)else:result = torch.stack([result, x.unsqueeze(0)])# 高效做法
tensor_list = [x for x in data_stream]
result = torch.stack(tensor_list)
顯存不足時考慮分塊處理:
chunk_size = 1000
for i in range(0, len(big_list), chunk_size):chunk = torch.stack(big_list[i:i+chunk_size])process(chunk)
2.8. 總結
核心用途:構建batch、新增維度、多任務輸出整合。
關鍵檢查點:
- 輸入張量形狀完全一致。
- 設備與數據類型統一。
- 合理選擇 dim 參數控制維度擴展位置。
優先選擇場景:當需要顯式創建新維度時使用;若僅需擴展現有維度,用 torch.cat 更高效。