RNN循環神經網絡
什么是循環神經網絡?
循環神經網絡(Recurrent Neural Network, RNN)是一類專門用于處理序列數據的神經網絡架構。與傳統的前饋神經網絡不同,RNN具有"記憶"能力,能夠捕捉數據中的時間依賴關系。
核心特點:
- 循環連接:RNN單元之間存在循環連接,使得信息能夠在網絡內部持續傳遞
- 參數共享:相同的權重參數在時間步之間共享,大大減少了模型參數數量
- 序列處理:能夠處理可變長度的輸入序列,適用于時序數據
基本結構:
RNN的基本單元包含一個隱藏狀態(hidden state),它在每個時間步都會被更新:
- 新隱藏狀態 = f(當前輸入, 前一個隱藏狀態)
舉一個簡單的例子:
簡單的循環神經網絡例子(多對多)
我們來做一個簡單的循環神經網絡,其實也就是跟上圖一致。
import torch
from torch import nnclass RNNCell(nn.Module):def __init__(self,input_size,hidden_size):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.w_hidden = torch.randn(hidden_size,hidden_size)self.w_input = torch.randn(input_size,hidden_size)self.tanh = nn.Tanh()def forward(self,x,hidden_state=None):N,input_size = x.shapeif hidden_state is None:hidden_state = torch.zeros(N,self.hidden_size)hidden_state = self.tanh(hidden_state @ self.w_hidden + x @ self.w_input)return hidden_stateclass RNN(nn.Module):def __init__(self,input_size,hidden_size):super().__init__()self.cell = RNNCell(input_size,hidden_size)self.w_output = torch.randn(hidden_size,hidden_size)def forward(self,x,hidden_state=None):N,L,input_size = x.shapeoutputs = []for i in range(L):x_i = x[:,i]hidden_state = self.cell(x_i,hidden_state)out = hidden_state @ self.w_outputoutputs.append(out)outputs = torch.stack(outputs,dim=1)return outputs,hidden_stateif __name__ == "__main__":x = torch.randn(5,3,10)model = RNN(10,20)y,h = model(x)print(y.shape)print(h.shape)
雙向循環神經網絡
雙向RNN其實也就是兩層RNN的疊加,分別更新的是兩層隱藏狀態以及兩層輸出。
import torch
from torch import nnclass BiRNN(nn.Module):def __init__(self,input_size,hidden_size):super().__init__()self.input_size = input_sizeself.hidden_size = hidden_size#前向RNN和線性層self.forward_cell = nn.RNNCell(input_size,hidden_size)self.backward_cell = nn.RNNCell(input_size,hidden_size)#反向RNN和線性層self.forward_Linear = nn.Linear(hidden_size,hidden_size)self.backward_Linear = nn.Linear(hidden_size,hidden_size)def forward(self,x,hidden = None):N,L,input_size = x.shapeif hidden is None:#堆疊兩層隱藏層hidden = torch.zeros(2,N,self.hidden_size)h_forward = hidden[0]out_forward = []for i in range(L):h_forward = self.forward_cell(x[:,i],h_forward)out = self.forward_Linear(h_forward)out_forward.append(out)out_forward = torch.stack(out_forward,dim=1)x = torch.flip(x,dims=[1])h_backward = hidden[1]out_backward = []for i in range(L):h_backward = self.backward_cell(x[:,i],h_backward)out = self.backward_Linear(h_backward)out_backward.append(out)out_backward = torch.stack(out_backward,dim=1)outputs = torch.concat((out_forward,out_backward),dim=-1)hidden = torch.stack([h_forward,h_backward])return outputs,hiddenif __name__ == '__main__':x = torch.randn((5,3,10))model = BiRNN(10,20)outputs,hidden = model(x)print(outputs.shape)print(hidden.shape)