在 PyTorch 中,tensor.view()
是一個常用的方法,用于改變張量(Tensor)的形狀(shape),但不會改變其數據本身。它類似于 NumPy 的 reshape()
,但有一些關鍵區別。
1. 基本用法
import torchx = torch.arange(1, 10) # shape: [9]
print(x)
# tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])# 改變形狀為 (3, 3)
y = x.view(3, 3)
print(y)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
關鍵點:
- 不改變數據,只是重新排列維度。
- 新形狀的元素數量必須與原張量一致,否則會報錯:
x.view(2, 5) # ? 錯誤!因為 2×5=10,但 x 只有 9 個元素
2. 自動推斷維度(-1 的作用)
如果不想手動計算某個維度的大小,可以用 -1
,PyTorch 會自動計算:
x = torch.arange(1, 10) # shape: [9]# 自動計算行數,確保列數是 3
y = x.view(-1, 3) # shape: [3, 3]
print(y)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])# 自動計算列數,確保行數是 3
z = x.view(3, -1) # shape: [3, 3]
print(z)
# 輸出同上
3. view()
vs reshape()
方法 | 是否共享內存 | 是否適用于非連續存儲 | 適用場景 |
---|---|---|---|
view() | ? 共享內存(修改會影響原張量) | ? 僅適用于連續存儲的張量 | 高效改變形狀(推薦優先使用) |
reshape() | ? 可能共享內存(如果可能) | ? 適用于非連續存儲 | 更通用,但可能額外復制數據 |
示例對比
x = torch.arange(1, 10) # 連續存儲# view() 可以正常工作
y = x.view(3, 3) # 如果張量不連續(如轉置后),view() 會報錯
x_transposed = x.t() # 轉置后存儲不連續
# z = x_transposed.view(9) # ? RuntimeError: view size is not compatible with input tensor's size and stride# reshape() 可以處理非連續存儲
z = x_transposed.reshape(9) # ?
4. 常見用途
(1) 展平張量(Flatten)
x = torch.randn(4, 5) # shape: [4, 5]
flattened = x.view(-1) # shape: [20]
(2) 調整 CNN 特征圖維度
# 假設 CNN 輸出是 [batch_size, channels, height, width]
features = torch.randn(32, 64, 7, 7) # shape: [32, 64, 7, 7]# 展平成 [batch_size, channels * height * width] 用于全連接層
flattened = features.view(32, -1) # shape: [32, 64*7*7] = [32, 3136]
(3) 交換維度(類似 permute
)
x = torch.randn(2, 3, 4) # shape: [2, 3, 4]
y = x.view(2, 4, 3) # shape: [2, 4, 3](相當于交換最后兩維)
5. 注意事項
view()
只適用于連續存儲的張量,否則會報錯,此時應該用reshape()
或先.contiguous()
:x_non_contiguous = x.t() # 轉置后不連續 x_contiguous = x_non_contiguous.contiguous() # 變成連續存儲 y = x_contiguous.view(...) # 現在可以用 view()
view()
返回的新張量與原張量共享內存,修改其中一個會影響另一個:x = torch.arange(1, 10) y = x.view(3, 3) y[0, 0] = 100 # 修改 y 會影響 x print(x) # tensor([100, 2, 3, 4, 5, 6, 7, 8, 9])
總結
操作 | 推薦方法 |
---|---|
改變形狀(連續張量) | view() |
改變形狀(非連續張量) | reshape() 或 .contiguous().view() |
展平張量 | x.view(-1) 或 torch.flatten(x) |
調整 CNN 特征圖維度 | features.view(batch_size, -1) |
view()
是 PyTorch 中高效調整張量形狀的首選方法,但要注意內存共享和連續性限制! 🚀
展平(Flatten)或改變形狀(如 view
、reshape
)的核心原則是保持張量的總元素個數(numel()
)不變,只是重新排列這些元素的維度。
1. 元素總數不變原則
無論原始張量是幾維的(1D、2D、3D 或更高維),轉換后的新形狀必須滿足:
原形狀的元素總數 = 新形狀的元素總數 \text{原形狀的元素總數} = \text{新形狀的元素總數} 原形狀的元素總數=新形狀的元素總數
即:
元素總數需滿足:
dim 1 × dim 2 × ? × dim n = new_dim 1 × new_dim 2 × ? × new_dim m \text{dim}_1 \times \text{dim}_2 \times \dots \times \text{dim}_n = \text{new\_dim}_1 \times \text{new\_dim}_2 \times \dots \times \text{new\_dim}_m dim1?×dim2?×?×dimn?=new_dim1?×new_dim2?×?×new_dimm?
示例:
import torchx = torch.arange(24) # 1D 張量,24 個元素
print(x.numel()) # 輸出:24# 轉換為 2D 張量:4 行 × 6 列(4×6=24)
y = x.view(4, 6) # 形狀 [4, 6]# 轉換為 3D 張量:2×3×4(2×3×4=24)
z = x.view(2, 3, 4) # 形狀 [2, 3, 4]
2. 自動推斷維度(-1
的作用)
在指定新形狀時,可以用 -1
代表“自動計算該維度大小”,PyTorch 會根據總元素數和其他已知維度推導出 -1
的值。
規則:
推斷的維度 = 總元素數 已知維度的乘積 \text{推斷的維度} = \frac{\text{總元素數}}{\text{已知維度的乘積}} 推斷的維度=已知維度的乘積總元素數?
示例:
x = torch.arange(24) # 24 個元素# 自動計算行數,確保列數為 6
y = x.view(-1, 6) # 形狀 [4, 6](因為 24/6=4)# 自動計算列數,確保行數為 3
z = x.view(3, -1) # 形狀 [3, 8](因為 24/3=8)
錯誤示例:
如果維度乘積不匹配總元素數,會報錯:
x.view(5, -1) # ? 報錯!24 無法被 5 整除
3. 展平(Flatten)的本質
展平是將任意維度的張量轉換為一維或二維的形式:
- 一維展平:
x.view(-1)
或x.flatten()
將所有元素排成一行,形狀變為[num_elements]
。 - 二維展平(保留批處理維度):
x.view(batch_size, -1)
或nn.Flatten()
保持batch_size
不變,其余維度合并為第二維,形狀變為[batch_size, features]
。
示例:
x = torch.randn(2, 3, 4) # 形狀 [2, 3, 4],總元素數=24# 一維展平
flatten_1d = x.view(-1) # 形狀 [24]# 二維展平(保留第0維 batch_size=2)
flatten_2d = x.view(2, -1) # 形狀 [2, 12](因為 3×4=12)
4. 為什么需要手動指定部分維度?
- 全連接層的輸入要求:
通常需要二維張量[batch_size, features]
,因此需明確保留batch_size
,其余維度展平。 - 避免歧義:
例如,若張量形狀為[32, 64, 7, 7]
,想展平成[32, 3136]
,需明確第二維是3136
(即64×7×7
),而-1
讓 PyTorch 自動計算。
代碼對比:
# 明確指定第二維
flatten_explicit = x.view(32, 64*7*7) # 形狀 [32, 3136]# 用 -1 自動計算
flatten_auto = x.view(32, -1) # 形狀 [32, 3136](推薦)
5. 關鍵總結
- 元素總數不變:形狀變換的本質是重新排列數據,不增刪元素。
-1
的作用:自動計算該維度大小,確保總元素數匹配。- 展平的應用場景:
- 全連接層前必須將多維特征轉換為一維向量(如 CNN 的
[batch, C, H, W]
→[batch, C*H*W]
)。 - 數據預處理時調整輸入形狀(如圖像展平為向量)。
- 全連接層前必須將多維特征轉換為一維向量(如 CNN 的