PyTorch核心函數詳解:gather與where的實戰指南

PyTorch中的torch.gathertorch.where是處理張量數據的關鍵工具,前者實現基于索引的靈活數據提取,后者完成條件篩選與動態生成。本文通過典型應用場景和代碼演示,深入解析兩者的工作原理及使用技巧,幫助開發者提升數據處理的靈活性與效率。

在深度學習中,我們經常需要根據特定規則提取或生成數據。例如:

  • 從預測概率中提取Top-K類別索引
  • 根據掩碼篩選有效數據點
  • 動態生成條件化張量

torch.gathertorch.where正是解決這類問題的核心函數。下文將結合圖像處理、數據篩選等場景,詳解它們的用法與差異。
在這里插入圖片描述

一、torch.gather:基于索引的精準提取

功能描述

torch.gather(input, dim, index) 沿指定維度dim,根據index張量中的索引值,從input中提取對應元素,輸出形狀與index一致。

參數說明
  • input:源張量
  • dim:指定操作的維度
  • index:索引張量,其值必須為整數類型

核心規則

  • 索引穿透性:索引值直接映射源張量的位置,不改變維度
  • 廣播機制:當index維度小于input時,會自動廣播到匹配形狀
  • 多維索引:支持通過多維索引張量提取復雜結構的數據

應用場景與示例

場景1:圖像數據批量提取

假設需要從批量圖像中提取特定位置的像素值:

# 假設images是形狀為(2,3,3)的圖像批次 (批次大小2,通道3,分辨率3x3)
images = torch.tensor([[[1,2,3],[4,5,6],[7,8,9]],  # 第一張圖像[[10,11,12],[13,14,15],[16,17,18]]  # 第二張圖像
])# 提取所有圖像的第0行第1列像素 (shape: (2,))
pixels = torch.gather(images, dim=2, index=torch.tensor([[[0,1,0],[0,1,0]], [[0,1,0],[0,1,0]]]))
print(pixels)
# 輸出: tensor([[1, 2, 1],
#                [10, 11, 10]])
場景2:從概率分布提取Top-K結果

在NLP任務中提取預測詞ID:

logits = torch.tensor([[0.1, 0.4, 0.5], [0.3, 0.6, 0.1]])  # 2個樣本的3個類別的概率
topk_indices = logits.topk(k=2, dim=1).indices  # 獲取Top-2索引# 使用gather提取Top-2概率值
topk_probs = torch.gather(logits, dim=1, index=topk_indices)
print(topk_probs)
# 輸出:
# tensor([[0.5, 0.4],
#         [0.6, 0.3]])

二、torch.where:條件驅動的動態生成

功能描述

torch.where(condition, x, y) 根據布爾條件condition,從張量xy中選擇元素,生成與輸入同形狀的新張量。

參數說明
  • condition:布爾型張量,決定元素來源
  • x:滿足條件時選擇的元素來源
  • y:不滿足條件時選擇的元素來源

核心特性

  • 自動廣播:支持不同形狀的條件與輸入張量
  • 元素級操作:逐元素比較生成動態結果
  • 類型轉換:輸出類型由xy決定

應用場景與示例

場景1:數據清洗與過濾

篩選出溫度超過30℃且濕度低于60%的記錄:

temperature = torch.tensor([25.0, 32.5, 28.0, 35.0])
humidity = torch.tensor([55.0, 58.0, 70.0, 50.0])# 生成布爾掩碼
mask = (temperature > 30) & (humidity < 60)# 根據條件生成標簽
labels = torch.where(mask, torch.tensor("High Risk"), torch.tensor("Normal"))
print(labels)
# 輸出: tensor(['Normal', 'High Risk', 'Normal', 'Normal'], dtype=string)
場景2:圖像二值化處理

將灰度圖像轉換為二值掩碼:

gray_image = torch.tensor([[0.1, 0.8], [0.6, 0.3]], dtype=torch.float32)
threshold = 0.5# 生成二值掩碼
binary_mask = torch.where(gray_image > threshold, torch.tensor(1.0), torch.tensor(0.0))
print(binary_mask)
# 輸出:
# tensor([[0., 1.],
#         [1., 0.]])

三、函數對比與選擇指南

特性torch.gathertorch.where
核心功能基于索引精確提取元素條件驅動動態生成元素
輸入要求需顯式提供索引張量需條件張量及候選值張量
維度匹配嚴格匹配索引與源張量維度自動廣播兼容不同形狀
典型應用多維數據查詢、Top-K提取條件篩選、數據轉換、掩碼生成
性能消耗較高(涉及索引計算)較低(基于原生條件判斷)

四、綜合實戰:圖像語義分割后處理

任務需求

將模型輸出的概率圖轉換為二值掩碼,并提取連通區域標簽。

解決方案

# 假設prob_map是模型輸出的概率圖 (H,W)
prob_map = torch.rand(256, 256) > 0.5  # 二值化處理# 使用where生成掩碼
mask = torch.where(prob_map, torch.tensor(1), torch.tensor(0))# 使用gather提取連通區域標簽(假設labels是預測的類別索引)
labels = torch.randint(0, 10, (256, 256))
selected_labels = torch.gather(labels, dim=0, index=mask.nonzero(as_tuple=True)[0])

五、注意事項與最佳實踐

  1. 索引越界預防

    # 錯誤示例:索引超出范圍會導致錯誤
    valid_indices = torch.clamp(indices, min=0, max=max_dim-1)
    
  2. 類型一致性

    # 確保index張量為整型
    index = index.long()  
    
  3. 內存優化

    # 優先使用in-place操作減少顯存占用
    mask.masked_fill_(condition, value)
    

結語

torch.gathertorch.where作為PyTorch生態中的基石函數,在數據工程與模型開發中扮演著不可替代的角色。理解它們的底層邏輯與適用場景,能夠幫助您:

  • 更高效地實現復雜數據操作
  • 優化模型推理與訓練流程
  • 解決各類條件化數據處理難題

掌握這兩把利器,您將在PyTorch開發中如魚得水!

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/76965.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/76965.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/76965.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

聲學測溫度原理解釋

已知聲速&#xff0c;就可以得到溫度。 不同溫度下的勝訴不同。 25度的聲速大約346m/s 絕對溫度-273度 不同溫度下的聲速。 FPGA 通過測距雷達測溫度&#xff0c;固定測量距離&#xff0c;或者可以測出當前距離。已知距離&#xff0c;然后雷達發出聲波到接收到回波的時間&a…

【網絡篇】UDP協議的封裝分用全過程

大家好呀 我是浪前 今天講解的是網絡篇的第二章&#xff1a;UDP協議的封裝分用 我們的協議最開始是OSI七層網絡協議 這個OSI 七層網絡協議 是計算機的大佬寫的&#xff0c;但是這個協議一共有七層&#xff0c;太多了太麻煩了&#xff0c;于是我們就把這個七層網絡協議就簡化為…

spring-ai-alibaba使用Agent實現智能機票助手

示例目標是使用 Spring AI Alibaba 框架開發一個智能機票助手&#xff0c;它可以幫助消費者完成機票預定、問題解答、機票改簽、取消等動作&#xff0c;具體要求為&#xff1a; 基于 AI 大模型與用戶對話&#xff0c;理解用戶自然語言表達的需求支持多輪連續對話&#xff0c;能…

嵌入式C語言高級編程:OOP封裝、TDD測試與防御性編程實踐

一、面向對象編程(OOP) 盡管 C 語言并非面向對象編程語言&#xff0c;但借助一些編程技巧&#xff0c;也能實現面向對象編程&#xff08;OOP&#xff09;的核心特性&#xff0c;如封裝、繼承和多態。 1.1 封裝 封裝是把數據和操作數據的函數捆綁在一起&#xff0c;對外部隱藏…

藍橋杯 web 常考到的一些知識點

filter&#xff1a;filter方法創建一個新數組&#xff0c;其包含通過所提供函數實現的測試的所有元素。這個 方法不會改變原數組&#xff0c;而是返回一個新的數組。 map&#xff1a;map方法創建一個新數組&#xff0c;其結果是該數組中的每個元素都調用一個提供的函數后的 返回…

音視頻小白系統入門筆記-0

本系列筆記為博主學習李超老師課程的課堂筆記&#xff0c;僅供參閱 音視頻小白系統入門課 音視頻基礎ffmpeg原理 緒論 ffmpeg推流 ffplay/vlc拉流 使用rtmp協議 ffmpeg -i <source_path> -f flv rtmp://<rtmp_server_path> 為什么會推流失敗&#xff1f; 默認…

mysql按條件三表并聯查詢

下面為你呈現一個 MySQL 按條件三表并聯查詢的示例。假定有三個表&#xff1a;students、courses 和 enrollments&#xff0c;它們的結構和關聯如下&#xff1a; students 表&#xff1a;包含學生的基本信息&#xff0c;有 student_id 和 student_name 等字段。courses 表&…

UML之序列圖的消息

序列圖表現各參與者之間為完成某個行為而發生的交互及其時間順序&#xff0c;序列圖中的交互通過消息實現。消息是從一條生命線到另一條生命線的通信&#xff0c;它們通常是水平或傾斜向下的箭頭&#xff0c;從發送方生命線離開&#xff0c;到達接收方生命線。如果需要&#xf…

UniAD:自動駕駛的統一架構 - 創新與挑戰并存

引言 自動駕駛技術正經歷一場架構革命。傳統上&#xff0c;自動駕駛系統采用模塊化設計&#xff0c;將感知、預測和規劃分離為獨立組件。而上海人工智能實驗室的OpenDriveLab團隊提出的UniAD&#xff08;Unified Autonomous Driving&#xff09;則嘗試將這些任務整合到一個統一…

如何寫好合同管理系統需求分析

引言 在當今企業數字化轉型的浪潮中&#xff0c;合同管理系統作為企業法律合規和商業運營的重要支撐工具&#xff0c;其需求分析的準確性和完整性直接關系到系統建設的成敗。本文基于Volere需求過程方法論&#xff0c;結合江鈴汽車集團合同管理系統需求規格說明書實踐案例&…

libevent服務器附帶qt界面開發(附帶源碼)

本章是入門章節&#xff0c;講解如何實現一個附帶界面的服務器&#xff0c;后續會完善與優化 使用qt編譯libevent源碼演示視頻qt的一些知識 1.主要功能有登錄界面 2.基于libevent實現的服務器的業務功能 使用qt編譯libevent 下載這個&#xff0c;其他版本也可以 主要是github上…

八、自動化函數

1.元素的定位 web自動化測試的操作核心是能夠找到頁面對應的元素&#xff0c;然后才能對元素進行具體的操作。 常見的元素定位方式非常多&#xff0c;如id,classname,tagname,xpath,cssSelector 常用的主要由cssSelector和xpath 1.1 cssSelector選擇器 選擇器的功能&#x…

Web三漏洞學習(其二:sql注入)

靶場&#xff1a;NSSCTF 、云曦歷年考核題 二、sql注入 NSSCTF 【SWPUCTF 2021 新生賽】easy_sql 這題雖然之前做過&#xff0c;但為了學習sql&#xff0c;整理一下就再寫一次 打開以后是杰哥的界面 注意到html網頁標題的名稱是 “參數是wllm” 那就傳參數值試一試 首先判…

單片機非耦合業務邏輯框架

在小型單片機項目開發初期&#xff0c;由于業務邏輯相對簡單&#xff0c;我們往往較少關注程序架構層面的設計。 然而隨著項目經驗的積累&#xff0c;開發者會逐漸意識到模塊間的耦合問題&#xff1a;當功能迭代時&#xff0c;一處修改可能引發連鎖反應。 此時&#xff0c;構…

Zookeeper三臺服務器三節點集群部署(docker-compose方式)

1. 準備工作 - 服務器:3 臺服務器,IP 地址分別為 `10.10.10.11`、`10.10.10.12`、`10.10.10.13`。 - 安裝 Docker:確保每臺服務器已安裝 Docker 和 Docker Compose。 - 網絡通信:確保三臺服務器之間可以通過 IP 地址互相訪問,并開放以下端口: - `2181`:Zookeeper 客戶…

Mac關閉sip方法

Mac關閉sip方法 導航 文章目錄 Mac關閉sip方法導航完整操作流程圖詳細步驟 完整操作流程圖 這東西是我在網上搬運下來的&#xff0c;但是我在為業務實操過程中&#xff0c;根據實操情況還是有新的注意點的 詳細步驟 1.在「關于本機」-「系統報告」-「軟件」;查看SIP是否開啟…

C++| 深入剖析std::list底層實現:鏈表結構與內存管理機制

引言 std::list的底層實現基于雙向鏈表&#xff0c;其設計哲學與std::vector截然不同。本文將深入探討其節點結構、內存分配策略及迭代器實現原理&#xff0c;揭示鏈表的性能優勢和潛在代價。 1. 底層數據結構&#xff1a;雙向鏈表 每個std::list節點包含&#xff1a; 數據域…

漢諾塔問題——用貪心算法解決

目錄 一&#xff1a;起源 二&#xff1a;問題描述 三&#xff1a;規律 三&#xff1a;解決方案 遞歸算法 四&#xff1a;代碼實現 復雜度分析 一&#xff1a;起源 漢諾塔&#xff08;Tower of Hanoi&#xff09;問題起源于一個印度的古老傳說。在世界中心貝拿勒斯&#…

【Python】Python 100題 分類入門練習題 - 新手友好

Python 100題 分類入門練習題 - 新手友好篇 - 整合篇 一、數學問題題目1&#xff1a;組合數字題目2&#xff1a;利潤計算題目3&#xff1a;完全平方數題目4&#xff1a;日期天數計算題目11&#xff1a;兔子繁殖問題題目18&#xff1a;數列求和題目19&#xff1a;完數判斷題目21…

【linux】--- 進程概念

進程概念 1.認識馮諾依曼結構2. 操作系統&#xff08;Operator system)2.1 概念2.2 設計OS的目的2.3 理解操作系統2.4 如何理解管理2.5 理解系統調用和庫函數 3. 進程3.1 基本概念和基本操作3.1.1 描述進程 - PCB3.1.2 task_struct3.1.3 查看進程 3.2 進程狀態3.2.1 運行&&…