RNN小練習
要求:
假設有 4 個字 吃 了 沒 ?
,請使用 torch.nn.RNN
完成以下任務
- 將每個進行 one-hot 編碼
- 請使用
吃 了 沒
作為輸入序列,了 沒 ?
作為輸出序列 - RNN 的
hidden_size = 64
- 請將 RNN 的輸出使用全連接轉換成 4 個特征,并使用 CrossEntropyLoss 訓練模型
- 訓練模型并驗證
1、準備數據集
import torch.nn.functional
from torch.utils.data import Datasetclass mydataset(Dataset):def __init__(self):super().__init__()texts = '吃 了 沒 ?'self.words = texts.split()self.input = self.words[:3]self.label = self.words[1:]def __len__(self):return 1def __getitem__(self, idx):# 對輸入進行 one_hot 編碼inp = torch.nn.functional.one_hot(torch.tensor([self.words.index(word) for word in self.input]),len(self.words)).float()# 對標簽進行編碼,返回文字的索引label = torch.tensor([self.words.index(word) for word in self.label])return inp, label
2、創建模型
import torch.nn as nnclass mymodel(nn.Module):def __init__(self):super().__init__()self.rnn = nn.RNN(4,64,nonlinearity='relu')self.fc1 = nn.Linear(64,4)def forward(self, x,h=None):x,h = self.rnn(x,h)x = self.fc1(x)return x,h
3、訓練模型以及預測
import torch.nn as nn
from torch import optimfrom myset import mydataset
from mymodel import mymodelEPOCH = 1000
LR = 1e-2ds = mydataset()
inputs,lables = ds[0]model = mymodel()loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=LR)for epoch in range(EPOCH):optimizer.zero_grad()y,h = model(inputs)loss = loss_fn(y,lables)print(loss)loss.backward()optimizer.step()model.eval()y,h = model(inputs)y = y.softmax(-1)
maxarg = y.argmax(-1)print([ds.words[indx] for indx in maxarg.tolist()])