文章目錄
- 介紹
- BERT 訓練之數據集處理
- BERT 原理及模型代碼實現
- 數據集處理
- 導包
- 加載數據
- 生成下一句預測任務的數據
- 從段落中獲取nsp數據
- 生成遮蔽語言模型任務的數據
- 從token中獲取mlm數據
- 將文本轉換為預訓練數據集
- 創建Dataset
- 加載WikiText-2數據集
- BERT 訓練代碼實現
- 導包
- 加載數據
- 構建BERT模型
- 模型損失
- 訓練
- 獲取BERT編碼器
個人主頁:道友老李
歡迎加入社區:道友老李的學習社區
介紹
**自然語言處理(Natural Language Processing,NLP)**是計算機科學領域與人工智能領域中的一個重要方向。它研究的是人類(自然)語言與計算機之間的交互。NLP的目標是讓計算機能夠理解、解析、生成人類語言,并且能夠以有意義的方式回應和操作這些信息。
NLP的任務可以分為多個層次,包括但不限于:
- 詞法分析:將文本分解成單詞或標記(token),并識別它們的詞性(如名詞、動詞等)。
- 句法分析:分析句子結構,理解句子中詞語的關系,比如主語、謂語、賓語等。
- 語義分析:試圖理解句子的實際含義,超越字面意義,捕捉隱含的信息。
- 語用分析:考慮上下文和對話背景,理解話語在特定情境下的使用目的。
- 情感分析:檢測文本中表達的情感傾向,例如正面、負面或中立。
- 機器翻譯:將一種自然語言轉換為另一種自然語言。
- 問答系統:構建可以回答用戶問題的系統。
- 文本摘要:從大量文本中提取關鍵信息,生成簡短的摘要。
- 命名實體識別(NER):識別文本中提到的特定實體,如人名、地名、組織名等。
- 語音識別:將人類的語音轉換為計算機可讀的文字格式。
NLP技術的發展依賴于算法的進步、計算能力的提升以及大規模標注數據集的可用性。近年來,深度學習方法,特別是基于神經網絡的語言模型,如BERT、GPT系列等,在許多NLP任務上取得了顯著的成功。隨著技術的進步,NLP正在被應用到越來越多的領域,包括客戶服務、智能搜索、內容推薦、醫療健康等。
BERT 訓練之數據集處理
BERT 原理及模型代碼實現
【自然語言處理(NLP)】基于Transformer架構的預訓練語言模型:BERT 原理及代碼實現
數據集處理
導包
import os
import random
import torch
import dltools
加載數據
def _read_wiki(data_dir):file_name = os.path.join(data_dir, 'wiki.train.tokens')with open(file_name, 'r',encoding="utf-8") as f:lines = f.readlines()# 大寫字母轉換為小寫字母paragraphs = [line.strip().lower().split(' . ') for line in lines if len(line.split(' . ')) >= 2]random.shuffle(paragraphs)return paragraphs_read_wiki('./wikitext-2')
生成下一句預測任務的數據
def _get_next_sentence(sentence, next_sentence, paragraphs):if random.random() < 0.5:is_next = Trueelse:# paragraphs是三重列表的嵌套next_sentence = random.choice(random.choice(paragraphs))is_next = Falsereturn sentence, next_sentence, is_next
從段落中獲取nsp數據
def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):nsp_data_from_paragraph = []for i in range(len(paragraph) - 1):tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i], paragraph[i + 1], paragraphs)# 考慮1個'<cls>'詞元和2個'<sep>'詞元if len(tokens_a) + len(tokens_b) + 3 > max_len:continuetokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)nsp_data_from_paragraph.append((tokens, segments, is_next))return nsp_data_from_paragraph
生成遮蔽語言模型任務的數據
- 為遮蔽語言模型的輸入創建新的詞元副本,其中輸入可能包含替換的mask或隨機詞元
- 打亂后用于在遮蔽語言模型任務中獲取15%的隨機詞元進行預測
- 80%的時間:將詞替換為mask詞元
- 10%的時間:保持詞不變
- 10%的時間:用隨機詞替換該詞
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,vocab):# 為遮蔽語言模型的輸入創建新的詞元副本,其中輸入可能包含替換的“<mask>”或隨機詞元mlm_input_tokens = [token for token in tokens]pred_positions_and_labels = []# 打亂后用于在遮蔽語言模型任務中獲取15%的隨機詞元進行預測random.shuffle(candidate_pred_positions)for mlm_pred_position in candidate_pred_positions:if len(pred_positions_and_labels) >= num_mlm_preds:breakmasked_token = None# 80%的時間:將詞替換為“<mask>”詞元if random.random() < 0.8:masked_token = '<mask>'else:# 10%的時間:保持詞不變if random.random() < 0.5:masked_token = tokens[mlm_pred_position]# 10%的時間:用隨機詞替換該詞else:masked_token = random.choice(vocab.idx_to_token)mlm_input_tokens[mlm_pred_position] = masked_tokenpred_positions_and_labels.append((mlm_pred_position, tokens[mlm_pred_position]))return mlm_input_tokens, pred_positions_and_labels
從token中獲取mlm數據
在遮蔽語言模型任務中不會預測特殊詞元
def _get_mlm_data_from_tokens(tokens, vocab):candidate_pred_positions = []# tokens是一個字符串列表for i, token in enumerate(tokens):# 在遮蔽語言模型任務中不會預測特殊詞元if token in ['<cls>', '<sep>']:continuecandidate_pred_positions.append(i)# 遮蔽語言模型任務中預測15%的隨機詞元num_mlm_preds = max(1, round(len(tokens) * 0.15))mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab)pred_positions_and_labels = sorted(pred_positions_and_labels,key=lambda x: x[0])pred_positions = [v[0] for v in pred_positions_and_labels]mlm_pred_labels = [v[1] for v in pred_positions_and_labels]return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]
將文本轉換為預訓練數據集
- valid_lens不包括’'的計數
- 填充詞元的預測將通過乘以0權重在損失中過濾掉
def _pad_bert_inputs(examples, max_len, vocab):max_num_mlm_preds = round(max_len * 0.15)all_token_ids, all_segments, valid_lens, = [], [], []all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []nsp_labels = []for (token_ids, pred_positions, mlm_pred_label_ids, segments,is_next) in examples:all_token_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (max_len - len(token_ids)), dtype=torch.long))all_segments.append(torch.tensor(segments + [0] * (max_len - len(segments)), dtype=torch.long))# valid_lens不包括'<pad>'的計數valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))all_pred_positions.append(torch.tensor(pred_positions + [0] * (max_num_mlm_preds - len(pred_positions)), dtype=torch.long))# 填充詞元的預測將通過乘以0權重在損失中過濾掉all_mlm_weights.append(torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (max_num_mlm_preds - len(pred_positions)),dtype=torch.float32))all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))nsp_labels.append(torch.tensor(is_next, dtype=torch.long))return (all_token_ids, all_segments, valid_lens, all_pred_positions,all_mlm_weights, all_mlm_labels, nsp_labels)
創建Dataset
- 輸入paragraphs[i]是代表段落的句子字符串列表
- 而輸出paragraphs[i]是代表段落的句子列表,其中每個句子都是詞元列表
- 獲取下一句子預測任務的數據
- 獲取遮蔽語言模型任務的數據
- 填充輸入
class _WikiTextDataset(torch.utils.data.Dataset):def __init__(self, paragraphs, max_len):# 輸入paragraphs[i]是代表段落的句子字符串列表;# 而輸出paragraphs[i]是代表段落的句子列表,其中每個句子都是詞元列表paragraphs = [dltools.tokenize(paragraph, token='word') for paragraph in paragraphs]sentences = [sentence for paragraph in paragraphsfor sentence in paragraph]self.vocab = dltools.Vocab(sentences, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])# 獲取下一句子預測任務的數據examples = []for paragraph in paragraphs:examples.extend(_get_nsp_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len))# 獲取遮蔽語言模型任務的數據examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)+ (segments, is_next))for tokens, segments, is_next in examples]# 填充輸入(self.all_token_ids, self.all_segments, self.valid_lens,self.all_pred_positions, self.all_mlm_weights,self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx], self.all_pred_positions[idx],self.all_mlm_weights[idx], self.all_mlm_labels[idx],self.nsp_labels[idx])def __len__(self):return len(self.all_token_ids)
加載WikiText-2數據集
def load_data_wiki(batch_size, max_len):"""加載WikiText-2數據集"""num_workers = dltools.get_dataloader_workers()data_dir = "./wikitext-2/"paragraphs = _read_wiki(data_dir)train_set = _WikiTextDataset(paragraphs, max_len)train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True, num_workers=num_workers)return train_iter, train_set.vocabbatch_size, max_len = 512, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,mlm_Y, nsp_y) in train_iter:print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,nsp_y.shape)break
torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])
len(vocab)
20256
BERT 訓練代碼實現
導包
import torch
from torch import nn
import dltools
加載數據
dltools中加載本地wiki文件,請自行修改路徑 ./data/wikitext-2
batch_size, max_len = 1, 64
# dltools中加載本地wiki文件,請自行修改路徑 ./data/wikitext-2
train_iter, vocab = dltools.load_data_wiki(batch_size, max_len)# tokens, segments, valid_lens, pred_positions, mlm_weights,mlm, nsp
for i in train_iter:break
i
構建BERT模型
net = dltools.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,num_layers=2, dropout=0.2, key_size=128, query_size=128,value_size=128, hid_in_features=128, mlm_in_features=128,nsp_in_features=128)
devices = dltools.try_all_gpus()
模型損失
- 前向傳播
- 計算遮蔽語言模型損失
- 計算下一句子預測任務的損失
loss = nn.CrossEntropyLoss()def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,segments_X, valid_lens_x,pred_positions_X, mlm_weights_X,mlm_Y, nsp_y):# 前向傳播_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,valid_lens_x.reshape(-1),pred_positions_X)# 計算遮蔽語言模型損失mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)# 計算下一句子預測任務的損失nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_lreturn mlm_l, nsp_l, l
訓練
遮蔽語言模型損失的和,下一句預測任務損失的和,句子對的數量,計數
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.Adam(net.parameters(), lr=0.01)step, timer = 0, dltools.Timer()animator = dltools.Animator(xlabel='step', ylabel='loss',xlim=[1, num_steps], legend=['mlm', 'nsp'])# 遮蔽語言模型損失的和,下一句預測任務損失的和,句子對的數量,計數metric = dltools.Accumulator(4)num_steps_reached = Falsewhile step < num_steps and not num_steps_reached:for tokens_X, segments_X, valid_lens_x, pred_positions_X,mlm_weights_X, mlm_Y, nsp_y in train_iter:tokens_X = tokens_X.to(devices[0])segments_X = segments_X.to(devices[0])valid_lens_x = valid_lens_x.to(devices[0])pred_positions_X = pred_positions_X.to(devices[0])mlm_weights_X = mlm_weights_X.to(devices[0])mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])trainer.zero_grad()timer.start()mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)l.backward()trainer.step()metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)timer.stop()animator.add(step + 1,(metric[0] / metric[3], metric[1] / metric[3]))step += 1if step == num_steps:num_steps_reached = Truebreakprint(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')train_bert(train_iter, net, loss, len(vocab), devices, 500)
獲取BERT編碼器
def get_bert_encoding(net, tokens_a, tokens_b=None):tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)encoded_X, _, _ = net(token_ids, segments, valid_len)return encoded_Xtokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 詞元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
(torch.Size([1, 6, 128]),torch.Size([1, 128]),tensor([-1.0005, 0.8355, 0.2930], grad_fn=<SliceBackward0>))
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 詞元:'<cls>','a','crane','driver','came','<sep>','he','just',
# 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]
(torch.Size([1, 10, 128]),torch.Size([1, 128]),tensor([-1.0168, 0.8235, 0.2141], grad_fn=<SliceBackward0>))