torch.gather
介紹
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
沿由 dim 指定的軸收集值。
對于三維張量,輸出按如下方式確定:
out[i][j][k] = input[index[i][j][k]][j][k] # 如果 dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # 如果 dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # 如果 dim == 2
input 和 index 必須具有相同的維度數。同時要求對于所有不等于 dim 的維度 d,滿足 index.size(d) <= input.size(d)。輸出的形狀將與 index 相同。注意 input 和 index 不會相互廣播。
參數
-
input (Tensor) – 源張量
-
dim (int) – 進行索引的軸
-
index (LongTensor) – 要收集的元素的索引
關鍵字參數
-
sparse_grad (bool, 可選) – 如果為 True,則關于 input 的梯度將是一個稀疏張量。
-
out (Tensor, 可選) – 目標張量
示例:
t = torch.tensor([[1, 2], [3, 4]])
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1], [ 4, 3]])
舉例
其實torch文檔給的形式非常清晰,只是一上來可能不太好理解
假如input是一個shape=[2,2]的矩陣,此時dim只能等于0或者1,index的shape也只能大于或者等于[2,2]
input=torch.tensor([[1,2][3,4]])
index = torch.tensor([[0, 1], [1, 2]])output = torch.gather(input,dim=0,index)
output[[],[]]
上面dim=0表示 output[i][j] = t[ index[i][j] ][ j ]
意思新的output矩陣行索引取值input矩陣的行索引,列索引取index矩陣中的元素值
所以取值如下
[[input[0,0],input[0,1]],[input[1,1],input[1,2]]
]
[1,23,4
]
總結
將index矩陣中的元素當成對input取值的行索引或者列索引,同時注意index矩陣中的元素值不能超過input的行或者列大小, 比如dim=0,那么index中元素值不能超過input的列大小2,否則就會報錯