unsqueeze(-1)
?是 PyTorch 中的一個張量操作,用于?在指定維度上增加一個長度為1的維度?(即擴展維度)。具體解析如下:
功能說明
-
?作用位置?
-1
?表示在張量的?最后一個維度?后添加新維度。
(等價于?dim=len(tensor.shape)
) -
?輸入輸出對比?
- 假設原張量?
train_X
?形狀為?(N,)
(一維向量) - 執行后形狀變為?
(N, 1)
(二維矩陣)
- 假設原張量?
-
?典型用途?
- 適配神經網絡層輸入要求(如全連接層需要二維輸入)
- 廣播機制(Broadcasting)前的維度對齊
- 處理單通道數據(如時間序列、灰度圖像)
示例演示
import torch# 原始數據(一維張量)
data = torch.tensor([1, 2, 3]) # shape: (3,)# 添加維度后
expanded = data.unsqueeze(-1) # shape: (3, 1)
print(expanded)
輸出:
tensor([[1],[2],[3]])
其他等價寫法
unsqueeze(1)
:當輸入為一維時效果與?unsqueeze(-1)
?相同data[:, None]
:Python 切片語法實現相同功能