基于RNN和Transformer的詞級語言建模 代碼分析 _generate_square_subsequent_mask
flyfish
Word-level Language Modeling using RNN and Transformer
word_language_model
PyTorch 提供的 word_language_model 示例展示了如何使用循環神經網絡RNN(GRU或LSTM)和 Transformer 模型進行詞級語言建模 。默認情況下,訓練使用Wikitext-2數據集,generate.py可以使用訓練好的模型來生成新文本。
源碼地址
https://github.com/pytorch/examples/tree/main/word_language_model
文件:model.py
import torch
import matplotlib.pyplot as plt
import numpy as npdef _generate_square_subsequent_mask(sz):return torch.log(torch.tril(torch.ones(sz, sz)))# 設置矩陣大小
sz = 5
mask = _generate_square_subsequent_mask(sz)# 將 mask 轉換為 numpy 數組,方便可視化
mask_np = mask.numpy()# 可視化
plt.imshow(mask_np, cmap='viridis')
plt.colorbar()
plt.title("Square Subsequent Mask")
plt.show()
可視化圖示
在可視化結果中,你會看到一個下三角矩陣,其值為 0 的部分為下三角部分,值為負無窮的部分為上三角部分。圖像中通常負無窮會被顯示為一種不同的顏色。
這樣,你可以直觀地理解生成的掩碼矩陣的結構和作用。這個掩碼矩陣主要用于 Transformer 模型中,以確保模型在預測時只能看到當前時刻及之前的時刻信息,而不能看到未來的信息。
結果
運行這段代碼,你會看到一個 5x5 的矩陣,其中下三角部分是 0(因為 log(1) = 0),上三角部分是負無窮(由于 log(0) 是負無窮)。
def _generate_square_subsequent_mask(sz):return torch.log(torch.tril(torch.ones(sz, sz)))
# 設置矩陣大小
sz = 5
mask = _generate_square_subsequent_mask(sz)# 打印矩陣
print(mask)
輸出
tensor([[0., -inf, -inf, -inf, -inf],[0., 0., -inf, -inf, -inf],[0., 0., 0., -inf, -inf],[0., 0., 0., 0., -inf],[0., 0., 0., 0., 0.]])
在數學上,定義對數函數時,log(0) 是未定義的,但在計算中,我們處理這種情況的方式是認為 log(0) 的極限值是負無窮。因此,計算機通常會返回負無窮來表示這種情況。
在 PyTorch 中,torch.log(0) 的結果是 -inf(負無窮)。這是因為對數函數是單調遞增的,并且在接近0時值會急劇下降到負無窮。