假設我們需要一個查找表(Lookup Table),我們可以根據索引數字快速定位查找表中某個具體位置并讀取出來。最簡單的方法,可以通過一個二維數組或者二維list來實現。但如果我希望查找表的值可以通過梯度反向傳播來修改,那么就需要用到nn.Embedding
來實現了。
其實,我們需要用反向傳播來修正表值的場景還是很多的,比如我們想存儲數據的通用特征時,這個通用特征就可以用nn.Embedding來表示,常見于現在的各種codebook的trick。閑話不多說,我們來看栗子:
import torch
from torch import nntable = nn.Embedding(10, 3)
print(table.weight)
idx = torch.LongTensor([[1]])
b = table(idx)
print(b)'''
output
Parameter containing:
tensor([[-0.2317, -0.9679, -1.9324],[ 0.2473, 1.1043, -0.7218],[ 0.5425, -0.3109, -0.1330],[-1.4006, -0.0675, 0.1376],[-0.1995, 0.7168, 0.5692],[-1.3572, -0.6407, -0.0128],[-0.0773, 1.1928, -1.0836],[ 0.1721, -0.9232, -0.4059],[ 1.6108, -0.4640, 0.3535],[ 0.6975, 1.6554, -0.2217]], requires_grad=True)
tensor([[[ 0.2473, 1.1043, -0.7218]]], grad_fn=<EmbeddingBackward0>)
'''
這段代碼實際上就實現了一個查找表的功能,索引值為[[1]](注意有兩個中括弧),返回值為對應的表值。我們還可以批量查找表值:
import torch
from torch import nntable = nn.Embedding(10, 3)
print(table)
print(table.weight)indices = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
print(indices)out = table(indices)
print(out)
print(out.shape)
通過輸入索引張量來獲取表值:[2,4] -> [2,4,3],請注意這個shape變化,即對應位置的索引獲得對應位置的表值。
參考:https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
本人親自整理,有問題可留言交流~