從零開始的NLP:使用序列到序列網絡和注意力機制進行翻譯
我們將編寫自己的類和函數來預處理數據以完成我們的 NLP 建模任務。
在這個項目中,我們將訓練一個神經網絡將法語翻譯成英語。
[KEY: > input, = target, < output]> il est en train de peindre un tableau .
= he is painting a picture .
< he is painting a picture .> pourquoi ne pas essayer ce vin delicieux ?
= why not try that delicious wine ?
< why not try that delicious wine ?> elle n est pas poete mais romanciere .
= she is not a poet but a novelist .
< she not not a poet but a novelist .> vous etes trop maigre .
= you re too skinny .
< you re all alone .
… 取得不同程度的成功。
這得益于簡單而強大的序列到序列網絡思想,其中兩個循環神經網絡協同工作以將一個序列轉換為另一個序列。編碼器網絡將輸入序列壓縮成一個向量,解碼器網絡將該向量展開成一個新的序列。
為了改進這個模型,我們將使用一個注意力機制,它允許解碼器學習集中關注輸入序列的特定范圍。
你還會發現之前的從零開始的 NLP:使用字符級 RNN 分類名稱 和 從零開始的 NLP:使用字符級 RNN 生成名稱 教程非常有用,因為這些概念分別與編碼器和解碼器模型非常相似。
要求
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import randomimport torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as Fimport numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSamplerdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
加載數據文件
本項目的數據集包含數千對英法翻譯對。
Open Data Stack Exchange 上的這個問題 指引我找到了開放翻譯網站 https://tatoeba.org/,該網站在 https://tatoeba.org/eng/downloads 提供下載——更妙的是,有人做了額外的工作,將語言對分成獨立的文本文件,在此處提供:https://www.manythings.org/anki/
英法翻譯對文件太大,無法包含在倉庫中,請在繼續之前下載到 data/eng-fra.txt
。該文件是以制表符分隔的翻譯對列表
I am cold. J'ai froid.
注意
從這里下載數據并將其解壓到當前目錄。
與字符級 RNN 教程中使用的字符編碼類似,我們將語言中的每個詞表示為一個獨熱向量,或者一個巨大的零向量,除了一個位置為一(該詞的索引)。與語言中可能存在的幾十個字符相比,詞的數量要多得多,因此編碼向量更大。然而,我們將稍微做一些妥協,僅使用每種語言的幾千個詞來修剪數據。
我們稍后需要為每個詞設置一個唯一的索引,用作網絡的輸入和目標。為了跟蹤所有這些,我們將使用一個名為 Lang
的輔助類,它包含 word → index (word2index
) 和 index → word (index2word
) 字典,以及每個詞的計數 word2count
,這將用于稍后替換罕見詞。
SOS_token = 0
EOS_token = 1class Lang:def __init__(self, name):self.name = nameself.word2index = {}self.word2count = {}self.index2word = {0: "SOS", 1: "EOS"}self.n_words = 2 # Count SOS and EOSdef addSentence(self, sentence):for word in sentence.split(' '):self.addWord(word)def addWord(self, word):if word not in self.word2index:self.word2index[word] = self.n_wordsself.word2count[word] = 1self.index2word[self.n_words] = wordself.n_words += 1else:self.word2count[word] += 1
所有文件都是 Unicode 格式,為了簡化,我們將 Unicode 字符轉換為 ASCII,全部轉換為小寫,并去除大部分標點符號。
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn')# Lowercase, trim, and remove non-letter characters
def normalizeString(s):s = unicodeToAscii(s.lower().strip())s = re.sub(r"([.!?])", r" \1", s)s = re.sub(r"[^a-zA-Z!?]+", r" ", s)return s.strip()
為了讀取數據文件,我們將文件按行分割,然后將行分割成對。文件都是英語 → 其他語言,因此如果我們要從其他語言 → 英語翻譯,我添加了 reverse
標志來反轉翻譯對。
def readLangs(lang1, lang2, reverse=False):print("Reading lines...")# Read the file and split into lineslines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').\read().strip().split('\n')# Split every line into pairs and normalizepairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]# Reverse pairs, make Lang instancesif reverse:pairs = [list(reversed(p)) for p in pairs]input_lang = Lang(lang2)output_lang = Lang(lang1)else:input_lang = Lang(lang1)output_lang = Lang(lang2)re