在 PyTorch 中,torch.gather()
是一個非常實用的張量操作函數,主要用于根據索引從輸入張量中選擇特定位置的值。它常用于注意力機制、序列處理等場景。
函數定義
torch.gather(input, dim, index) → Tensor
input
:待提取數據的張量。dim
:在哪個維度上進行索引選擇。index
:一個與input
在除了dim
維度外相同形狀的張量,其值指定了從input
中提取的索引位置。- 返回值:從
input
的指定維度dim
上根據index
提取出的新張量。
形象理解
舉個簡單的例子:
示例 1:二維張量,按列(dim=1)提取
import torchinput = torch.tensor([[10, 20, 30],[40, 50, 60]])
index = torch.tensor([[2, 1, 0],[0, 1, 2]])output = torch.gather(input, dim=1, index=index)
print(output)
解釋:
- 對于第一行:從
[10, 20, 30]
中提取位置[2,1,0]
,結果是[30, 20, 10]
- 對于第二行:從
[40, 50, 60]
中提取位置[0,1,2]
,結果是[40, 50, 60]
輸出:
tensor([[30, 20, 10],[40, 50, 60]])
示例 2:按行(dim=0)提取
input = torch.tensor([[1, 2],[3, 4],[5, 6]])index = torch.tensor([[0, 1],[1, 2],[2, 0]])output = torch.gather(input, dim=0, index=index)
print(output)
解釋:
-
每個位置從第
dim=0
維度提取對應的元素。例如:- 第 (0,0) 位置:從 [1,3,5] 中取第 0 行,值為 1
- 第 (1,0) 位置:從 [1,3,5] 中取第 1 行,值為 3
- 第 (2,1) 位置:從 [2,4,6] 中取第 0 行,值為 2
輸出:
tensor([[1, 4],[3, 6],[5, 2]])
應用場景
- 注意力機制中的權重選擇
- 序列解碼中的 beam search
- 從嵌套表示中根據索引獲取嵌套內容
實戰場景舉例
假設有一個 batch 的 BERT 輸出,想從每個句子中提取第 N 個 token(如 [CLS]、某個關鍵詞)的表示向量。
假設數據
import torch
from transformers import BertModel, BertTokenizertokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")sentences = ["I love World", "Transformers are powerful"]
inputs = tokenizer(sentences, padding=True, return_tensors="pt")# 獲取 BERT 輸出
outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state # (batch_size, seq_len, hidden_size)print(last_hidden_state.shape)
# torch.Size([2, 5, 768]) 假設 padding 后為長度 5,hidden size 為 768
場景 1:提取每個句子的第一個 token(通常是 [CLS])
cls_embeddings = last_hidden_state[:, 0, :] # shape: (batch_size, hidden_size)
這個可以直接使用切片完成,不需要 gather
。
場景 2:提取每個句子中 指定位置的 token 表示(如“love”或“are”)
假設我們事先知道每個句子中感興趣 token 的位置:
# 每個句子中我們想要提取的 token 索引
# 假設我們想提取第 2 個 token
token_indices = torch.tensor([2, 1]) # shape: (batch_size,)
使用 gather
抽取對應 token 的向量:
# last_hidden_state: (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = last_hidden_state.size()# 將 token_indices 轉成 index 用于 gather: shape (batch_size, 1, 1)
token_indices = token_indices.view(-1, 1, 1).expand(-1, 1, hidden_size) # (batch_size, 1, hidden_size)# gather on dim=1(seq_len)
token_embeddings = torch.gather(last_hidden_state, dim=1, index=token_indices) # (batch_size, 1, hidden_size)# squeeze 掉中間的維度
token_embeddings = token_embeddings.squeeze(1) # (batch_size, hidden_size)print(token_embeddings.shape)
小結
操作需求 | 用法 |
---|---|
取所有句子的第一個 token | output[:, 0, :] |
取所有句子的第 N 個 token | output[:, N, :] |
取每個句子的指定 token(不同位置) | torch.gather() (如上所示) |
注意事項
index
必須與input
的 shape 一致,除了在指定的dim
維度上的大小。index
的值必須小于input
在dim
維度上的長度。