pytorch小記(十七):PyTorch 中的 `expand` 與 `repeat`:詳解廣播機制與復制行為(附詳細示例)
- 🚀 PyTorch 中的 `expand` 與 `repeat`:詳解廣播機制與復制行為(附詳細示例)
- 🔍 一、基礎定義
- 1. `tensor.expand(*sizes)`
- 2. `tensor.repeat(*sizes)`
- 📌 二、維度行為詳解
- 使用 `expand`
- 使用 `repeat`
- ?? 三、重點報錯案例解釋
- 📌 示例 1:`expand(1, 4)` 報錯
- ? 示例 2:`expand(2, 4)` 正確
- 🔁 四、repeat 的多種使用場景舉例
- 🔍 五、輸入維度對 `expand` 和 `repeat` 的影響總結
- 🎯 六、常見錯誤總結
- ? 七、維度補齊技巧
- 🎓 八、結語:如何選擇?
- 問題
- 1. PyTorch 自動**廣播一維 tensor**
- 2. 和二維 `[1, 2, 3]` 效果一樣?
- 🔎 為什么以前會報錯?
- 📌 總結規律(適用于新版本 PyTorch)
🚀 PyTorch 中的 expand
與 repeat
:詳解廣播機制與復制行為(附詳細示例)
在使用 PyTorch 構建神經網絡時,經常會遇到不同維度張量需要對齊的問題,expand()
和 repeat()
就是兩種非常常用的方式來處理張量的形狀變化。本博客將詳細解釋兩者的區別、作用、使用規則以及典型的報錯原因,配合實際例子,幫助你深入理解廣播機制。
🔍 一、基礎定義
1. tensor.expand(*sizes)
- 功能:沿指定維度進行“虛擬復制”,不占用額外內存。
- 要求:只能擴展 原始維度中為1的維度,否則會報錯。
2. tensor.repeat(*sizes)
- 功能:真正復制數據,生成新的內存區域。
- 不限制是否為1的維度,任意維度都能復制。
📌 二、維度行為詳解
以一個張量為例:
a = torch.tensor([[1], [2]]) # shape: (2, 1)
使用 expand
print(a.expand(2, 3))
結果:
tensor([[1, 1, 1],[2, 2, 2]])
- 第1維為 1,可以擴展成3列。
- 數據并沒有真實復制,只是通過 廣播機制 顯示為多列。
使用 repeat
print(a.repeat(1, 3))
結果:
tensor([[1, 1, 1],[2, 2, 2]])
- 每一行的元素真實地復制了3份,占用了新內存。
?? 三、重點報錯案例解釋
📌 示例 1:expand(1, 4)
報錯
c = torch.tensor([[7], [8]]) # shape: (2, 1)
print(c.expand(1, 4))
錯誤原因:
RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.
解釋:
- 原 tensor 的第0維是2,而你想擴展為1。
- 非1的維度不能進行expand擴展,會觸發報錯。
? 示例 2:expand(2, 4)
正確
c = torch.tensor([[7], [8]]) # shape: (2, 1)
print(c.expand(2, 4))
輸出:
tensor([[7, 7, 7, 7],[8, 8, 8, 8]])
- 第0維是2,不變 ?
- 第1維是1,被擴展為4 ?
🔁 四、repeat 的多種使用場景舉例
a = torch.tensor([[1, 2, 3]]) # shape: (1, 3)
print(a.repeat(2, 3))
輸出:
tensor([[1, 2, 3, 1, 2, 3],[1, 2, 3, 1, 2, 3]])
解釋:
(2, 3)
的含義是:行重復2次,列重復3次。- 數據真實復制!
🔍 五、輸入維度對 expand
和 repeat
的影響總結
操作 | 輸入維度形狀 | 輸入參數 | 說明 |
---|---|---|---|
expand | 必須是顯式維度 | 尺寸必須與原tensor維度數一致,且非1的維度不能變 | |
repeat | 任意形狀 | 每個維度對應復制幾次 | |
自動廣播 | 可擴展1維為任意數目 | ? | expand 底層用到 |
內存行為 | 不復制數據 | ? | expand 是 zero-copy |
內存行為 | 真正復制 | ? | repeat 用得多就要小心內存 |
🎯 六、常見錯誤總結
錯誤場景 | 示例 | 錯誤原因 |
---|---|---|
expand 維度不對 | tensor(2, 1).expand(1, 4) | 非1維度不能擴展 |
expand 維數不匹配 | tensor(2, 1).expand(4) | 參數數目與維度數不一致 |
repeat 維度數對不上 | tensor(2, 1).repeat(3) | 參數不夠,需要補齊 |
? 七、維度補齊技巧
有時原始張量的維度太少,需要先 .unsqueeze()
添加維度:
x = torch.tensor([1, 2, 3]) # shape: (3,)
x = x.unsqueeze(0) # shape: (1, 3)
x = x.expand(2, 3)
🎓 八、結語:如何選擇?
- 如果你只是想“假裝復制”以減少內存開銷 ?
expand()
- 如果你真的需要重復數據去喂模型 ?
repeat()
- 如果你想安全無腦復制 ?
repeat()
更通用但代價大 - 如果你要配合 broadcasting ?
expand()
是你的最優選擇
問題
a = torch.tensor([[1, 2, 3]]) # shape: (1, 3)
print(a.shape)
print(a.repeat(6, 4))a = torch.tensor([1, 2, 3]) # shape: (1, 3)
print(a.shape)
print(a.repeat(6, 4))
為什么維度不同但是輸出是一樣的?
1. PyTorch 自動廣播一維 tensor
在新版 PyTorch 中(大約 1.8 起),當你對 一維張量 調用 .repeat(m, n)
,PyTorch 會自動地把它當作 shape 為 (1, 3)
,然后再執行 repeat。這相當于隱式地:
a = torch.tensor([1, 2, 3]) # shape: (3,)
a = a.unsqueeze(0) # shape: (1, 3)
print(a.repeat(6, 4)) # 🔁 repeat(6, 4) 等價于 (6 rows, 12 columns)
2. 和二維 [1, 2, 3]
效果一樣?
是的。你對比的兩個 tensor:
a1 = torch.tensor([[1, 2, 3]]) # shape: (1, 3)
a2 = torch.tensor([1, 2, 3]) # shape: (3,)
print(a1.repeat(6, 4))
print(a2.repeat(6, 4)) # 現在兩者結果完全一致!
輸出都是 shape: (6, 12),值為重復的 [1, 2, 3]
:
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3],...[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]])
🔎 為什么以前會報錯?
在早期版本的 PyTorch 中(<1.8),repeat(6, 4)
要求參數個數和維度完全一致。所以對 a = torch.tensor([1,2,3])
(一維)來說,你只能:
a.repeat(6) # 正確,對一維張量
a.repeat(6, 4) # 錯誤(舊版本)
📌 總結規律(適用于新版本 PyTorch)
原始 tensor | repeat 維度 | 自動行為 | 結果 |
---|---|---|---|
[1,2,3] (1維) | repeat(6,4) | 自動 unsqueeze → (1,3) | ? |
[[1,2,3]] (2維) | repeat(6,4) | 直接 repeat | ? |
[1,2,3] (1維) | repeat(6) | 沿第0維重復 | ? |
[[1,2,3]] (2維) | repeat(6) | 報錯,維度不匹配 | ? |