代碼來自李沐老師《動手學pytorch》
在數據處理時,首先執行以下代碼
def load_data_wiki(batch_size, max_len):"""加載WikiText-2數據集"""num_workers = d2l.get_dataloader_workers()data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')以上兩句代碼,不再說明paragraphs = _read_wiki(data_dir)train_set = _WikiTextDataset(paragraphs, max_len)train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True)return train_iter, train_set.vocab
d2l.DATA_HUB['wikitext-2'] = ('https://s3.amazonaws.com/research.metamind.io/wikitext/''wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')#@save
def _read_wiki(data_dir):file_name = os.path.join(data_dir, 'wiki.train.tokens')with open(file_name, 'r',encoding='utf-8') as f:lines = f.readlines()# 大寫字母轉換為小寫字母 ,每行文本中包含兩個句子,才進行處理,否則舍去文本paragraphs = [line.strip().lower().split(' . ')for line in lines if len(line.split(' . ')) >= 2]random.shuffle(paragraphs)return paragraphs
首先讀取文本,每個文本必須包含兩個以上句子(為了第二個預訓練任務:判斷兩個句子,是否連續)。paragraphs 其中一部分結果如下所示
文本中包含了三個句子,每個’‘里面,代表一個句子
['common starlings are trapped for food in some mediterranean countries'
, 'the meat is tough and of low quality , so it is <unk> or made into <unk>'
, 'one recipe said it should be <unk> " until tender , however long that may be "'
, 'even when correctly prepared , it may still be seen as an acquired taste .']
class _WikiTextDataset(torch.utils.data.Dataset):def __init__(self, paragraphs, max_len):'''每一個paragraph就是上面的包含多個句子的列表,將其進行分詞處理。下面是一個分詞的例子[['common', 'starlings', 'are', 'trapped', 'for', 'food', 'in', 'some', 'mediterranean', 'countries'], ['the', 'meat', 'is', 'tough', 'and', 'of', 'low', 'quality', ',', 'so', 'it', 'is', '<unk>', 'or', 'made', 'into', '<unk>'], ['one', 'recipe', 'said', 'it', 'should', 'be', '<unk>', '"', 'until', 'tender', ',', 'however', 'long', 'that', 'may', 'be', '"'], ['even', 'when', 'correctly', 'prepared', ',', 'it', 'may', 'still', 'be', 'seen', 'as', 'an', 'acquired', 'taste', '.']]'''paragraphs = [d2l.tokenize(paragraph, token='word') for paragraph in paragraphs]#將詞提取處理,保存sentences = [sentence for paragraph in paragraphsfor sentence in paragraph]#形成一個詞典,min_freq為詞最少出現的次數,少于5次,則不保存進詞典中self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])# 獲取下一句子預測任務的數據examples = []for paragraph in paragraphs:examples.extend(_get_nsp_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len))'''
def _get_nsp_data_from_paragraph(paragraph,paragraphs,vocab,max_len):nsp_data_from_paragraph=[]for i in range(len(paragraph)-1):_get_next_sentence函數傳入的是相鄰的句子a,b。函數中b會有一定概率替換為其他的句子tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i], paragraph[i + 1], paragraphs)句子長度大于bert限制的長度,則舍去。if len(tokens_a)+len(tokens_b)+3>max_len:continue#加上<cls>和<sep>,segments用于區token在哪個句子中tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)nsp_data_from_paragraph.append((tokens, segments, is_next))return nsp_data_from_paragraphtoken和segments的例子: True表示兩個句子相鄰,False表示b被隨機替換,a,b不相鄰。(['<cls>', 'mushrooms', 'grow', '<unk>', 'or', 'in', '"', '<unk>', 'groups', '"', 'in', 'late', 'summer', 'and', 'throughout', 'autumn', ',', 'though', 'it', 'is', 'not', 'commonly', 'encountered', 'species', '<sep>', 'it','can', 'be', 'found', 'in', 'europe', ',', 'asia', 'and', 'north', 'america', '.', '<sep>'], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1], True),'''# 獲取遮蔽語言模型任務的數據'''在這里我們會將句子中單詞,替換為在詞典中的索引。13意思為,句子的第13個詞,進行了處理,可能不變,可能替換為其他詞,可能替換為mask。在這里這個詞沒有替換。0與1區分兩個句子,False代表兩個句子不相鄰。examples中的結果;([3, 2510, 31, 337, 9, 0, 6, 6891, 8, 11621, 6, 21, 11, 60, 3405, 14, 1542, 9546, 4, 2524,21, 185, 4421, 649, 38, 277, 2872, 13233, 4], [13], [60], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], False)'''examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)+ (segments, is_next))for tokens, segments, is_next in examples]#_pad_bert_inputs對數據進行填充,all_mlm_weights中1為需要預測,0為填充# all_mlm_weights= tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.](self.all_token_ids, self.all_segments, self.valid_lens,self.all_pred_positions, self.all_mlm_weights,self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx], self.all_pred_positions[idx],self.all_mlm_weights[idx], self.all_mlm_labels[idx],self.nsp_labels[idx])def __len__(self):return len(self.all_token_ids)
上述已經將數據處理完,最后看一下處理后的例子:
將原來的句子列表填充1,一直到到大小為64
tensor([[ 3, 5, 0, 18306, 23, 11, 2659, 156, 5779, 382,1296, 110, 158, 22, 5, 1771, 496, 0, 3398, 2,5, 3496, 110, 5038, 179, 4, 16, 11, 19837, 6,58, 13, 5, 685, 7, 66, 156, 0, 3063, 77,3842, 19, 4, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1]])
segments用于區分兩個句子,0為第一個句子中的詞,1為第二個句子中的詞,后面的0為填充
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
valid_lens表示句子列表的有效長度
tensor([43.])
pred_positions需要預測的位置,0為填充
tensor([[19, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
mlm_weights需要預測多少個詞,0為填充
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
預測位置的真實標簽,0為填充
tensor([[22, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
兩句話是否相鄰
tensor([0])
隨后就是把處理好的數據,送入bert中。在 BERTEncoder 中,執行如下代碼:
def forward(self, tokens, segments, valid_lens):# Shape of `X` remains unchanged in the following code snippet:# (batch size, max sequence length, `num_hiddens`)# 將token和segment分別進行embedding,X = self.token_embedding(tokens) + self.segment_embedding(segments)#加入位置編碼X = X + self.pos_embedding.data[:, :X.shape[1], :]for blk in self.blks:X = blk(X, valid_lens)return X
將編碼完后的數據,進行多頭注意力和殘差化
def forward(self, X, valid_lens):Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))return self.addnorm2(Y, self.ffn(Y))
將結果返回到如下代碼中:其中encoded_X .shape=torch.Size([1, 64, 128]),1代表批次大小為1,我們設置的每個批次只有行文本,每行文本由64個詞組成,bert提取128維的向量來表示每個詞。隨后進行兩個任務,一個是預測被掩蓋的單詞,另一個為判斷兩個句子是否為相鄰。
def forward(self, tokens, segments, valid_lens=None, pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# The hidden layer of the MLP classifier for next sentence prediction.# 0 is the index of the '<cls>' tokennsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat
第一個任務為預測被mask的單詞:
'''
例如:batch為1,X為1*64*128,其中num_pred_positions =10,batch_idx 會重復為[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],pred_positions為[ 3, 6, 10, 12, 15, 20, 0, 0, 0, 0],X[batch_idx, pred_positions]會將需要預測的向量取出。然后reshape為1*10*128的矩陣。最后連接一個mlp,經過規范化后接nn.Linear(num_hiddens, vocab_size)),會生成再vocab上的預測'''def forward(self, X, pred_positions):num_pred_positions = pred_positions.shape[1]pred_positions = pred_positions.reshape(-1)batch_size = X.shape[0]batch_idx = torch.arange(0, batch_size)# Suppose that `batch_size` = 2, `num_pred_positions` = 3, then# `batch_idx` is `torch.tensor([0, 0, 0, 1, 1, 1])`batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)masked_X = X[batch_idx, pred_positions]masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))mlm_Y_hat = self.mlp(masked_X)return mlm_Y_hat
結束后,會返回到上層的代碼中:
def forward(self, tokens, segments, valid_lens=None, pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# The hidden layer of the MLP classifier for next sentence prediction.# 0 is the index of the '<cls>' token判斷句子是否連續,將<cls>的向量,放入mlp中,接一個nn.Linear(num_inputs, 2),最后變成一個二分類問題。nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat
后面就是計算損失:
將mlm_Y_hat進行reshap,與mlm_Y求loss,最后需要乘mlm_weights_X,將填充的無用數據進行去除。mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)取平均lossmlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_l