RNN
x 為當前狀態下數據的輸入, h 表示接收到的上一個節點的輸入。
y為當前節點狀態下的輸出,而h′h^\primeh′為傳遞到下一個節點的輸出.
LSTM
#定義網絡
lstm = nn.LSTM(input_size=20,hidden_size=50,num_layers=2)
#輸入變量
input_data = Variable(torch.randn(100,32,20))
#初始隱狀態
h_0 = Variable(torch.randn(2,32,50))
#輸出記憶細胞
c_0 = Variable(torch.randn(2,32,50))
#輸出變量
output,(h_t,c_t) = lstm(input_data,(h_0,c_0))
print(output.size())
print(h_t.size())
print(c_t.size())
#參數大小為(50x4,20),是RNN的四倍
print(lstm.weight_ih_l0)
print(lstm.weight_ih_l0.size())
打印結果:
torch.Size([100, 32, 50])
torch.Size([2, 32, 50])
torch.Size([2, 32, 50])
tensor([[ 0.0068, -0.0925, -0.0343, …, -0.1059, 0.0045, -0.1335],
[-0.0509, 0.0135, 0.0100, …, 0.0282, -0.1232, 0.0330],
[-0.0425, 0.1392, 0.1140, …, -0.0740, -0.1214, 0.1087],
…,
[ 0.0217, -0.0032, 0.0815, …, -0.0605, 0.0636, 0.1197],
[ 0.0144, 0.1288, -0.0569, …, 0.1361, 0.0837, -0.0021],
[ 0.0355, 0.1045, 0.0339, …, 0.1412, 0.0371, 0.0649]],
requires_grad=True)
torch.Size([200, 20])
注意LSTM的參數,rnn.weight_ih_l0 為 wiw_i~wi?? 的權重
rnn.weight_hh_l0 為 whw_h~wh?? 的權重,并且為hidden_size的4倍。
GRU
兩個門控
PyTorch中的循環神經網絡(RNN+LSTM+GRU)
人人都能看懂的GRU
人人都能看懂的LSTM