在 PyTorch 的 torch.cat 函數中,out 參數用于指定輸出張量的存儲位置。是否使用 out 參數直接影響結果的存儲方式和張量的內存行為。以下是詳細解釋:
- 不使用 out 參數(默認行為)
含義:不提供 out 參數時,torch.cat 會創建一個新的張量來存儲拼接后的結果,并返回這個新張量。
特點:
內存分配:PyTorch 會為結果分配新的內存空間。
原張量不變:輸入的原始張量(如 tensors 中的張量)不會被修改。
返回新張量:返回的張量是獨立的,與輸入張量沒有內存共享 - 使用 out 參數
含義:通過 out 參數提供一個已存在的張量,torch.cat 將直接將結果寫入該張量中,無需創建新張量。
特點:
內存復用:避免分配新內存,直接利用已有張量的內存空間。
原張量被修改:out 指定的張量會被覆蓋,其內容會被替換為拼接結果。
形狀匹配:out 張量的形狀必須與拼接后的結果完全一致,否則會報錯。
以下是關于 torch.cat
在不同 dim
下的拼接過程的公式化描述及可視化示例,用分塊矩陣的形式呈現:
1. 數學公式化描述
1.1 沿 dim=0
(行方向)拼接
假設:
- 張量 A 的形狀為 m × n m\times n m×n
- 張量 B 的形狀為 p × n p\times n p×n
- 所有張量在非拼接維度(列數 n n n)必須一致。
拼接后的張量 C 形狀為 ( m + p ) × n (m + p) \times n (m+p)×n,公式表示為: C = [ A B ] C = \begin{bmatrix}A \\B\end{bmatrix} C=[AB?]
即:
C i , j = { A i , j , 若? 1 ≤ i ≤ m , B i ? m , j , 若? m + 1 ≤ i ≤ m + p . C_{i,j} = \begin{cases} A_{i,j}, & \text{若 } 1 \leq i \leq m, \\ B_{i-m,j}, & \text{若 } m+1 \leq i \leq m+p. \end{cases} Ci,j?={Ai,j?,Bi?m,j?,?若?1≤i≤m,若?m+1≤i≤m+p.?
1.2 沿 dim=1
(水平/列方向)拼接
假設條件:
- 張量 A A A 的形狀為 m × n m \times n m×n
- 張量 B B B 的形狀為 m × p m \times p m×p
- 兩個張量在非拼接維度(行維度 m m m)上必須保持一致
拼接操作:
水平拼接后的張量 C C C 形狀為 m × ( n + p ) m \times (n + p) m×(n+p),其數學表示為:
C = [ A B ] C = \begin{bmatrix} A & B \end{bmatrix} C=[A?B?]
元素級定義:
C i , j = { A i , j 當? 1 ≤ j ≤ n B i , j ? n 當? n + 1 ≤ j ≤ n + p C_{i,j} = \begin{cases} A_{i,j} & \text{當 } 1 \leq j \leq n \\ B_{i,j-n} & \text{當 } n+1 \leq j \leq n+p \end{cases} Ci,j?={Ai,j?Bi,j?n??當?1≤j≤n當?n+1≤j≤n+p?
維度說明:
- 行維度: m m m(保持不變)
- 列維度: n + p n + p n+p( A A A 和 B B B 列數的總和)
2. 具體示例(使用數值矩陣)
張量拼接示例
示例 1:沿第 0 維度拼接 (dim=0
)
輸入張量:
- A = [ 1 2 3 4 ] ( 形狀? 2 × 2 ) A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \quad (\text{形狀 } 2 \times 2) A=[13?24?](形狀?2×2)
- B = [ 5 6 7 8 ] ( 形狀? 2 × 2 ) B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形狀 } 2 \times 2) B=[57?68?](形狀?2×2)
拼接操作:
C = concat ? ( A , B , dim = 0 ) C = \operatorname{concat}(A, B, \text{dim}=0) C=concat(A,B,dim=0)
輸出結果:
C = [ 1 2 3 4 5 6 7 8 ] ( 形狀? 4 × 2 ) C = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ \hline 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形狀 } 4 \times 2) C= ?1357?2468?? ?(形狀?4×2)
示例 2:沿第 1 維度拼接 (dim=1
)
輸入張量:
- A = [ 1 2 3 4 ] ( 形狀? 2 × 2 ) A = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \quad (\text{形狀 } 2 \times 2) A=[13?24?](形狀?2×2)
- B = [ 5 6 7 8 ] ( 形狀? 2 × 2 ) B = \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} \quad (\text{形狀 } 2 \times 2) B=[57?68?](形狀?2×2)
拼接操作:
C = concat ? ( A , B , dim = 1 ) C = \operatorname{concat}(A, B, \text{dim}=1) C=concat(A,B,dim=1)
輸出結果:
C = [ 1 2 5 6 3 4 7 8 ] ( 形狀? 2 × 4 ) C = \begin{bmatrix} 1 & 2 & 5 & 6 \\ 3 & 4 & 7 & 8 \end{bmatrix} \quad (\text{形狀 } 2 \times 4) C=[13?24?57?68?](形狀?2×4)
關鍵說明
dim=0
表示垂直拼接(沿行方向堆疊)dim=1
表示水平拼接(沿列方向連接)- 拼接維度的大小可以不同,但其他維度必須完全相同(例如
dim=1
拼接時,兩個張量的行數必須相等)
以下是用表格形式展示 dim=0
和 dim=1
的拼接結果:
沿 dim=0
拼接
初始張量 ( A ) | 初始張量 ( B ) | 拼接結果 ( C )(dim=0) |
---|---|---|
[[1, 2], | [[5, 6], | [[1, 2], |
[3, 4]] | [7, 8]] | [3, 4], |
形狀:2×2 | 形狀:2×2 | [5, 6], |
[7, 8]] | ||
形狀:4×2 |
沿 dim=1
拼接
初始張量 ( A ) | 初始張量 ( B ) | 拼接結果 ( C )(dim=1) |
---|---|---|
[[1, 2], | [[5, 6], | [[1, 2, 5, 6], |
[3, 4]] | [7, 8]] | [3, 4, 7, 8]] |
形狀:2×2 | 形狀:2×2 | 形狀:2×4 |
3. 關鍵點總結
-
維度一致性:
dim=0
:所有張量的列數(n
)必須相同。dim=1
:所有張量的行數(m
)必須相同。
-
拼接方向:
dim=0
:垂直方向拼接(行數相加)。dim=1
:水平方向拼接(列數相加)。
-
數學符號表示:
dim=0
: C = [ A B ] C = \begin{bmatrix} A \\ B \end{bmatrix} C=[AB?]dim=1
: C = [ A B ] C = \begin{bmatrix} A & B \end{bmatrix} C=[A?B?]
5. 擴展示例(多維張量)
假設張量為三維(如圖像的批處理):
A ∈ R B × C × H × W A \in \mathbb{R}^{B \times C \times H \times W} A∈RB×C×H×W
B ∈ R B ′ × C × H × W B \in \mathbb{R}^{B' \times C \times H \times W} B∈RB′×C×H×W
- 拼接
dim=0
(批處理方向):
C ∈ R ( B + B ′ ) × C × H × W C \in \mathbb{R}^{(B+B') \times C \times H \times W} C∈R(B+B′)×C×H×W
以下是三維張量拼接的示例,使用分塊矩陣的形式展示沿不同維度(dim=0
, dim=1
, dim=2
)的拼接過程:
三維張量拼接示例
假設兩個三維張量 ( A ) 和 ( B ),形狀分別為:
- A ∈ R 2 × 3 × 2 (形狀: 2 × 3 × 2 ) A \in \mathbb{R}^{2 \times 3 \times 2}(形狀:2×3×2) A∈R2×3×2(形狀:2×3×2)
- B ∈ R 1 × 3 × 2 (形狀: 1 × 3 × 2 ) B \in \mathbb{R}^{1 \times 3 \times 2} (形狀:1×3×2) B∈R1×3×2(形狀:1×3×2)
1. 沿 dim=0
拼接(擴展第一個維度)
- 拼接條件:除
dim=0
外,其他維度(3×2)必須一致。 - 拼接結果形狀: ( 2 + 1 ) × 3 × 2 = 3 × 3 × 2 (2 + 1) \times 3 \times 2 = 3 \times 3 \times 2 (2+1)×3×2=3×3×2
- 數學表示:
C = [ A 1 A 2 B 1 ] C = \begin{bmatrix} A_{1} \\ A_{2} \\ B_{1} \end{bmatrix} C= ?A1?A2?B1?? ?
其中 A 1 , A 2 A_1, A_2 A1?,A2? 是 A A A的兩個“塊”, B 1 B_1 B1?是 B B B的唯一塊。
數值示例:
-
輸入張量:
- A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= ? ?135?246? ?,? ?7911?81012? ?? ?
- B = [ [ 13 14 15 16 17 18 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 \\ 15 & 16 \\ 17 & 18 \end{bmatrix} \end{bmatrix} B= ? ?131517?141618? ?? ?
-
拼接結果(
dim=0
):
C = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] , [ 13 14 15 16 17 18 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix}, & \begin{bmatrix} 13 & 14 \\ 15 & 16 \\ 17 & 18 \end{bmatrix} \end{bmatrix} C= ? ?135?246? ?,? ?7911?81012? ?,? ?131517?141618? ?? ?
形狀: 3 × 3 × 2 3 \times 3 \times 2 3×3×2
2. 沿 dim=1
拼接(擴展第二個維度)
- 拼接條件:除
dim=1
外,其他維度(2×2)必須一致。 - 假設調整后的張量形狀:
- A ∈ R 2 × 3 × 2 A \in \mathbb{R}^{2 \times 3 \times 2} A∈R2×3×2
- B ∈ R 2 × 2 × 2 B \in \mathbb{R}^{2 \times 2 \times 2} B∈R2×2×2
- A ∈ R 2 × 3 × 2 A \in \mathbb{R}^{2 \times 3 \times 2} A∈R2×3×2
- 拼接結果形狀: 2 × ( 3 + 2 ) × 2 = 2 × 5 × 2 2 \times (3 + 2) \times 2 = 2 \times 5 \times 2 2×(3+2)×2=2×5×2
- 數學表示:
C = [ A 1 B 1 A 2 B 2 ] C = \begin{bmatrix} A_{1} & B_{1} \\ A_{2} & B_{2} \end{bmatrix} C=[A1?A2??B1?B2??]
其中 ( A 1 , A 2 ) ( A_1, A_2 ) (A1?,A2?) 是 ( A ) ( A ) (A) 的塊, ( B 1 , B 2 ) ( B_1, B_2 ) (B1?,B2?) 是 B B B 的塊。
數值示例:
-
輸入張量:
- A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= ? ?135?246? ?,? ?7911?81012? ?? ?- B = [ [ 13 14 15 16 ] , [ 17 18 19 20 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}, & \begin{bmatrix} 17 & 18 \\ 19 & 20 \end{bmatrix} \end{bmatrix} B=[[1315?1416?],?[1719?1820?]?]
-
拼接結果(
dim=1
):
C = [ [ 1 2 3 4 5 6 ] [ 13 14 15 16 ] , [ 7 8 9 10 11 12 ] [ 17 18 19 20 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} \quad \begin{bmatrix} 13 & 14 \\ 15 & 16 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \quad \begin{bmatrix} 17 & 18 \\ 19 & 20 \end{bmatrix} \end{bmatrix} C= ? ?135?246? ?[1315?1416?],? ?7911?81012? ?[1719?1820?]? ?
形狀: 2 × 5 × 2 2 \times 5 \times 2 2×5×2
3. 沿 dim=2
拼接(擴展第三個維度)
- 拼接條件:除
dim=2
外,其他維度(2×3)必須一致。 - 假設調整后的張量形狀:
- A ∈ R 2 × 3 × 2 A \in \mathbb{R}^{2 \times 3 \times 2} A∈R2×3×2
- B ∈ R 2 × 3 × 3 B \in \mathbb{R}^{2 \times 3 \times 3} B∈R2×3×3
- 拼接結果形狀: ( 2 × 3 × ( 2 + 3 ) = 2 × 3 × 5 ) ( 2 \times 3 \times (2 + 3) = 2 \times 3 \times 5 ) (2×3×(2+3)=2×3×5)
- 數學表示:
C = [ A 1 B 1 A 2 B 2 ] C = \begin{bmatrix} A_{1} & B_{1} \\ A_{2} & B_{2} \end{bmatrix} C=[A1?A2??B1?B2??]
其中每個塊在第三個維度上拼接。
數值示例:
-
輸入張量:
- A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= ? ?135?246? ?,? ?7911?81012? ?? ?
- B = [ [ 13 14 15 16 17 18 19 20 21 ] , [ 22 23 24 25 26 27 28 29 30 ] ] B = \begin{bmatrix} \begin{bmatrix} 13 & 14 & 15 \\ 16 & 17 & 18 \\ 19 & 20 & 21 \end{bmatrix}, & \begin{bmatrix} 22 & 23 & 24 \\ 25 & 26 & 27 \\ 28 & 29 & 30 \end{bmatrix} \end{bmatrix} B= ? ?131619?141720?151821? ?,? ?222528?232629?242730? ?? ?
- A = [ [ 1 2 3 4 5 6 ] , [ 7 8 9 10 11 12 ] ] A = \begin{bmatrix} \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}, & \begin{bmatrix} 7 & 8 \\ 9 & 10 \\ 11 & 12 \end{bmatrix} \end{bmatrix} A= ? ?135?246? ?,? ?7911?81012? ?? ?
-
拼接結果(
dim=2
):
C = [ [ 1 2 13 14 15 3 4 16 17 18 5 6 19 20 21 ] , [ 7 8 22 23 24 9 10 25 26 27 11 12 28 29 30 ] ] C = \begin{bmatrix} \begin{bmatrix} 1 & 2 & 13 & 14 & 15 \\ 3 & 4 & 16 & 17 & 18 \\ 5 & 6 & 19 & 20 & 21 \end{bmatrix}, & \begin{bmatrix} 7 & 8 & 22 & 23 & 24 \\ 9 & 10 & 25 & 26 & 27 \\ 11 & 12 & 28 & 29 & 30 \end{bmatrix} \end{bmatrix} C= ? ?135?246?131619?141720?151821? ?,? ?7911?81012?222528?232629?242730? ?? ?
形狀:$2 \times 3 \times 5 )
關鍵點總結
-
維度擴展方向:
dim=0
:增加第一個維度的大小(如批處理大小)。dim=1
:增加第二個維度的大小(如通道數或行數)。dim=2
:增加第三個維度的大小(如列數或深度)。
-
形狀一致性:
- 所有輸入張量在非拼接維度的形狀必須完全一致。
-
應用場景:
dim=0
:合并不同批次的圖像數據。dim=1
:在通道維度拼接特征圖(如圖像處理中的多模態數據)。dim=2
:擴展特征的維度(如時間序列中的時間步)。
通過上述示例和表格,可以直觀理解三維張量在不同維度上的拼接邏輯。