本系列教程適用于沒有任何pytorch的同學(簡單的python語法還是要的),從代碼的表層出發挖掘代碼的深層含義,理解具體的意思和內涵。pytorch的很多函數看著非常簡單,但是其中包含了很多內容,不了解其中的意思就只能【看懂代碼】,無法【理解代碼】。
目錄
- 官方定義
- demo
- one-hot
官方定義
torch.tensor.scatter_
是PyTorch中的一個函數,用于將指定索引處的值替換為給定的值。
函數定義:
Tensor.scatter_(dim, index, src, reduce=None) → Tensor
官方解釋:
-
將張量
src
中的所有值寫入索引張量中指定的index
處的self。 -
對于
src
中的每個值,它的輸出索引由其在src
中的索引(dimension != dim)
和在index中對應的值(dimension = dim)
指定。
非常難以理解,十分抽象,從我個人的角度來說就是:
- 第一個參數
dim
表示維度,即在第幾維度處理數據,保持其它維度不變。 reduce
參數是一個可選參數,用于指定如何在執行散射(scatter)操作時對重復的索引值進行合并或聚合。- index則是需要填充的列的索引,即根據維度從src中取對應的值填充到tensor中去。
怎么映射的,比如一個一個3維張量:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
官方的文檔如下,TORCH.TENSOR.SCATTER_:
即使如此理解起來也是很復雜,下面從例子中去理解:
demo
下面是一個官方文檔給出的例子:
import torchsrc = torch.Tensor([[-1.0276, 0.2673, -1.1752, -0.8823],[-0.6447, -0.8256, 0.1542, -0.4242]])
print(src)output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])output = output.scatter(1, index, src)
print(output)
輸出的結果:
我們一步步理解代碼:
- 首先,定義了一個
src
張量,后續output即從src中取值。 - 其次,定義了
output
,其值為二行五列的全零張量,后續對output
進行修改。 - 接著,定義了index,即從src取值的索引。
- 最后,根據index從src取值填充到output中,即完成操作。
那么具體是如何取值的呢?
首先,dim = 1
,意味著從維度值為1的地方取值,維度值為0的地方不變,那就是:
self[i][index[i][j]] = src[i][j] # if dim == 1
具體來說:
當i = 0, j = 0
時,output[0][index[0][0]] = src[0][0]
,因為index[0][0] = 3
,所以output[0][3] = src[0][0] = -1.0276
,這時候我們檢查輸出的output
值,確實是-1.0276
。
同理:
i = 0, j = 1
: output[0][index[0][1]] = output[0][1] = src[0][1] = 0.2673
i = 0, j = 2
: output[0][index[0][2]] = output[0][2] = src[0][2] = -1.1752
one-hot
作者在學習該函數時實在遇到one-hot編碼時遇到的,而該函數在one-hot中應用很廣:
index = torch.tensor([[3], [2], [0], [1]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)