BERT數據處理,模型,預訓練

代碼來自李沐老師《動手學pytorch》
在數據處理時,首先執行以下代碼
def load_data_wiki(batch_size, max_len):"""加載WikiText-2數據集"""num_workers = d2l.get_dataloader_workers()data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')以上兩句代碼,不再說明paragraphs = _read_wiki(data_dir)train_set = _WikiTextDataset(paragraphs, max_len)train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True)return train_iter, train_set.vocab

d2l.DATA_HUB['wikitext-2'] = ('https://s3.amazonaws.com/research.metamind.io/wikitext/''wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')#@save
def _read_wiki(data_dir):file_name = os.path.join(data_dir, 'wiki.train.tokens')with open(file_name, 'r',encoding='utf-8') as f:lines = f.readlines()# 大寫字母轉換為小寫字母 ,每行文本中包含兩個句子,才進行處理,否則舍去文本paragraphs = [line.strip().lower().split(' . ')for line in lines if len(line.split(' . ')) >= 2]random.shuffle(paragraphs)return paragraphs

首先讀取文本,每個文本必須包含兩個以上句子(為了第二個預訓練任務:判斷兩個句子,是否連續)。paragraphs 其中一部分結果如下所示

文本中包含了三個句子,每個’‘里面,代表一個句子
['common starlings are trapped for food in some mediterranean countries'
, 'the meat is tough and of low quality , so it is <unk> or made into <unk>'
, 'one recipe said it should be <unk> " until tender , however long that may be "'
, 'even when correctly prepared , it may still be seen as an acquired taste .']
class _WikiTextDataset(torch.utils.data.Dataset):def __init__(self, paragraphs, max_len):'''每一個paragraph就是上面的包含多個句子的列表,將其進行分詞處理。下面是一個分詞的例子[['common', 'starlings', 'are', 'trapped', 'for', 'food', 'in', 'some', 'mediterranean', 'countries'], ['the', 'meat', 'is', 'tough', 'and', 'of', 'low', 'quality', ',', 'so', 'it', 'is', '<unk>', 'or', 'made', 'into', '<unk>'], ['one', 'recipe', 'said', 'it', 'should', 'be', '<unk>', '"', 'until', 'tender', ',', 'however', 'long', 'that', 'may', 'be', '"'], ['even', 'when', 'correctly', 'prepared', ',', 'it', 'may', 'still', 'be', 'seen', 'as', 'an', 'acquired', 'taste', '.']]'''paragraphs = [d2l.tokenize(paragraph, token='word') for paragraph in paragraphs]#將詞提取處理,保存sentences = [sentence for paragraph in paragraphsfor sentence in paragraph]#形成一個詞典,min_freq為詞最少出現的次數,少于5次,則不保存進詞典中self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])# 獲取下一句子預測任務的數據examples = []for paragraph in paragraphs:examples.extend(_get_nsp_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len))'''
def _get_nsp_data_from_paragraph(paragraph,paragraphs,vocab,max_len):nsp_data_from_paragraph=[]for i in range(len(paragraph)-1):_get_next_sentence函數傳入的是相鄰的句子a,b。函數中b會有一定概率替換為其他的句子tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i], paragraph[i + 1], paragraphs)句子長度大于bert限制的長度,則舍去。if len(tokens_a)+len(tokens_b)+3>max_len:continue#加上<cls>和<sep>,segments用于區token在哪個句子中tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)nsp_data_from_paragraph.append((tokens, segments, is_next))return nsp_data_from_paragraphtoken和segments的例子: True表示兩個句子相鄰,False表示b被隨機替換,a,b不相鄰。(['<cls>', 'mushrooms', 'grow', '<unk>', 'or', 'in', '"', '<unk>', 'groups', '"', 'in', 'late', 'summer', 'and', 'throughout', 'autumn', ',', 'though', 'it', 'is', 'not', 'commonly', 'encountered', 'species', '<sep>', 'it','can', 'be', 'found', 'in', 'europe', ',', 'asia', 'and', 'north', 'america', '.', '<sep>'], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1], True),'''# 獲取遮蔽語言模型任務的數據'''在這里我們會將句子中單詞,替換為在詞典中的索引。13意思為,句子的第13個詞,進行了處理,可能不變,可能替換為其他詞,可能替換為mask。在這里這個詞沒有替換。0與1區分兩個句子,False代表兩個句子不相鄰。examples中的結果;([3, 2510, 31, 337, 9, 0, 6, 6891, 8, 11621, 6, 21, 11, 60, 3405, 14, 1542, 9546, 4, 2524,21, 185, 4421, 649, 38, 277, 2872, 13233, 4], [13], [60], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], False)'''examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)+ (segments, is_next))for tokens, segments, is_next in examples]#_pad_bert_inputs對數據進行填充,all_mlm_weights中1為需要預測,0為填充#    all_mlm_weights= tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.](self.all_token_ids, self.all_segments, self.valid_lens,self.all_pred_positions, self.all_mlm_weights,self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx], self.all_pred_positions[idx],self.all_mlm_weights[idx], self.all_mlm_labels[idx],self.nsp_labels[idx])def __len__(self):return len(self.all_token_ids)

上述已經將數據處理完,最后看一下處理后的例子:

將原來的句子列表填充1,一直到到大小為64
tensor([[    3,     5,     0, 18306,    23,    11,  2659,   156,  5779,   382,1296,   110,   158,    22,     5,  1771,   496,     0,  3398,     2,5,  3496,   110,  5038,   179,     4,    16,    11, 19837,     6,58,    13,     5,   685,     7,    66,   156,     0,  3063,    77,3842,    19,     4,     1,     1,     1,     1,     1,     1,     1,1,     1,     1,     1,     1,     1,     1,     1,     1,     1,1,     1,     1,     1]])
segments用于區分兩個句子,0為第一個句子中的詞,1為第二個句子中的詞,后面的0為填充
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
valid_lens表示句子列表的有效長度
tensor([43.])
pred_positions需要預測的位置,0為填充
tensor([[19,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
mlm_weights需要預測多少個詞,0為填充
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
預測位置的真實標簽,0為填充
tensor([[22,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
兩句話是否相鄰
tensor([0])

隨后就是把處理好的數據,送入bert中。在 BERTEncoder 中,執行如下代碼:

 def forward(self, tokens, segments, valid_lens):# Shape of `X` remains unchanged in the following code snippet:# (batch size, max sequence length, `num_hiddens`)#  將token和segment分別進行embedding,X = self.token_embedding(tokens) + self.segment_embedding(segments)#加入位置編碼X = X + self.pos_embedding.data[:, :X.shape[1], :]for blk in self.blks:X = blk(X, valid_lens)return X

將編碼完后的數據,進行多頭注意力和殘差化

    def forward(self, X, valid_lens):Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))return self.addnorm2(Y, self.ffn(Y))

將結果返回到如下代碼中:其中encoded_X .shape=torch.Size([1, 64, 128]),1代表批次大小為1,我們設置的每個批次只有行文本,每行文本由64個詞組成,bert提取128維的向量來表示每個詞。隨后進行兩個任務,一個是預測被掩蓋的單詞,另一個為判斷兩個句子是否為相鄰。

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# The hidden layer of the MLP classifier for next sentence prediction.# 0 is the index of the '<cls>' tokennsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

第一個任務為預測被mask的單詞:

'''
例如:batch為1,X為1*64*128,其中num_pred_positions =10,batch_idx 會重復為[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],pred_positions為[ 3,  6, 10, 12, 15, 20,  0,  0,  0,  0],X[batch_idx, pred_positions]會將需要預測的向量取出。然后reshape為1*10*128的矩陣。最后連接一個mlp,經過規范化后接nn.Linear(num_hiddens, vocab_size)),會生成再vocab上的預測'''def forward(self, X, pred_positions):num_pred_positions = pred_positions.shape[1]pred_positions = pred_positions.reshape(-1)batch_size = X.shape[0]batch_idx = torch.arange(0, batch_size)# Suppose that `batch_size` = 2, `num_pred_positions` = 3, then# `batch_idx` is `torch.tensor([0, 0, 0, 1, 1, 1])`batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)masked_X = X[batch_idx, pred_positions]masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))mlm_Y_hat = self.mlp(masked_X)return mlm_Y_hat

結束后,會返回到上層的代碼中:

def forward(self, tokens, segments, valid_lens=None, pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# The hidden layer of the MLP classifier for next sentence prediction.# 0 is the index of the '<cls>' token判斷句子是否連續,將<cls>的向量,放入mlp中,接一個nn.Linear(num_inputs, 2),最后變成一個二分類問題。nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

后面就是計算損失:

將mlm_Y_hat進行reshap,與mlm_Y求loss,最后需要乘mlm_weights_X,將填充的無用數據進行去除。mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)取平均lossmlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_l

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/40643.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/40643.shtml
英文地址,請注明出處:http://en.pswp.cn/news/40643.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

django——配置 settings.py 及相關參數說明

3. 配置 settings.py 及相關參數說明 3.1 配置setting.py文件 設置setting.py文件 加入安裝的庫 apps.erp_test, rest_framework, django_filters, drf_spectacular,加入新增的APP users啟動項目 # 運行項目先執行數據庫相關操作&#xff0c;再啟動 django 項目 python manag…

【JavaSE】面向對象之繼承

繼承 繼承概念繼承的語法父類成員的訪問子類和父類沒有同名成員變量子類和父類有同名成員變量成員方法名字不同成員方法名字相同 super關鍵字子類構造方法super和this繼承方式 繼承概念 繼承(inheritance)機制&#xff1a;是面向對象程序設計使代碼可以復用的最重要的手段&…

docker 安裝nacos

1、下載nacos docker pull nacos/nacos-server2、啟動nacos docker run --restart always --env MODEstandalone --name nacos -d -p 8848:8848 -p 9848:9848 -p 9849:9849 nacos/nacos-server3、驗證nacos http://localhost:8848/nacos 默認用戶名和密碼&#xff1a;nacos

lvs集群與nat模式

一&#xff0c;什么是集群&#xff1a; 集群&#xff0c;群集&#xff0c;Cluster&#xff0c;由多臺主機構成&#xff0c;但是對外只表現為一個整體&#xff0c;只提供一個訪問入口&#xff08;域名與ip地址&#xff09;&#xff0c;相當于一臺大型計算機。 二&#xff0c;集…

Java書簽 #使用MyBatis接入多數據源

楔子&#xff1a;當然&#xff0c;世上有很多優秀的女性&#xff0c;我也會被她們吸引。這對男人來說是理所當然的。但目光被吸引和內心被吸引是截然不同的。- 東野圭吾《黎明之街》 今日書簽 在一些應用場景中&#xff0c;可能需要連接多個不同的數據庫&#xff0c;例如連接不…

Centos 防火墻命令

查看防火墻狀態 systemctl status firewalld.service 或者 firewall-cmd --state 開啟防火墻 單次開啟防火墻 systemctl start firewalld.service 開機自啟動防火墻 systemctl enable firewalld.service 重啟防火墻 systemctl restart firewalld.service 防火墻設置開…

8.15 IO的多路復用

select的TCP客戶端 poll的TCP客戶端

Chart GPT免費可用地址共享資源

GPT4.0&#xff1a; https://gpt4e.ninvfeng.xyz github:https://github.com/ninvfeng/chatgpt4 WeUseAi&#xff1a;https://chatb.weuseai.pro AI.LS&#xff1a;https://n7.gpt03.xyz ChatX (iOS/macOS應用)&#xff1a;https://itunes.apple.com/app/id6446304087 ch…

C/C++ : C/C++的詳解,C語言與C++的常用算法以及算法的各自用法和應用(初級,中級),C++ CSP考題(J居多,S偏少)的詳解,NOI的真題題解

目錄 1.C語言 2.C 3.C與C語言的共同/不同點 4.導讀 5.相關文章 5.1&#xff1a;Dev-C是Windows 環境下的一個輕量級 C/C 集成開發環境&#xff08;IDE&#xff09; 5.2&#xff1a;C是從C語言發展而來的&#xff0c;而C語言的歷史可以追溯到1969年 6.C/C最新年度總…

?LeetCode解法匯總88. 合并兩個有序數組

目錄鏈接&#xff1a; 力扣編程題-解法匯總_分享記錄-CSDN博客 GitHub同步刷題項目&#xff1a; https://github.com/September26/java-algorithms 原題鏈接&#xff1a;力扣&#xff08;LeetCode&#xff09;官網 - 全球極客摯愛的技術成長平臺 描述&#xff1a; 給你兩個按…

解決方案:如何在 Amazon EMR Serverless 上執行純 SQL 文件?

長久已來&#xff0c;SQL以其簡單易用、開發效率高等優勢一直是ETL的首選編程語言&#xff0c;在構建數據倉庫和數據湖的過程中發揮著不可替代的作用。Hive和Spark SQL也正是立足于這一點&#xff0c;才在今天的大數據生態中牢牢占據著主力位置。在常規的Spark環境中&#xff0…

國企的大數據崗位方向的分析

現如今大數據已無所不在&#xff0c;并且正被越來越廣泛的被應用到歷史、政治、科學、經濟、商業甚至滲透到我們生活的方方面面中&#xff0c;獲取的渠道也越來越便利。 今天我們就來聊一聊“大屏應用”&#xff0c;說到大屏就一定要聊到數據可視化&#xff0c;現如今&#xf…

【Git】(三)回退版本

1、git reset命令 1.1 回退至上一個版本 git reset --hard HEAD^ 1.2 將本地的狀態回退到和遠程的一樣 git reset --hard origin/master 注意&#xff1a;謹慎使用 –-hard 參數&#xff0c;它會刪除回退點之前的所有信息。HEAD 說明&#xff1a;HEAD 表示當前版本HEAD^ 上…

服務鏈路追蹤

一、服務鏈路追蹤導論 1.背景 對于一個大型的幾十個、幾百個微服務構成的微服務架構系統&#xff0c;通常會遇到下面一些問題&#xff0c;比如&#xff1a; 如何串聯整個調用鏈路&#xff0c;快速定位問題&#xff1f;如何理清各個微服務之間的依賴關系&#xff1f;如何進行…

pycorrector一鍵式文本糾錯工具,整合了BERT、MacBERT、ELECTRA、ERNIE等多種模型,讓您立即享受糾錯的便利和效果

pycorrector&#xff1a;一鍵式文本糾錯工具&#xff0c;整合了Kenlm、ConvSeq2Seq、BERT、MacBERT、ELECTRA、ERNIE、Transformer、T5等多種模型&#xff0c;讓您立即享受糾錯的便利和效果 pycorrector: 中文文本糾錯工具。支持中文音似、形似、語法錯誤糾正&#xff0c;pytho…

Python OpenGL環境配置

1.Python的安裝請參照 Anconda安裝_安裝anconda_lwb-nju的博客-CSDN博客anconda安裝教程_安裝ancondahttps://blog.csdn.net/lwbCUMT/article/details/125322193?spm1001.2014.3001.5501 Anconda換源虛擬環境創建及使用&#xff08;界面操作&#xff09;_anconda huanyuan_l…

徹底卸載Android Studio

永恒的愛是永遠恪守最初的諾言。 在安裝Android Studio會有很多問題導致無法正常運行&#xff0c;多次下載AS多次錯誤后了解到&#xff0c;刪除以下四個文件才能徹底卸載Android Studio。 第一個文件&#xff1a;.gradle 路徑&#xff1a;C:\Users\yao&#xff08;這里yao是本…

解密人工智能:線性回歸 | 邏輯回歸 | SVM

文章目錄 1、機器學習算法簡介1.1 機器學習算法包含的兩個步驟1.2 機器學習算法的分類 2、線性回歸算法2.1 線性回歸的假設是什么&#xff1f;2.2 如何確定線性回歸模型的擬合優度&#xff1f;2.3 如何處理線性回歸中的異常值&#xff1f; 3、邏輯回歸算法3.1 什么是邏輯函數?…

火山引擎聯合Forrester發布《中國云原生安全市場現狀及趨勢白皮書》,賦能企業構建云原生安全體系

國際權威研究咨詢公司Forrester 預測&#xff0c;2023年全球超過40%的企業將會采用云原生優先戰略。然而&#xff0c;云原生在改變企業上云及構建新一代基礎設施的同時&#xff0c;也帶來了一系列的新問題&#xff0c;針對涵蓋云原生應用、容器、鏡像、編排系統平臺以及基礎設施…

用棧解決有效的括號匹配問題

//用數組實現棧 typedef char DataType; typedef struct stack {DataType* a;//動態數組int top;//棧頂int capacity; //容量 }ST;void STInit(ST*pst);//初始化void STDestroy(ST* pst);//銷毀所有空間void STPush(ST* pst, DataType x);//插入數據到棧中void STPop(ST* pst);…