Pytorch中gather()函數詳解和實戰示例

在 PyTorch 中,torch.gather() 是一個非常實用的張量操作函數,主要用于根據索引從輸入張量中選擇特定位置的值。它常用于注意力機制、序列處理等場景。


函數定義

torch.gather(input, dim, index) → Tensor
  • input:待提取數據的張量。
  • dim:在哪個維度上進行索引選擇。
  • index:一個與 input 在除了 dim 維度外相同形狀的張量,其值指定了從 input 中提取的索引位置。
  • 返回值:從 input 的指定維度 dim 上根據 index 提取出的新張量。

形象理解

舉個簡單的例子:

示例 1:二維張量,按列(dim=1)提取

import torchinput = torch.tensor([[10, 20, 30],[40, 50, 60]])
index = torch.tensor([[2, 1, 0],[0, 1, 2]])output = torch.gather(input, dim=1, index=index)
print(output)

解釋

  • 對于第一行:從 [10, 20, 30] 中提取位置 [2,1,0],結果是 [30, 20, 10]
  • 對于第二行:從 [40, 50, 60] 中提取位置 [0,1,2],結果是 [40, 50, 60]

輸出

tensor([[30, 20, 10],[40, 50, 60]])

示例 2:按行(dim=0)提取

input = torch.tensor([[1, 2],[3, 4],[5, 6]])index = torch.tensor([[0, 1],[1, 2],[2, 0]])output = torch.gather(input, dim=0, index=index)
print(output)

解釋

  • 每個位置從第 dim=0 維度提取對應的元素。例如:

    • 第 (0,0) 位置:從 [1,3,5] 中取第 0 行,值為 1
    • 第 (1,0) 位置:從 [1,3,5] 中取第 1 行,值為 3
    • 第 (2,1) 位置:從 [2,4,6] 中取第 0 行,值為 2

輸出

tensor([[1, 4],[3, 6],[5, 2]])

應用場景

  1. 注意力機制中的權重選擇
  2. 序列解碼中的 beam search
  3. 從嵌套表示中根據索引獲取嵌套內容

實戰場景舉例

假設有一個 batch 的 BERT 輸出,想從每個句子中提取第 N 個 token(如 [CLS]、某個關鍵詞)的表示向量


假設數據

import torch
from transformers import BertModel, BertTokenizertokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")sentences = ["I love World", "Transformers are powerful"]
inputs = tokenizer(sentences, padding=True, return_tensors="pt")# 獲取 BERT 輸出
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)print(last_hidden_state.shape)
# torch.Size([2, 5, 768])  假設 padding 后為長度 5,hidden size 為 768

場景 1:提取每個句子的第一個 token(通常是 [CLS])

cls_embeddings = last_hidden_state[:, 0, :]  # shape: (batch_size, hidden_size)

這個可以直接使用切片完成,不需要 gather


場景 2:提取每個句子中 指定位置的 token 表示(如“love”或“are”)

假設我們事先知道每個句子中感興趣 token 的位置:
# 每個句子中我們想要提取的 token 索引
# 假設我們想提取第 2 個 token
token_indices = torch.tensor([2, 1])  # shape: (batch_size,)

使用 gather 抽取對應 token 的向量:

# last_hidden_state: (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = last_hidden_state.size()# 將 token_indices 轉成 index 用于 gather: shape (batch_size, 1, 1)
token_indices = token_indices.view(-1, 1, 1).expand(-1, 1, hidden_size)  # (batch_size, 1, hidden_size)# gather on dim=1(seq_len)
token_embeddings = torch.gather(last_hidden_state, dim=1, index=token_indices)  # (batch_size, 1, hidden_size)# squeeze 掉中間的維度
token_embeddings = token_embeddings.squeeze(1)  # (batch_size, hidden_size)print(token_embeddings.shape)

小結

操作需求用法
取所有句子的第一個 tokenoutput[:, 0, :]
取所有句子的第 N 個 tokenoutput[:, N, :]
取每個句子的指定 token(不同位置)torch.gather()(如上所示)

注意事項

  • index 必須與 input 的 shape 一致,除了在指定的 dim 維度上的大小。
  • index 的值必須小于 inputdim 維度上的長度。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/bicheng/85683.shtml
繁體地址,請注明出處:http://hk.pswp.cn/bicheng/85683.shtml
英文地址,請注明出處:http://en.pswp.cn/bicheng/85683.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

uniapp 微信小程序在線引入字體圖標

在線引入字體圖標,出現體驗版,真機調試字體圖標不出來,模擬器上是好的 由于字體圖標和小程序域名不在同一個,所以出現了跨域問題,將字體圖標文件放到小程序同一個域名下就好了

macOS版的節點小寶上架蘋果APP Store了

前言 前段時間很多小伙伴按照小白的教程在飛牛NAS部署了節點小寶之后,Windows的小伙伴玩得不亦樂乎! 反觀macOS用戶……因為沒有#macOS版本的節點小寶,就算是在飛牛NAS上部署了節點小寶,卻一點也開心不起來。 畢竟iOS版本的節點…

tensor向量按任意維度進行切片、拆分、組合

torch.index_select(input_tensor, 切片維度, 切片索引) 注意:切完之后,轉onnx時會生成Gather節點; torch自帶切片操作: start : end : step: 范圍前閉后開,將其放在哪個維度上,就對那個維度…

(八)Linux進程程序替換

1 進程替換 進程替換是為了讓程序能在不創建新進程的情況下&#xff0c;讓父進程和子進程執行不同的代碼&#xff0c;以實現控制清晰、執行高效的程序調度機制。 1.1 先看效果 #include <stdio.h> #include <unistd.h> int main() {printf("before:I am a p…

支持 TDengine 的數據庫管理工具—qStudio

qStudio qStudio 是一款免費的多平臺 SQL 數據分析工具&#xff0c;可以輕松瀏覽數據庫中的表、變量、函數和配置設置。最新版本 qStudio 內嵌支持 TDengine。 前置條件? 使用 qStudio 連接 TDengine 需要以下幾方面的準備工作。 安裝 qStudio。qStudio 支持主流操作系統包…

破解 VMP+OLLVM 混淆:通過 Hook jstring 快速定位加密算法入口

版權歸作者所有&#xff0c;如有轉發&#xff0c;請注明文章出處&#xff1a;https://cyrus-studio.github.io/blog/ VMP 殼 OLLVM 的加密算法 某電商APP的加密算法經過dex脫殼分析&#xff0c;找到參數加密的方法在 DuHelper.doWork 中 package com.shizhuang.duapp.common…

Automatisch:開源的工作流自動化利器

在當今數字化的時代,企業和個人都在尋找高效的方式來自動化業務流程,減少手動操作帶來的時間和成本消耗。Automatisch 作為一款開源的 Zapier 替代方案,為我們提供了一個強大而靈活的工具,讓工作流自動化變得更加簡單和可控。 一、Automatisch 簡介 Automatisch 是一個商…

RAG應用效果評估框架與優化指南

1. 引言:為何RAG評估至關重要? 一個RAG系統通常包含多個可調參數和可替換組件(如不同的嵌入模型、向量數據庫、LLM、Prompt模板等)。沒有有效的評估機制,優化過程就像“盲人摸象”,難以判斷改動是否帶來了真正的提升。 RAG評估的核心目的: 量化系統性能:將RAG的“好壞…

豆包大模型應用場景

豆包作為通用大模型&#xff0c;應用場景其實覆蓋了個人和企業兩端。個人端要突出生活化功能——比如幫學生解題、幫上班族寫周報&#xff1b;企業端則要強調降本增效&#xff0c;比如客服自動化、代碼生成這些硬需求。用戶沒指定角度&#xff0c;那就都覆蓋吧。 注意到用戶用“…

OSITCP/IP

模型&協議 在互聯網發展的早期,不同的計算機廠商有不同的網絡傳輸協議,例如:IBM的SNA協議、蘋果的AppleTalk協議等,這些協議互不兼容,導致雖然不同的產商計算機在物理層面是鏈接的,但是在網絡上基本無法完成正常通信。這就導致一個用戶如果使用了某個廠商的某個網絡…

店匠科技閃耀“跨博會”,技術+生態打造靈活出海能力

2025年6月16日至18日&#xff0c;第八屆全球跨境電商節暨第十屆深圳國際跨境電商貿易博覽會&#xff08;簡稱“跨博會”&#xff09;在深圳會展中心舉行。作為全球跨境電商行業的年度盛會&#xff0c;本屆展會以“文化跨境、品牌出海、智量強國”為主題&#xff0c;匯聚近 1500…

selenium彈框元素定位-凍結界面

有些網站上面的元素&#xff0c;我們鼠標放在上面&#xff0c;會動態彈出一些內容。 但是當我們的鼠標從音樂圖標移開&#xff0c;這個欄目就整個消失了&#xff0c;就沒法查看其對應的HTML。 怎么辦&#xff1f;在開發者工具欄console里面執行如下js代碼 &#xff1a; setTi…

美學心得(第二百七十九集)羅國正

美學心得&#xff08;第二百七十九集&#xff09; 羅國正 &#xff08;2025年6月&#xff09; 3299、分清不同本體、主體及其之間的關系&#xff0c;是 正確的審美、判斷首先的關鍵 羅國正 &#xff08;2025年6月11日于廣州&#xff09; “人也按照美的規律來建造。”這句話…

云祺容災備份系統公有云備份與恢復實操-AWS

1、創建訪問密鑰 訪問并登錄AWS控制臺&#xff0c;點擊右上角用戶名、安全憑證&#xff0c;在我的安全憑證窗口中&#xff0c;下拉找到訪問密鑰&#xff0c;并點擊創建訪問密鑰&#xff0c;選擇其他&#xff0c;點擊下一步&#xff0c;即可獲得密鑰信息如圖1至圖6。 注意&…

windows內網穿透

內網穿透&#xff08;NAT穿透&#xff09;是一種通過技術手段將局域網&#xff08;內網&#xff09;中的服務暴露到公網&#xff08;外網&#xff09;的方法&#xff0c;使外部用戶能夠訪問內網資源。其核心是解決因NAT&#xff08;網絡地址轉換&#xff09;或防火墻限制導致的…

threejs 實現720°全景圖,;兩種方式:環境貼圖、CSS3DRenderer渲染

前提 有一個前提條件&#xff1a;六張大小一致的圖片&#xff0c;六個圖片分別對應的是720全景圖的六個面&#xff1a;上、下、左、右、前、后。 這個不是那種無人機拍攝的全景圖&#xff0c;是六個圖片拼起來的&#xff0c;這樣的取景方式要比無人機的要經濟一些。 ---…

老牌軟件 Ghost 備份還原操作基礎

一、Ghost 簡介 Symantec Ghost&#xff08;也稱為 Norton Ghost&#xff09; 是一款強大的磁盤克隆和備份還原工具&#xff0c;廣泛用于系統部署、數據恢復和災難恢復。其主要功能包括&#xff1a; 創建磁盤鏡像&#xff08;.GHO文件&#xff09;備份/還原分區或整個硬盤支持…

SSH連接服務器并同步本地文件

SSH連接服務器并同步本地文件 1. 復制本地公鑰 cat ~/.ssh/id_rsa.pub如果不確定本地是否有公鑰 ls ~/.ssh/id_rsa.pub# 如果出現如下&#xff0c;則說明你本地存在公鑰 # /Users/username/.ssh/id_rsa.pub若沒有公鑰&#xff0c;需生成 # 使用下面命令&#xff0c;然后一路回…

中英泰馬來語訂貨系統:助力東南亞批發貿易企業數字化轉型升級

隨著全球數字化轉型浪潮的推進&#xff0c;東南亞地區的批發貿易企業也正逐步邁向數字化發展道路。特別是在中英泰馬來語訂貨系統的推動下&#xff0c;東南亞的批發商和零售商能夠更高效、便捷地開展跨國貿易與供應鏈管理。這不僅幫助傳統企業提高了運營效率&#xff0c;還助力…

微信小程序獲取指定元素,滾動頁面到指定位置

微信小程序獲取指定元素&#xff0c;滾動頁面到指定位置 微信小程序獲取指定元素的寬高等信息,并滾動頁面到指定位置 微信小程序獲取指定元素的寬高等信息,并滾動頁面到指定位置 注&#xff1a;原生小程序開發&#xff1a; createSelectorQuery() 創建一個選擇器查詢實例。 sel…