吳恩達深度學習作業 RNN模型——字母級語言模型

一. 簡單復習一下RNN

RNN

RNN適用于處理序列數據,令x^{<i>}是序列的第i個元素,那么x^{<1>}x^{<2>}....x^{<T_{x}>}就是一個長度為T_{x}的序列,NLP中最常見的元素是單詞,對應的序列是句子。

RNN使用同一個神經網絡處理序列中的每一個元素。同時,為了表示序列的先后關系,RNN還有表示記憶的隱變量a,它記錄了前幾個元素的信息。對第t個元素的運算如下:

其中,W,b都是線性運算的參數,g是激活函數,隱藏層的激活函數一般用tanh,輸出層的激活函數根據實際情況選用。另外a的初始值a^{<1>},a^{<1>}=\vec{0}

語言模型

語言模型是NLP中的一個基礎任務。語言模型是NLP中的一個基礎任務。假設我們以單詞為基本元素,句子為序列,那么一個語言模型能夠輸出某句話的出現概率。通過比較不同句子的出現概率,我們能夠開發出很多應用。比如在英語里,同音的"apple and pear"比"apple and pair"的出現概率高(更可能是一個合理的句子)。當一個語音識別軟件聽到這句話時,可以分別寫下這兩句發音相近的句子,再根據語言模型斷定這句話應該寫成前者。

規范地說,對于序列x^{1}...x^{<T_{x}>},語言模型的輸出是P(x^{<1>}x^{<2>}....x^{<T_{x}>})?這個柿子也可以寫成:

即一句話的出現概率,等于第一個單詞出現在句首的概率,乘上第二個單詞在第一個單詞之后的概率,乘上第三個單詞再第一、二個單詞之后的概率,這樣一直乘下去。

單詞級的語言模型需要的數據量比較大,在這個項目中,我們將搭建一個字母級語言模型。即我們以字母為基本元素,單詞為序列。語言模型會輸出每個單詞的概率。比如我們輸入"apple"和"appll",語言模型會告訴我們單詞"apple"的概率更高,這個單詞更可能是一個正確的英文單詞。

RNN語言模型

為了計算語言模型的概率,我們可以用RNN分別輸出P(x^{<1>}),P(x^{<2>}|x^{<1>}),...最后把這些概率乘起來。

P(x^{<t>}|x^{<1>}x^{<2>},...x^{<t-1>})這個式子,說白了就是i給定前t-1個字母,猜一猜第t個字母最可能是哪個,比如給定了前四個字母"appl",第五個單詞構成"apply", "apple"的概率比較大,構成"appll", "appla"的概率較小。

為了讓神經網絡學會這個概率,我們可以令RNN的輸入為<sos> x_1, x_2, ..., x_T,RNN的標簽為x_1, x_2, ..., x_T, <eos><sos><eos>是句子開始和結束的特殊字符,實際實現中可以都用空格' '表示。<sos>也可以粗暴地用全零向量表示),即輸入和標簽都是同一個單詞,只是它們的位置差了一格。模型每次要輸出一個softmax的多分類概率,預測給定前幾個字母時下一個字母的概率。這樣,這個模型就能學習到前面那個條件概率了。

二. ?代碼細節

參考?https://zhuanlan.zhihu.com/p/558838663

1. 數據集獲取:

為了搭建字母級語言模型,我們只需要隨便找一個有很多單詞的數據集。這里我選擇了斯坦福大學的大型電影數據集,它收錄了IMDb上的電影評論,正面評論和負面評論各25000條。這個數據集本來是用于情感分類這一比較簡單的NLP任務,拿來搭字母級語言模型肯定是沒問題的。

這個數據集的文件結構大致如下:

├─test
│  ├─neg
│  │  ├ 0_2.txt
│  │  ├ 1_3.txt
│  │  └ ...
│  └─pos
├─train
│   ├─neg
│   └─pos
└─imdb.vocab

其中,imdb.vocab記錄了數據集中的所有單詞,一行一個。test和train測試集和訓練集,它們的neg和pos子文件夾分別記錄了負面評論和正面評論。每一條評論都是一句話,存在txt文件里。

代碼細節:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import re#from dldemos.BasicRNN.constant import EMBEDDING_LENGTH, LETTER_LIST, LETTER_MAPdef read_imdb_words(dir = 'data',split='pos',is_train=True,n_files=1000):subdir = 'train' if is_train else 'test'dir = os.path.join(dir,subdir,split)all_str = ''for file in os.listdir(dir):if n_files <= 0:breakwith open(os.path.join(dir, file), 'rb') as f:line = f.read().decode('utf-8')all_str += linen_files -= 1words = re.sub(u'([^\u0020\u0061-\u007a])','',all_str.lower()).split(' ')return wordsdef read_imdb(dir='data', split = 'pos', is_train = True):subdir = 'train' if is_train else 'test'dir = os.path.join(dir, subdir, split)lines = []for file in os.listdir(dir):with open(os.path.join(dir, file), 'rb') as f:line = f.read().decode('utf-8')lines.append(line)return linesdef read_imbd_vocab(dir='data'):fn = os.path.join(dir, 'imdb.vocab')with open(fn, 'rb') as f:word = f.read().decode('utf-8').replace('\n', ' ')print("read_imbd_vocab:",word)# words = re.sub(r'([^\u0020a-z])', '',word.lower().split(' '))# 清理字符串(移除所有非空格和非小寫字母的字符)words= re.sub(r'[^ a-z]', '', word.lower())# 按空格分割單詞words = words.split(' ')print("read_imbd_vocab:",words)filtered_words = [w for w in words if len(w) > 0]return filtered_wordsvocab = read_imbd_vocab()
print(vocab[0])
print(vocab[1])lines = read_imdb()
print('Length of the files:', len(lines))
print('lines[0]', lines[0])
words = read_imdb_words(n_files=100)print('Length of the words:', len(words))
for i in range(5):print(words[i])

read_imbd_vocab最終返回的數據單詞如下:?

1)words = re.sub(u'([^\u0020\u0061-\u007a])', '', all_str.lower())
步驟 1:?all_str.lower()
  • 作用:將原字符串?all_str?轉換為全小寫。

  • 示例
    "Hello World! 123"?→?"hello world! 123"

步驟 2:?re.sub(u'([^\u0020\u0061-\u007a])', '', ...)
  • 正則表達式模式([^\u0020\u0061-\u007a])

    • \u0020:Unicode 的空格字符(ASCII 32)。

    • \u0061-\u007a:Unicode 的小寫字母范圍(a?到?z,對應 ASCII 97-122)。

    • [^...]:匹配不包含在括號內的任何字符。

    • 整體含義:匹配所有非空格非小寫字母的字符。

  • 替換操作:將這些字符替換為空字符串(即刪除它們)。

  • 示例
    "hello world! 123"?→?"hello world "
    (移除了?!?和?123,但末尾可能留下多余空格)

  • import re# 清理字符串(移除所有非空格和非小寫字母的字符)
    cleaned_word = re.sub(r'[^ a-z]', '', word.lower())# 按空格分割單詞
    words = cleaned_word.split(' ')

步驟 3:?split(' ')
  • 作用:按空格分割字符串為單詞列表。

  • 潛在問題:連續空格可能導致空字符串(如?"hello world"?→?["hello", "", "world"])。

  • 示例
    "hello world "?→?["hello", "world", ""]

  1. 清理字符串

    • 轉換為全小寫。

    • 刪除所有非小寫字母(a-z)和非空格(?)的字符。

  2. 分割單詞
    按空格分割成單詞列表(可能包含空字符串)。

2).?output = torch.empty_like(word)?

這行代碼的作用是:

  • torch.empty_like(input)?是一個PyTorch函數,它會創建一個新張量(output),滿足以下條件:

    • 形狀相同:與輸入張量?word?的維度(shape)完全一致。

    • 數據類型相同:與?word?的數據類型(dtype,如?float32int64)相同。

    • 設備相同:與?word?所在的設備(如CPU或GPU)一致。

    • 未初始化內存:新張量的元素值是未定義的(可能是任意隨機值,取決于內存的當前狀態)。

2.數據集讀取

RNN的輸入不是字母,而是表示字母的向量。最簡單的字母表示方式是one-hot編碼,每一個字母用一個某一維度為1,其他維度為0的向量表示。比如我有a, b, c三個字母,它們的one-hot編碼分別為:

a: [1,0,0]
b: [0,1,0]
c: [0,0,1]?
EMBEDDING_LENGTH = 27
LETTER_MAP = {' ': 0}
ENCODING_MAP =[' ']
for i in range(26):LETTER_MAP[chr(ord('a')+i)] =i +1ENCODING_MAP.append(chr(ord('a')+i))
LETTER_LIST = list(LETTER_MAP.keys())print("LETTER_MAP:",LETTER_MAP)
print("ENCODING_MAP:",ENCODING_MAP)
print("LETTER_LIST:",LETTER_LIST)'''
字符生成: chr(ord('a') + i)動態生成每個小寫字母:
ord('a')返回a的ASCII碼97。
97 + i隨i從0到25變化, 得到97到122 (對應ASCII中的a到z)。
chr()將ASCII碼轉換為字符,得到a, b, ..., z。

打印結果更直觀:?

?Pytorch提供了用于管理數據讀取的Dataset類。Dataset一般只會存儲數據的信息,而非原始數據,比如存儲圖片路徑,而每次讀取時,Dataset才會去實際讀取數據。在這個項目里,我們用Data set存儲原始的單詞數組,實際讀取時,每次返回一個one-hot 編碼的向量。

實際dataset使用時,要繼承這個類,實現_len_和__getitem__方法。前者表示獲取數據集的長度,后者表示獲取某項數據。

import torch
from torch.utils.data import DataLoader,Datasetclass WordDataset(Dataset):def __init__(self, words, max_length, is_one_hot=True):super().__init__()self.words = wordsself.n_words = len(words)self.max_length = max_lengthself.is_onehot = is_one_hotdef __len__(self):return self.n_wordsdef __getitem__(self, index):word = self.words[index] + ' 'word_length = len(word)#print("word:",word)if self.is_onehot:tensor = torch.zeros(self.max_length, EMBEDDING_LENGTH)for i in range(self.max_length):if i < word_length:tensor[i][LETTER_MAP[word[i]]] = 1else:tensor[i][0] = 1else:tensor = torch.zeros(self.max_length, dtype = torch.long)for i in range(word_length):tensor[i] = LETTER_MAP[word[i]]return tensor

構造數據集的參數是words, max_length, is_onehotwords是單詞數組。max_length表示單詞的最大長度。在訓練時,我們一般要傳入一個batch的單詞。可是,單詞有長有短,我們不可能拿一個動態長度的數組去表示單詞。為了統一地表達所有單詞,我們可以記錄單詞的最大長度,把較短的單詞填充空字符,直到最大長度。is_onehot表示是不是one-hot編碼,我設計的這個數據集既能輸出用數字標簽表示的單詞(比如abc表示成[0, 1, 2]),也能輸出one-hoe編碼表示的單詞(比如abc表示成[[1, 0, 0], [0, 1, 0], [0, 0, 1]])。

在獲取數據集時,我們要根據是不是one-hot編碼,先準備好一個全是0的輸出張量。如果存的是one-hot編碼,張量的形狀是[MAX_LENGTH, EMBEDDING_LENGTH],第一維是單詞的最大長度,第二維是one-hot編碼的長度。而如果是普通的標簽數組,則張量的形狀是[MAX_LENGTH]。準備好張量后,遍歷每一個位置,令one-hot編碼的對應位為1,或者填入數字標簽。

另外,我們用空格表示單詞的結束。要在處理前給單詞加一個' ',保證哪怕最長的單詞也會至少有一個空格。

有了數據集類,結合之前寫好的數據集獲取函數,可以搭建一個DataLoader。DataLoader是PyTorch提供的數據讀取類,它可以方便地從Dataset的子類里讀取一個batch的數據,或者以更高級的方式取數據(比如隨機取數據)。?

def get_dataloader_and_max_langth(limit_length = None, is_one_hot = True, is_vocab = True):if is_vocab:words = read_imbd_vocab()else:words = read_imdb_words(n_files=200)max_length = 0for word in words:max_length = max(max_length, len(word))if limit_length is not None and max_length > limit_length:words = [w for w in words if len(w) <= limit_length]max_length = limit_lengthmax_length +=1dataset = WordDataset(words, max_length, is_one_hot)print("max_length:",max_length)return DataLoader(dataset, batch_size=256), max_length

這個函數會先調用之前編寫的數據讀取API獲取單詞數組。之后,函數會計算最長的單詞長度。這里,我用limit_length過濾了過長的單詞。據實驗,這個數據集里最長的單詞竟然有60多個字母,把短單詞填充至60需要浪費大量的計算資源。因此,我設置了limit_length這個參數,不去讀取那些過長的單詞。

計算完最大長度后,別忘了+1,保證每個單詞后面都有一個表示單詞結束的空格。

最后,用DataLoader(dataset, batch_size=256)就可以得到一個DataLoader。batch_size就是指定batch size的參數。我們這個神經網絡很小,輸入數據也很小,可以選一個很大的batch size加速訓練。

3.模型預覽

class RNN1(nn.Module):def __init__(self, hidden_units = 32):super().__init__()self.hidden_units = hidden_unitsself.linear_a = nn.Linear(self.hidden_units + EMBEDDING_LENGTH, hidden_units)self.linear_y = nn.Linear(hidden_units, EMBEDDING_LENGTH)self.tanh = nn.Tanh()def forward(self, word: torch.Tensor):#word shape: [batch, max_word_length, embedding_length]batch, Tx = word.shape[0:2]#word shape: [max_word_length, batch, embedding_length]word = torch.transpose(word, 0, 1)output = torch.empty_like(word)a = torch.zeros(batch, self.hidden_units)x = torch.zeros(batch, EMBEDDING_LENGTH)for i in range (Tx):next_a = self.tanh(self.linear_a(torch.cat((x,a),1)))hat_y = self.linear_y(next_a)output[i] = hat_yx = word[i]a = next_areturn torch.transpose(output, 0, 1)

我們可以把第一行公式里的兩個合并一下,拼接一下。這樣,只需要兩個線性層就可以描述RNN了。

因此,在初始化函數中,我們定義兩個線性層linear_alinear_y。另外,hidden_units表示隱藏層linear_a的神經元數目。tanh就是普通的tanh函數,它用作第一層的激活函數。

linear_a就是公式的第一行,由于我們把輸入x和狀態a拼接起來了,這一層的輸入通道數是hidden_units + EMBEDDING_LENGTH,輸出通道數是hidden_units。第二層linear_y表示公式的第二行。我們希望RNN能預測下一個字母的出現概率,因此這一層的輸出通道數是EMBEDDING_LENGTH=27,即字符個數。

在描述模型運行的forward函數中,我們先準備好輸出張量,再初始化好隱變量a和第一輪的輸入x。根據公式,循環遍歷序列的每一個字母,用a, x計算hat_y,并維護每一輪的a, x。最后,所有hat_y拼接成的output就是返回結果。

我們來看一看這個函數的細節。一開始,輸入張量word的形狀是[batch數,最大單詞長度,字符數=27]。我們提前獲取好形狀信息。

# word shape: [batch, max_word_length, embedding_length]
batch, Tx = word.shape[0:2]

我們循環遍歷的其實是單詞長度那一維。為了方便理解代碼,我們可以把單詞長度那一維轉置成第一維。根據這個新的形狀,我們準備好同形狀的輸出張量。輸出張量output[i][j]表示第j個batch的序列的第i個元素的27個字符預測結果。

4.訓練

首先,調用之前編寫的函數,準備好dataloadermodel。同時,準備好優化器optimizer和損失函數citerion。優化器和損失函數按照常見配置選擇即可。

這個語言模型一下就能訓練完,做5個epoch就差不多了。每一代訓練中, 先調用模型求出hat_y,再調用損失函數citerion,最后反向傳播并優化模型參數。

def train_rnn1():data, max_length = get_dataloader_and_max_langth(19)model = RNN1()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)citerion = torch.nn.CrossEntropyLoss()for epoch in range(5):loss_sum = 0dataset_len = len(data.dataset)for y in data:hat_y = model(y)n, Tx, _ = hat_y.shapehat_y = torch.reshape(hat_y,(n*Tx,-1))y = torch.reshape(y, (n* Tx, -1))label_y = torch.argmax(y, 1)loss = citerion(hat_y, label_y)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()loss_sum += lossprint(f'Epoch {epoch}. loss: {loss_sum / dataset_len}')torch.save(model.state_dict(), 'rnn1.pth')return model           

算損失函數前需要預處理一下數據,交叉熵損失函數默認hat_y的維度是[batch數,類型數]label_y是一個一維整形標簽數組。而模型的輸出形狀是[batch數,最大單詞長度,字符數],我們要把前兩個維度融合在一起。另外,我們并沒有提前準備好label_y,需要調用argmax把one-hot編碼轉換回標簽。

之后就是調用PyTorch的自動求導功能。注意,為了防止RNN梯度過大,我們可以用clip_grad_norm_截取梯度的最大值。?

輸出:

5. 測試

我們可以手動為字母級語言模型寫幾個測試用例,看看每一個單詞的概率是否和期望的一樣。我的測試單詞列表是:

test_words = ['apple', 'appll', 'appla', 'apply', 'bear', 'beer', 'berr', 'beee', 'car','cae', 'cat', 'cac', 'caq', 'query', 'queee', 'queue', 'queen', 'quest','quess', 'quees'
]

幾組長度一樣,但是最后幾個字母不太一樣的“單詞”。通過觀察這些詞的概率,我們能夠驗證語言模型的正確性。理論上來說,英文里的正確單詞的概率會更高。

我們的模型只能輸出每一個單詞的softmax前結果。我們還要為模型另寫一個求語言模型概率的函數。

@torch.no_grad()def language_model(self, word: torch.Tensor):# word shape: [batch, max_word_length, embedding_length]batch, Tx = word.shape[0:2]# word shape: [max_word_length, batch,  embedding_length]# word_label shape: [max_word_length, batch]word = torch.transpose(word, 0, 1)word_label = torch.argmax(word, 2)# output shape: [batch]output = torch.ones(batch, device=word.device)a = torch.zeros(batch, self.hidden_units, device=word.device)x = torch.zeros(batch, EMBEDDING_LENGTH, device=word.device)for i in range(Tx):next_a = self.tanh(self.linear_a(torch.cat((a, x), 1)))tmp = self.linear_y(next_a)hat_y = F.softmax(tmp, 1)probs = hat_y[torch.arange(batch), word_label[i]]#從hat_y里取出每一個batch里word_label[i]處的概率output *= probsx = word[i]a = next_areturn output@torch.no_grad()def sample_word(self):batch = 1output = ''a = torch.zeros(batch, self.hidden_units)x = torch.zeros(batch, EMBEDDING_LENGTH)for i in range(10):next_a = self.tanh(self.linear_a(torch.cat((a,x),1)))tmp = self.linear_y(next_a)hat_y = F.softmax(tmp, 1)np_prob = hat_y[0].detach().cpu().numpy()letter = np.random.choice(LETTER_LIST, p=np_prob)output += letterif letter == ' ':breakx = torch.zeros(batch, EMBEDDING_LENGTH)x[0][LETTER_MAP[letter]] = 1a = next_areturn output

這個函數和forward大致相同。只不過,這次我們的輸出output要表示每一個單詞的概率。因此,它被初始化成一個全1的向量。

# output shape: [batch]
output = torch.ones(batch, device=word.device)

每輪算完最后一層的輸出后,我們手動調用F.softmax得到softmax的概率值。

tmp = self.linear_y(next_a)
hat_y = F.softmax(tmp, 1)

接下來,我們要根據每一個batch當前位置的單詞,去hat_y里取出需要的概率。比如第2個batch當前的字母是b,我們就要取出hat_y[2][2]

i輪所有batch的字母可以用word_label[i]表示。根據這個信息,我們可以用probs = hat_y[torch.arange(batch), word_label[i]]神奇地從hat_y里取出每一個batch里word_label[i]處的概率。把這個概率乘到output上就算完成了一輪計算

有了語言模型函數,我們可以測試一下開始那些單詞的概率。

def sample(model):words =[]for _ in range(20):word = model.sample_word()words.append(word)print(*words)
def test_language_model(model, is_onehot=True):data, max_length = get_dataloader_and_max_langth(19)if is_onehot:test_word = words_to_onehot(test_words, max_length)else:test_word = words_to_label_array(test_words, max_length)probs = model.language_model(test_word)for word, prob in zip(test_words, probs):print(f'{word}: {prob}')  

#rnn1 = train_rnn1()
#rnn1 = RNN1()state_dict = torch.load('rnn1.pth')rnn1.load_state_dict(state_dict)rnn1.eval()# Dropout 層被禁用,BatchNorm 使用全局統計量
test_language_model(rnn1)
sample(rnn1)

輸出:?

?

采樣單詞:

語言模型有一個很好玩的應用:我們可以根據語言模型輸出的概率分布,采樣出下一個單詞;輸入這一個單詞,再采樣下一個單詞。這樣一直采樣,直到采樣出空格為止。使用這種采樣算法,我們能夠讓模型自動生成單詞,甚至是英文里不存在,卻看上去很像那么回事的單詞。

我們要為模型編寫一個新的方法sample_word,采樣出一個最大長度為10的單詞。這段代碼的運行邏輯和之前的forward也很相似。只不過,這一次我們沒有輸入張量,每一輪的x要靠采樣獲得。np.random.choice(LETTER_LIST, p=np_prob)可以根據概率分布np_prob對列表LETTER_LIST進行采樣。根據每一輪采樣出的單詞letter,我們重新生成一個x,給one-hot編碼的對應位置賦值1。

    @torch.no_grad()def sample_word(self):batch = 1output = ''a = torch.zeros(batch, self.hidden_units)x = torch.zeros(batch, EMBEDDING_LENGTH)for i in range(10):next_a = self.tanh(self.linear_a(torch.cat((a,x),1)))tmp = self.linear_y(next_a)hat_y = F.softmax(tmp, 1)np_prob = hat_y[0].detach().cpu().numpy()letter = np.random.choice(LETTER_LIST, p=np_prob)output += letterif letter == ' ':breakx = torch.zeros(batch, EMBEDDING_LENGTH)x[0][LETTER_MAP[letter]] = 1a = next_areturn output

使用這個方法,我們可以寫一個采樣20次的腳本:

def sample(model):words = []for _ in range(20):word = model.sample_word()words.append(word)print(*words)

輸出:

采樣出來的單詞幾乎不會是英文里的正確單詞。不過,這些單詞的詞綴很符合英文的造詞規則,非常好玩。如果為采樣函數加一些限制,比如只考慮概率前3的字母,那么算法應該能夠采樣出更正確的單詞。?

三.PyTorch里的RNN函數?

剛剛我們手動編寫了RNN的實現細節。實際上,PyTorch提供了更高級的函數,我們能夠更加輕松地實現RNN。其他部分的代碼邏輯都不怎么要改,這里只展示一下要改動的關鍵部分。

新的模型的主要函數如下:

class RNN2(torch.nn.Module):def __init__(self, hidden_units=64, embeding_dim=64, dropout_rate=0.2):super().__init__()self.drop = nn.Dropout(dropout_rate)self.encoder = nn.Embedding(EMBEDDING_LENGTH, embeding_dim)self.rnn = nn.GRU(embeding_dim, hidden_units, 1, batch_first=True)self.decoder = torch.nn.Linear(hidden_units, EMBEDDING_LENGTH)self.hidden_units = hidden_unitsself.init_weights()def init_weights(self):initrange = 0.1nn.init.uniform_(self.encoder.weight, -initrange, initrange)nn.init.zeros_(self.decoder.bias)nn.init.uniform_(self.decoder.weight, -initrange, initrange)def forward(self, word: torch.Tensor):# word shape: [batch, max_word_length]batch, Tx = word.shape[0:2]first_letter = word.new_zeros(batch, 1)x = torch.cat((first_letter, word[:, 0:-1]), 1)hidden = torch.zeros(1, batch, self.hidden_units, device=word.device)emb = self.drop(self.encoder(x))output, hidden = self.rnn(emb, hidden)y = self.decoder(output.reshape(batch * Tx, -1))return y.reshape(batch, Tx, -1)

初始化時,我們用nn.Embedding表示單詞的向量。詞嵌入(Embedding)是《深度學習專項-RNN》第二門課的內容,我會在下一篇筆記里介紹。這里我們把nn.Embedding看成一種代替one-hot編碼的更高級的向量就行。這些向量和線性層參數W一樣,是可以被梯度下降優化的。這樣,不僅是RNN可以優化,每一個單詞的表示方法也可以被優化。

注意,使用nn.Embedding后,輸入的張量不再是one-hot編碼,而是數字標簽。代碼中的其他地方也要跟著修改。

nn.GRU可以創建GRU。其第一個參數是輸入的維度,第二個參數是隱變量a的維度,第三個參數是層數,這里我們只構建1層RNN,batch_first表示輸入張量的格式是[batch, Tx, embedding_length]還是[Tx, batch, embedding_length]

貌似RNN中常用的正則化是靠dropout實現的。我們要提前準備好dropout層。

def __init__(self, hidden_units=64, embeding_dim=64, dropout_rate=0.2):super().__init__()self.drop = nn.Dropout(dropout_rate)self.encoder = nn.Embedding(EMBEDDING_LENGTH, embeding_dim)self.rnn = nn.GRU(embeding_dim, hidden_units, 1, batch_first=True)self.decoder = torch.nn.Linear(hidden_units, EMBEDDING_LENGTH)self.hidden_units = hidden_unitsself.init_weights()

準備好了計算層后,在forward里只要依次調用它們就行了。其底層原理和我們之前手寫的是一樣的。其中,self.rnn(emb, hidden)這個調用完成了循環遍歷的計算。

由于輸入格式改了,令第一輪輸入為空字符的操作也更繁瑣了一點。我們要先定義一個空字符張量,再把它和輸入的第一至倒數第二個元素拼接起來,作為網絡的真正輸入。

def forward(self, word: torch.Tensor):# word shape: [batch, max_word_length]batch, Tx = word.shape[0:2]first_letter = word.new_zeros(batch, 1)x = torch.cat((first_letter, word[:, 0:-1]), 1)hidden = torch.zeros(1, batch, self.hidden_units, device=word.device)emb = self.drop(self.encoder(x))output, hidden = self.rnn(emb, hidden)y = self.decoder(output.reshape(batch * Tx, -1))return y.reshape(batch, Tx, -1)

PyTorch里的RNN用起來非常靈活。我們不僅能夠給它一個序列,一次輸出序列的所有結果,還可以只輸入一個元素,得到一輪的結果。在采樣單詞時,我們不得不每次輸入一個元素。有關采樣的邏輯如下:

@torch.no_grad()
def sample_word(self, device='cuda:0'):batch = 1output = ''hidden = torch.zeros(1, batch, self.hidden_units, device=device)x = torch.zeros(batch, 1, device=device, dtype=torch.long)for _ in range(10):emb = self.drop(self.encoder(x))rnn_output, hidden = self.rnn(emb, hidden)hat_y = self.decoder(rnn_output)hat_y = F.softmax(hat_y, 2)np_prob = hat_y[0, 0].detach().cpu().numpy()letter = np.random.choice(LETTER_LIST, p=np_prob)output += letterif letter == ' ':breakx = torch.zeros(batch, 1, device=device, dtype=torch.long)x[0] = LETTER_MAP[letter]return output

以上就是PyTorch高級RNN組件的使用方法。在使用PyTorch的RNN時,主要的改變就是輸入從one-hot向量變成了標簽,數據預處理會更加方便一些。另外,PyTorch的RNN會自動完成循環,可以給它輸入任意長度的序列。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/pingmian/79516.shtml
繁體地址,請注明出處:http://hk.pswp.cn/pingmian/79516.shtml
英文地址,請注明出處:http://en.pswp.cn/pingmian/79516.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

基于python的哈希查表搜索特定文件

Python有hashlib庫&#xff0c;支持多種哈希算法&#xff0c;比如MD5、SHA1、SHA256等。通常SHA256比較安全&#xff0c;但MD5更快&#xff0c;但可能存在碰撞風險&#xff0c;得根據自己需求決定。下面以SHA256做例。 import hashlib import os from typing import Dict, Lis…

idea創建springboot項目無法創建jdk8原因及多種解決方案

idea創建springboot項目無法創建jdk8原因及多種解決方案 提示&#xff1a;幫幫志會陸續更新非常多的IT技術知識&#xff0c;希望分享的內容對您有用。本章分享的是springboot的使用。前后每一小節的內容是存在的有&#xff1a;學習and理解的關聯性。【幫幫志系列文章】&#x…

【C++進階十】多態深度剖析

【C進階十】多態深度剖析 1.多態的概念及條件2.虛函數的重寫3.重寫、重定義、重載區別4.C11新增的override 和final5.抽象類6.虛表指針和虛表6.1什么是虛表指針6.2指向誰調用誰&#xff0c;傳父類調用父類&#xff0c;傳子類調用子類 7.多態的原理8.單繼承的虛表狀態9.多繼承的…

面向網絡安全的開源 大模型-Foundation-Sec-8B

1. Foundation-Sec-8B 整體介紹 Foundation-Sec-8B 是一個專注于網絡安全領域的大型語言模型 (LLM),由思科的基礎人工智能團隊 (Foundation AI) 開發 。它基于 Llama 3.1-8B 架構構建,并通過在一個精心策劃和整理的網絡安全專業語料庫上進行持續預訓練而得到增強 。該模型旨在…

Python爬蟲的基礎用法

Python爬蟲的基礎用法 python爬蟲一般通過第三方庫進行完成 導入第三方庫&#xff08;如import requests &#xff09; requests用于處理http協議請求的第三方庫,用python解釋器中查看是否有這個庫&#xff0c;沒有點擊安裝獲取網站url&#xff08;url一定要解析正確&#xf…

WHAT - Tailwind CSS + Antd = MetisUI組件庫

文章目錄 Tailwind 和 Antd 組件庫MetisUI 組件庫 Tailwind 和 Antd 組件庫 在 WHAT - Tailwind 樣式方案&#xff08;不寫任何自定義樣式&#xff09; 中我們介紹了 Tailwind&#xff0c;至于 Antd 組件庫&#xff0c;我們應該都耳熟能詳&#xff0c;官網地址&#xff1a;htt…

Day 4:牛客周賽Round 91

好久沒寫了&#xff0c;問題還蠻多的。聽說這次是苯環哥哥出題 F題 小苯的因子查詢 思路 考慮求因子個數&#xff0c;用質因數分解&#xff1b;奇數因子只需要去掉質數為2的情況&#xff0c;用除法。 這里有個比較妙的細節是&#xff0c;提前處理出數字x的最小質因數&#xff0…

使用直覺理解不等式

問題是這個&#xff1a; 題目 探究 ∣ max ? b { q 1 ( z , b ) } ? max ? b { q 2 ( z , b ) } ∣ ≤ max ? b ∣ q 1 ( z , b ) ? q 2 ( z , b ) ∣ |\max_b\{q_1(z,b)\}-\max_b\{q_2(z,b)\}|\le\max_b|q_1(z,b)-q_2(z,b)| ∣maxb?{q1?(z,b)}?maxb?{q2?(z,b)}∣≤…

惡心的win11更新DIY 設置win11更新為100年

?打開注冊表編輯器?&#xff1a;按下Win R鍵&#xff0c;輸入regedit&#xff0c;然后按回車打開注冊表編輯器。?12?導航到指定路徑?&#xff1a;在注冊表編輯器中&#xff0c;依次展開HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\WindowsUpdate\UX\Settings?新建DWORD值?&…

嵌入式驅動學習

時鐘 定義 周期型的0、1信號 時鐘信號由“心臟”時鐘源產生&#xff0c;通過“動脈”時鐘樹傳播到整個芯片中。 SYSCLK系統時鐘&#xff0c;由HSI、HSE、PLLCLK三選一。 HCLK是AHB總線時鐘&#xff0c; PCLK是APB總線時鐘。 使用某個外設&#xff0c;必須要先使能該外設時鐘系統…

Java:從入門到精通,你的編程之旅

Java&#xff0c;一門歷久彌新的編程語言&#xff0c;自誕生以來就以其跨平臺性、面向對象、穩定性和安全性等特性&#xff0c;在企業級應用開發領域占據著舉足輕重的地位。無論你是初學者還是經驗豐富的開發者&#xff0c;Java 都能為你提供強大的工具和廣闊的舞臺。 為什么選…

Linux:深入理解數據鏈路層

實際上一臺主機中&#xff0c;報文并沒有通過網絡層直接發送出去&#xff0c;而是交給了自己的下一層協議——數據鏈路層&#xff01;&#xff01; 一、理解數據鏈路層 網絡層交付給鏈路層之前&#xff0c;會先做決策再行動&#xff08;會先查一下路由表&#xff0c;看看目標網…

Python基本語法(類和實例)

類和實例 類和對象是面向對象編程的兩個主要方面。類創建一個新類型&#xff0c;而對象是這個 類的實例&#xff0c;類使用class關鍵字創建。類的域和方法被列在一個縮進塊中&#xff0c;一般函數 也可以被叫作方法。 &#xff08;1&#xff09;類的變量&#xff1a;甴一個類…

2025 年如何使用 Pycharm、Vscode 進行樹莓派 Respberry Pi Pico 編程開發詳細教程(更新中)

micropython 概述 micropython 官方網站&#xff1a;https://www.micropython.org/ 安裝 Micropython 支持固件 樹莓派 Pico 安裝 Micropython 支持固件 下載地址&#xff1a;https://www.raspberrypi.com/documentation/microcontrollers/ 選擇 MicroPython 下載 RPI_PIC…

flink rocksdb狀態說明

文章目錄 1.默認情況2.flink中的狀態3.RocksDB4.對比情況5.使用6.RocksDB架構7.參考文章8.總結提示:以下主要考慮flink 狀態永久存儲 rocksdb情況,做一些簡單說明 1.默認情況 當flink使用rocksdb存儲狀態時。無論是永久存儲還是臨時存儲都可能會落盤寫文件(如果沒有配置存儲…

安裝SDL和FFmpeg

1、先記錄SDL 這玩意還是有一點講究的 具體步驟&#xff1a; 下載 SDL包&#xff1a; 鏈接&#xff1a;https://www.libsdl.org/release/SDL2-2.0.14.tar.gz 可以用迅雷&#xff0c;下載完之后&#xff0c; 解壓&#xff1a; tar -zxvf SDL2-2.0.14.tar.gz進入安裝目錄 cd …

2022年408真題及答案

2022年計算機408真題 2022年計算機408答案 2022 408真題下載鏈接 2022 408答案下載鏈接

Spring AI聊天模型API:輕松構建智能聊天交互

Spring AI聊天模型API&#xff1a;輕松構建智能聊天交互 前言 在當今數字化時代&#xff0c;智能聊天功能已成為眾多應用程序提升用戶體驗、增強交互性的關鍵要素。Spring AI的聊天模型API為開發者提供了一條便捷通道&#xff0c;能夠將強大的AI驅動的聊天完成功能無縫集成到…

Softmax回歸與單層感知機對比

(1) 輸出形式 Softmax回歸 輸出是一個概率分布&#xff0c;通過Softmax函數將線性得分轉換為概率&#xff1a; 其中 KK 是類別數&#xff0c;模型同時計算所有類別的概率。 單層感知機 輸出是二分類的硬決策&#xff08;如0/1或1&#xff09;&#xff1a; 無概率解釋&#x…

【React】Hooks 解鎖外部狀態安全訂閱 useSyncExternalStore 應用與最佳實踐

一、背景 useSyncExternalStore 是 React 18 引入的一個 Hook&#xff1b;用于從外部存儲&#xff08;例如狀態管理庫、瀏覽器 API 等&#xff09;獲取狀態并在組件中同步顯示。這對于需要跟蹤外部狀態的應用非常有用。 二、場景 訂閱外部 store 例如(redux,mobx,Zustand,jo…