PyTorch使用(4)-張量拼接操作

文章目錄

  • 張量拼接操作
  • 1. torch.cat 函數的使用
      • 1.1. torch.cat 定義
      • 1.2. 語法
      • 1.3. 關鍵規則
    • 1.4. 示例代碼
      • 1.4.1. 沿行拼接(dim=0)
      • 1.4.2. 沿列拼接(dim=1)
      • 1.4.3. 高維拼接(dim=2)
    • 1.5. 錯誤場景分析
      • 1.5.1. 維度數不一致
      • 1.5.2. 非拼接維度大小不匹配
      • 1.5.3. 設備或數據類型不一致
      • 1.6. 與 torch.stack 的區別
    • 1.7. 高級用法
      • 1.7.1. 批量拼接(Batch-wise Concatenation)
      • 1.7.2. 自動廣播支持
    • 1.8. 總結
  • 2. torch.stack 函數的使用
    • 2.1. 函數定義
    • 2.2. 核心規則
    • 2.3. 使用示例
    • 2.4. 與 torch.cat 的對比
    • 2.4. 常見錯誤與調試
    • 2.5. 工程實踐技巧
    • 2.7. 性能優化建議
    • 2.8. 總結

張量拼接操作

1. torch.cat 函數的使用

在 PyTorch 中,torch.cat 是用于沿指定維度拼接多個張量的核心函數

1.1. torch.cat 定義

功能: 將多個張量沿指定維度(dim)拼接,生成新張量。

輸入要求:

所有輸入張量的 維度數必須相同。

非拼接維度的大小必須一致。

張量必須位于 同一設備 且 數據類型相同。

1.2. 語法

torch.cat(tensors, dim=0, *, out=None) → Tensor

參數:

tensors (sequence of Tensors):需拼接的張量序列(列表或元組)。

dim (int, optional):拼接的維度索引,默認為 0。

out (Tensor, optional):可選輸出張量。

1.3. 關鍵規則

規則示例
輸入張量維度數必須相同不允許將 2D 張量與 3D 張量拼接
非拼接維度大小必須一致若 dim=1,所有張量的 dim=0、dim=2 等大小必須相同
拼接維度大小可以不同沿 dim=0 拼接形狀為 (2, 3) 和 (3, 3) 的張量,結果為 (5, 3)
輸出維度數與輸入相同輸入均為 3D 張量,輸出仍為 3D 張量

1.4. 示例代碼

1.4.1. 沿行拼接(dim=0)

import torchA = torch.tensor([[1, 2], [3, 4]])    # shape: (2, 2)
B = torch.tensor([[5, 6], [7, 8]])    # shape: (2, 2)
C = torch.cat([A, B], dim=0)          # shape: (4, 2)
print(C)
# 輸出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6],
#         [7, 8]])

1.4.2. 沿列拼接(dim=1)

D = torch.tensor([[9], [10]])          # shape: (2, 1)
E = torch.cat([A, D], dim=1)          # shape: (2, 3)
print(E)
# 輸出:
# tensor([[ 1,  2,  9],
#         [ 3,  4, 10]])

1.4.3. 高維拼接(dim=2)

F = torch.randn(2, 3, 4)              # shape: (2, 3, 4)
G = torch.randn(2, 3, 5)              # shape: (2, 3, 5)
H = torch.cat([F, G], dim=2)          # shape: (2, 3, 9)

1.5. 錯誤場景分析

1.5.1. 維度數不一致

A_2D = torch.randn(2, 3)
B_3D = torch.randn(2, 3, 4)
try:torch.cat([A_2D, B_3D], dim=0)  # 報錯:維度數不同
except RuntimeError as e:print("錯誤:", e)

1.5.2. 非拼接維度大小不匹配

A = torch.randn(2, 3)
B = torch.randn(3, 3)              # dim=0 大小不同
try:torch.cat([A, B], dim=1)       # 報錯:非拼接維度大小不一致
except RuntimeError as e:print("錯誤:", e)

1.5.3. 設備或數據類型不一致

if torch.cuda.is_available():A_cpu = torch.randn(2, 3)B_gpu = torch.randn(2, 3).cuda()try:torch.cat([A_cpu, B_gpu], dim=0)  # 報錯:設備不一致except RuntimeError as e:print("錯誤:", e)

1.6. 與 torch.stack 的區別

函數輸入維度輸出維度核心用途
torch.cat所有張量維度相同維度數與輸入相同沿現有維度擴展張量
torch.stack所有張量形狀嚴格相同新增一個維度創建新維度合并張量

示例對比:

A = torch.tensor([1, 2])          # shape: (2)
B = torch.tensor([3, 4])          # shape: (2)# cat 沿 dim=0
C_cat = torch.cat([A, B])         # shape: (4)# stack 沿 dim=0
C_stack = torch.stack([A, B])     # shape: (2, 2)

1.7. 高級用法

1.7.1. 批量拼接(Batch-wise Concatenation)

# 批量數據拼接(batch_size=2)
batch_A = torch.randn(2, 3, 4)    # shape: (2, 3, 4)
batch_B = torch.randn(2, 3, 5)    # shape: (2, 3, 5)
batch_C = torch.cat([batch_A, batch_B], dim=2)  # shape: (2, 3, 9)

1.7.2. 自動廣播支持

torch.cat 不支持廣播,必須顯式匹配形狀:

A = torch.randn(3, 1)            # shape: (3, 1)
B = torch.randn(1, 3)            # shape: (1, 3)
try:torch.cat([A, B], dim=1)     # 報錯:非拼接維度大小不一致
except RuntimeError as e:print("錯誤:", e)

1.8. 總結

適用場景:合并同維度的特征、批量數據拼接等。

核心規則

1、輸入張量維度數相同。2、非拼接維度大小嚴格一致。3、設備與數據類型一致。

優先使用 torch.cat:當需要在現有維度擴展時;需新增維度時選擇 torch.stack。

2. torch.stack 函數的使用

2.1. 函數定義

torch.stack(tensors, dim=0, *, out=None) → Tensor

功能:將多個張量沿新維度堆疊(非拼接),要求所有輸入張量形狀嚴格相同。

  • 輸入:
    • tensors (sequence of Tensors):形狀相同的張量序列(列表/元組)。
    • dim (int):新維度的插入位置(支持負數索引)。
  • 輸出:
    • 比輸入張量多一維的新張量。

2.2. 核心規則

規則示例
輸入張量形狀必須完全相同(3, 4) 只能與 (3, 4) 堆疊,不能與 (3, 5) 堆疊
輸出維度 = 輸入維度 + 1輸入(3, 4) → 輸出 (n, 3, 4)(n為堆疊數量)
新維度大小 = 張量數量堆疊3個張量 → 新維度大小為3
設備/數據類型必須一致所有張量需在同一設備(CPU/GPU)且 dtype 相同

2.3. 使用示例

(1) 基礎用法

import torch
# 定義兩個相同形狀的張量
A = torch.tensor([1, 2, 3])      # shape: (3,)
B = torch.tensor([4, 5, 6])      # shape: (3,)# 沿新維度0堆疊
C = torch.stack([A, B])          # shape: (2, 3)
print(C)
# tensor([[1, 2, 3],
#         [4, 5, 6]])# 沿新維度1堆疊
D = torch.stack([A, B], dim=1)   # shape: (3, 2)
print(D)
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

(2) 高維張量堆疊

# 形狀為 (2, 3) 的張量
X = torch.randn(2, 3)
Y = torch.randn(2, 3)# 沿dim=0堆疊(新增最外層維度)
Z0 = torch.stack([X, Y])         # shape: (2, 2, 3)# 沿dim=1堆疊(插入到第二維)
Z1 = torch.stack([X, Y], dim=1)  # shape: (2, 2, 3)# 沿dim=-1堆疊(插入到最后一維)
Z2 = torch.stack([X, Y], dim=-1) # shape: (2, 3, 2)

(3) 批量數據構建

# 模擬批量圖像數據(單張圖像shape: (3, 32, 32))
image1 = torch.randn(3, 32, 32)
image2 = torch.randn(3, 32, 32)
image3 = torch.randn(3, 32, 32)# 構建batch維度(batch_size=3)
batch = torch.stack([image1, image2, image3])  # shape: (3, 3, 32, 32)

2.4. 與 torch.cat 的對比

特性 torch.stack torch.cat
輸入要求 所有張量形狀嚴格相同 僅需非拼接維度相同
輸出維度 比輸入多1維 與輸入維度相同
內存開銷 更高(新增維度) 更低(復用現有維度)
典型場景 構建batch、新增序列維度 合并特征、擴展現有維度
示例對比:

A = torch.tensor([1, 2])
B = torch.tensor([3, 4])# stack -> 新增維度
stacked = torch.stack([A, B])    # shape: (2, 2)# cat -> 沿現有維度擴展
concatenated = torch.cat([A, B]) # shape: (4)

2.4. 常見錯誤與調試

(1) 形狀不匹配

A = torch.randn(2, 3)
B = torch.randn(2, 4)  # 第二維不同
try:torch.stack([A, B])
except RuntimeError as e:print("Error:", e)  # Sizes of tensors must match

(2) 設備不一致

A_cpu = torch.randn(3, 4)
B_gpu = torch.randn(3, 4).cuda()
try:torch.stack([A_cpu, B_gpu])
except RuntimeError as e:print("Error:", e)  # Expected all tensors to be on the same device

(3) 空張量處理

empty_tensors = [torch.tensor([]) for _ in range(3)]
try:torch.stack(empty_tensors)  # 可能引發未定義行為
except RuntimeError as e:print("Error:", e)

2.5. 工程實踐技巧

(1) 批量數據預處理

# 從數據加載器中逐批讀取數據并堆疊
batch_images = []
for image in dataloader:batch_images.append(image)if len(batch_images) == batch_size:batch = torch.stack(batch_images)  # shape: (batch_size, C, H, W)process_batch(batch)batch_images = []

(2) 序列建模中的時間步堆疊

# RNN輸入序列構建(T個時間步,每個步長特征dim=D)
time_steps = [torch.randn(1, D) for _ in range(T)]
input_seq = torch.stack(time_steps, dim=1)  # shape: (1, T, D)

(3) 多任務輸出合并

# 多任務學習中的輸出堆疊
task1_out = torch.randn(batch_size, 10)
task2_out = torch.randn(batch_size, 5)
multi_out = torch.stack([task1_out, task2_out], dim=1)  # shape: (batch_size, 2, ...)

2.7. 性能優化建議

避免循環中頻繁堆疊:優先在內存中收集所有張量后一次性堆疊。

# 低效做法
result = None
for x in data_stream:if result is None:result = x.unsqueeze(0)else:result = torch.stack([result, x.unsqueeze(0)])# 高效做法
tensor_list = [x for x in data_stream]
result = torch.stack(tensor_list)

顯存不足時考慮分塊處理:

chunk_size = 1000
for i in range(0, len(big_list), chunk_size):chunk = torch.stack(big_list[i:i+chunk_size])process(chunk)

2.8. 總結

核心用途:構建batch、新增維度、多任務輸出整合。

關鍵檢查點:

  • 輸入張量形狀完全一致。
  • 設備與數據類型統一。
  • 合理選擇 dim 參數控制維度擴展位置。

優先選擇場景:當需要顯式創建新維度時使用;若僅需擴展現有維度,用 torch.cat 更高效。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/77060.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/77060.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/77060.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

linux命令之yes(Linux Command Yes)

linux命令之yes 簡介與功能 yes 命令在 Linux 系統中用于重復輸出一行字符串,直到被殺死(kill)。該命令最常見的用途是自動化控制腳本中的交互式命令,以便無需用戶介入即可進行連續的確認操作。 用法示例 基本用法非常簡單&am…

《算法筆記》10.3小節——圖算法專題->圖的遍歷 問題 B: 連通圖

題目描述 給定一個無向圖和其中的所有邊&#xff0c;判斷這個圖是否所有頂點都是連通的。 輸入 每組數據的第一行是兩個整數 n 和 m&#xff08;0<n<1000&#xff09;。n 表示圖的頂點數目&#xff0c;m 表示圖中邊的數目。如果 n 為 0 表示輸入結束。隨后有 m 行數據…

使用Prometheus監控systemd服務并可視化

實訓背景 你是一家企業的運維工程師&#xff0c;需將服務器的systemd服務監控集成到Prometheus&#xff0c;并通過Grafana展示實時數據。需求如下&#xff1a; 數據采集&#xff1a;監控所有systemd服務的狀態&#xff08;運行/停止&#xff09;、資源占用&#xff08;CPU、內…

OpenCV--圖像邊緣檢測

在計算機視覺和圖像處理領域&#xff0c;邊緣檢測是極為關鍵的技術。邊緣作為圖像中像素值發生急劇變化的區域&#xff0c;承載了圖像的重要結構信息&#xff0c;在物體識別、圖像分割、目標跟蹤等眾多應用場景中發揮著核心作用。OpenCV 作為強大的計算機視覺庫&#xff0c;提供…

Rollup詳解

Rollup 是一個 JavaScript 模塊打包工具&#xff0c;專注于 ES 模塊的打包&#xff0c;常用于打包 JavaScript 庫。下面從它的工作原理、特點、使用場景、配置和與其他打包工具對比等方面進行詳細講解。 一、 工作原理 Rollup 的核心工作是分析代碼中的 import 和 export 語句…

Chapter 7: Compiling C++ Sources with CMake_《Modern CMake for C++》_Notes

Chapter 7: Compiling C Sources with CMake 1. Understanding the Compilation Process Key Points: Four-stage process: Preprocessing → Compilation → Assembly → LinkingCMake abstracts low-level commands but allows granular controlToolchain configuration (c…

5分鐘上手GitHub Copilot:AI編程助手實戰指南

引言 近年來&#xff0c;AI編程工具逐漸成為開發者提升效率的利器。GitHub Copilot作為由GitHub和OpenAI聯合推出的智能代碼補全工具&#xff0c;能夠根據上下文自動生成代碼片段。本文將手把手教你如何快速安裝、配置Copilot&#xff0c;并通過實際案例展示其強大功能。 一、…

謝志輝和他的《韻之隊詩集》:探尋生活與夢想交織的詩意世界

大家好&#xff0c;我是謝志輝&#xff0c;一個扎根在文字世界&#xff0c;默默耕耘的寫作者。寫作于我而言&#xff0c;早已不是簡單的愛好&#xff0c;而是生命中不可或缺的一部分。無數個寂靜的夜晚&#xff0c;當世界陷入沉睡&#xff0c;我獨自坐在書桌前&#xff0c;伴著…

Logo語言的死鎖

Logo語言的死鎖現象研究 引言 在計算機科學中&#xff0c;死鎖是一個重要的研究課題&#xff0c;尤其是在并發編程中。它指的是兩個或多個進程因爭奪資源而造成的一種永久等待狀態。在編程語言的設計與實現中&#xff0c;如何避免死鎖成為了優化系統性能和提高程序可靠性的關…

深入理解矩陣乘積的導數:以線性回歸損失函數為例

深入理解矩陣乘積的導數&#xff1a;以線性回歸損失函數為例 在機器學習和數據分析領域&#xff0c;矩陣微積分扮演著至關重要的角色。特別是當我們涉及到優化問題&#xff0c;如最小化損失函數時&#xff0c;對矩陣表達式求導變得必不可少。本文將通過一個具體的例子——線性…

real_time_camera_audio_display_with_animation

視頻錄制 import cv2 import pyaudio import wave import threading import os import tkinter as tk from PIL import Image, ImageTk # 視頻錄制設置 VIDEO_WIDTH = 640 VIDEO_HEIGHT = 480 FPS = 20.0 VIDEO_FILENAME = _video.mp4 AUDIO_FILENAME = _audio.wav OUTPUT_…

【Pandas】pandas DataFrame astype

Pandas2.2 DataFrame Conversion 方法描述DataFrame.astype(dtype[, copy, errors])用于將 DataFrame 中的數據轉換為指定的數據類型 pandas.DataFrame.astype pandas.DataFrame.astype 是一個方法&#xff0c;用于將 DataFrame 中的數據轉換為指定的數據類型。這個方法非常…

Johnson

理論 全源最短路算法 Floyd 算法&#xff0c;時間復雜度為 O(n)跑 n 次 Bellman - Ford 算法&#xff0c;時間復雜度是 O(nm)跑 n 次 Heap - Dijkstra 算法&#xff0c;時間復雜度是 O(nmlogm) 第 3 種算法被 Johnson 做了改造&#xff0c;可以求解帶負權邊的全源最短路。 J…

Exce格式化批處理工具詳解:高效處理,讓數據更干凈!

Exce格式化批處理工具詳解&#xff1a;高效處理&#xff0c;讓數據更干凈&#xff01; 1. 概述 在數據分析、報表整理、數據庫管理等工作中&#xff0c;數據清洗是不可或缺的一步。原始Excel數據常常存在格式不統一、空值、重復數據等問題&#xff0c;影響數據的準確性和可用…

(三十七)Dart 中使用 Pub 包管理系統與 HTTP 請求教程

Dart 中使用 Pub 包管理系統與 HTTP 請求教程 Pub 包管理系統簡介 Pub 是 Dart 和 Flutter 的包管理系統&#xff0c;用于管理項目的依賴。通過 Pub&#xff0c;開發者可以輕松地添加、更新和管理第三方庫。 使用 Pub 包管理系統 1. 找到需要的庫 訪問以下網址&#xff0c…

代碼隨想錄算法訓練營第三十五天 | 416.分割等和子集

416. 分割等和子集 題目鏈接&#xff1a;416. 分割等和子集 - 力扣&#xff08;LeetCode&#xff09; 文章講解&#xff1a;代碼隨想錄 視頻講解&#xff1a;動態規劃之背包問題&#xff0c;這個包能裝滿嗎&#xff1f;| LeetCode&#xff1a;416.分割等和子集_嗶哩嗶哩_bilibi…

HTTP 教程 : 從 0 到 1 全面指南 教程【全文三萬字保姆級詳細講解】

目錄 HTTP 的請求-響應 HTTP 方法 HTTP 狀態碼 HTTP 版本 安全性 HTTP/HTTPS 簡介 HTTP HTTPS HTTP 工作原理 HTTPS 作用 HTTP 與 HTTPS 區別 HTTP 消息結構 客戶端請求消息 服務器響應消息 實例 HTTP 請求方法 各個版本定義的請求方法 HTTP/1.0 HTTP/1.1 …

spring功能匯總

1.創建一個dao接口&#xff0c;實現類&#xff1b;service接口&#xff0c;實現類并且service里用new創建對象方式調用dao的方法 2.使用spring分別獲取dao和service對象(IOC) 注意 2中的service里面獲取dao的對象方式不用new的(DI) 運行測試&#xff1a; 使用1的方式創建servic…

Vue.js 實現下載模板和導入模板、數據比對功能核心實現。

在前端開發中&#xff0c;數據比對是一個常見需求&#xff0c;尤其在資產管理等場景中。本文將基于 Vue.js 和 Element UI&#xff0c;通過一個簡化的代碼示例&#xff0c;展示如何實現“新建比對”和“開始比對”功能的核心部分。 一、功能簡介 我們將聚焦兩個核心功能&…

volatile關鍵字用途說明

volatile 關鍵字在 C# 中用于指示編譯器和運行時系統&#xff0c;某個字段可能會被多個線程同時訪問&#xff0c;并且該字段的讀寫操作不應被優化&#xff08;例如緩存到寄存器或重排序&#xff09;&#xff0c;以確保所有線程都能看到最新的值。這使得 volatile 成為一種輕量級…