基于transformers框架實踐Bert系列6-完形填空

本系列用于Bert模型實踐實際場景,分別包括分類器、命名實體識別、選擇題、文本摘要等等。(關于Bert的結構和詳細這里就不做講解,但了解Bert的基本結構是做實踐的基礎,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要講解完形填空應用場景。本系列代碼和數據集都上傳到GitHub上:https://github.com/forever1986/bert_task

1 環境說明

1)本次實踐的框架采用torch-2.1+transformer-4.37
2)另外還采用或依賴其它一些庫,如:evaluate、pandas、datasets、accelerate等

2 前期準備

Bert模型是一個只包含transformer的encoder部分,并采用雙向上下文和預測下一句訓練而成的預訓練模型。可以基于該模型做很多下游任務。

2.1 了解Bert的輸入輸出

Bert的輸入:input_ids(使用tokenizer將句子向量化),attention_mask,token_type_ids(句子序號)、labels(結果)
Bert的輸出:
last_hidden_state:最后一層encoder的輸出;大小是(batch_size, sequence_length, hidden_size)(注意:這是關鍵輸出,本次任務就需要獲取該值,可以取出那個被mask掉的token,獲取其前幾個,取score最高的(當然也可以使用top_k或者top_p方式獲取一定隨機性)
pooler_output:這是序列的第一個token(classification token)的最后一層的隱藏狀態,輸出的大小是(batch_size, hidden_size),它是由線性層和Tanh激活函數進一步處理的。(通常用于句子分類,至于是使用這個表示,還是使用整個輸入序列的隱藏狀態序列的平均化或池化,視情況而定)。
hidden_states: 這是輸出的一個可選項,如果輸出,需要指定config.output_hidden_states=True,它也是一個元組,它的第一個元素是embedding,其余元素是各層的輸出,每個元素的形狀是(batch_size, sequence_length, hidden_size)
attentions:這是輸出的一個可選項,如果輸出,需要指定config.output_attentions=True,它也是一個元組,它的元素是每一層的注意力權重,用于計算self-attention heads的加權平均值。

2.2 數據集與模型

1)數據集來自:ChnSentiCorp(該數據集本身是做情感分類,但是我們只需要取其text部分即可)
2)模型權重使用:bert-base-chinese

2.3 任務說明

完形填空其實就是在一段文字中mask掉幾個字,讓模型能夠自動填充字。這里本身就是bert模型做預訓練是所做的事情之一,因此就是讓數據給模型做訓練的過程。

2.4 實現關鍵

1)數據集結構是一個帶有text和label兩列的數據,我們只需要獲取到text部分即可。
在這里插入圖片描述
2)隨機mask掉部分數據,這個本身也是bert的訓練過程,因此在transforms框架中DataCollatorForLanguageModeling已經實現了,你也可以自己實現隨機mask掉你的數據進行訓練

3 關鍵代碼

3.1 數據集處理

數據集不需要做過多處理,只需要將text部分進行tokenizer,并制定max_length和truncation即可

def process_function(datas):tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)return tokenized_datas
new_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

3.2 模型加載

model = BertForMaskedLM.from_pretrained(model_path)

注意:這里使用的是transformers中的BertForMaskedLM,該類對bert模型進行封裝。如果我們不使用該類,需要自己定義一個model,繼承bert,增加分類線性層。另外使用AutoModelForMaskedLM也可以,其實AutoModel最終返回的也是BertForMaskedLM,它是根據你config中的model_type去匹配的。
這里列一下BertForMaskedLM的關鍵源代碼說明一下transformers幫我們做了哪些關鍵事情

# 在__init__方法中增加增加了BertOnlyMLMHead,BertOnlyMLMHead其實就是一個二層神經網絡,一層是BertPredictionHeadTransform(包括linear+geluAct+ln),一層是decoder(hidden_size*vocab_size大小的linear)。
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
# 將輸出結果outputs取第一個返回值,也就是last_hidden_state
sequence_output = outputs[0]
# 將last_hidden_state輸入到cls層中,獲得最終結果(預測的score和詞)
prediction_scores = self.cls(sequence_output)

3.3 自動并隨機mask數據

關鍵代碼在于DataCollatorForLanguageModeling,該類會實現自動mask。參考torch_mask_tokens方法。

trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),)

4 整體代碼

"""
基于BERT做完形填空
1)數據集來自:ChnSentiCorp
2)模型權重使用:bert-base-chinese
"""
# step 1 引入數據庫
from datasets import DatasetDict
from transformers import TrainingArguments, Trainer, BertTokenizerFast, BertForMaskedLM, DataCollatorForLanguageModeling, pipelinemodel_path = "./model/tiansz/bert-base-chinese"
data_path = "./data/ChnSentiCorp"# step 2 數據集處理
datasets = DatasetDict.load_from_disk(data_path)
tokenizer = BertTokenizerFast.from_pretrained(model_path)def process_function(datas):tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)return tokenized_datasnew_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)# step 3 加載模型
model = BertForMaskedLM.from_pretrained(model_path)# step 4 創建TrainingArguments
# 原先train是9600條數據,batch_size=32,因此每個epoch的step=300
train_args = TrainingArguments(output_dir="./checkpoints",      # 輸出文件夾per_device_train_batch_size=32,  # 訓練時的batch_sizenum_train_epochs=1,              # 訓練輪數logging_steps=30,                # log 打印的頻率)# step 5 創建Trainer
trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],# 自動MASK關鍵所在,通過DataCollatorForLanguageModeling實現自動MASK數據data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),)# Step 6 模型訓練
trainer.train()# step 7 模型評估
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)
str = datasets["test"][3]["text"]
str = str.replace("方便","[MASK][MASK]")
results = pipe(str)
# results[0][0]["token_str"]
print(results[0][0]["token_str"]+results[1][0]["token_str"])

5 運行效果

在這里插入圖片描述

注:本文參考來自大神:https://github.com/zyds/transformers-code

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/13512.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/13512.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/13512.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

自己動手寫docker——Namespace

Linux Namespace linux Namespace用于隔離一系列的系統資源,例如pid,userid,netword等,借助于Linux Namespace,可以實現容器的基本隔離。 Namespce介紹 Namespace類型系統調用參數作用Mount NamespaceCLONE_NEWNS隔離…

Python筑基之旅-MySQL數據庫(一)

目錄 一、MySQL數據庫 1、簡介 2、優點 2-1、開源和免費 2-2、高性能 2-3、可擴展性 2-4、易用性 2-5、靈活性 2-6、安全性和穩定性 2-7、豐富的功能 2-8、結合其他工具和服務 2-9、良好的兼容性和移植性 3、缺點 3-1、對大數據的支持有限 3-2、缺乏全文…

微服務如何做好監控

大家好,我是蒼何。 在脈脈上看到這條帖子,說阿里 P8 因為上面 P9 斗爭失敗走人,以超齡 35 被裁,Boss 上找工作半年,到現在還處于失業中。 看了下溝通記錄, 溝通了 1000 多次,但沒有一個邀請投遞…

uniapp中使用 iconfont字體

下載 iconfont 字體文件 打開 iconfont.css 文件,修改一下 把文件 復制到 static/iconfont/… 目錄下 在App.vue中引入iconfont 5. 使用iconfont 使用 iconfont 有兩種方式, 一種是 class 方式, 一種是使用 unicode 的方式 5.1 使用 class 的…

【Mac】Dreamweaver 2021 for mac v21.3 Rid中文版安裝教程

軟件介紹 Dreamweaver是Adobe公司開發的一款專業網頁設計與前端開發軟件。它集成了所見即所得(WYSIWYG)編輯器和代碼編輯器,可以幫助開發者快速創建和編輯網頁。Dreamweaver提供了豐富的功能和工具,包括代碼提示、語法高亮、代碼…

51單片機學習(1)2-1點亮一個LED

#include <REGX52.H> void() { p20xFE;//1111 1110 while(1) { //讓程序停了下來了。 } }

教你一分鐘搭建適合IT人員的在線開發工具箱

文章目錄 1. 使用Docker本地部署it-tools2. 本地訪問it-tools3. 安裝cpolar內網穿透4. 固定it-tools公網地址 本篇文章將介紹如何在Windows上使用Docker本地部署IT- Tools&#xff0c;并且同樣可以結合cpolar實現公網訪問。 在前一篇文章中我們講解了如何在Linux中使用Docker搭…

Anaconda Jupyter 報錯及解決方法記錄

一、AttributeError: module lib has no attribute X509_V_FLAG_CB_ISSUER_CHECK 背景&#xff1a;Anaconda更新版本后&#xff0c;運行import oss2時報錯 ~/anaconda3/lib/python3.8/site-packages/OpenSSL/crypto.py in X509StoreFlags() 1535 NOTIFY_POLICY _lib…

【Java基礎】集合(1) —— Collection

存儲不同類型的對象: Object[] arrnew object[5];數組的長度是固定的, 添加或刪除數據比較耗時 集合: Object[] toArray可以存儲不同類型的對象隨著存儲的對象的增加&#xff0c;會自動的擴容集合提供了非常豐富的方法&#xff0c;便于操縱集合相當于容器&#xff0c;可以存儲多…

探索Git之旅:倉庫代碼版本控制藝術

探索Git之旅&#xff1a;倉庫代碼版本控制藝術 引言Git基礎與核心概念什么是版本控制&#xff1f;Git的工作流程分布式特性 Git實戰操作指南安裝與配置克隆倉庫日常操作分支管理解決沖突 高級技巧與最佳實踐Git FlowGit鉤子Git別名 安全與性能考量結語與引發討論 引言 在軟件開…

馮喜運:5.16黃金是否突破阻力?黃金原油趨勢分析

【黃金消息面分析】&#xff1a;周四(5月16日)亞市盤中&#xff0c;現貨黃金延續昨日升勢&#xff0c;金價目前最高觸及2397.44美元/盎司&#xff0c;為4月19日以來新高。FXStreet首席分析師Valeria Bednarik撰文&#xff0c;對黃金技術前景進行分析。Bednarik指出&#xff0c;…

「51媒體」北京財經媒體有哪些?媒體邀約宣傳

傳媒如春雨&#xff0c;潤物細無聲&#xff0c;大家好&#xff0c;我是51媒體網胡老師。 北京作為中國的首都&#xff0c;擁有眾多的財經媒體&#xff0c;這些媒體在財經新聞報道、經濟分析、市場研究等方面發揮著重要作用。根據搜索結果&#xff0c;以下是一些北京地區的財經…

富格林:曝光虛假套路規避虧損

富格林指出&#xff0c;在現貨黃金市場中&#xff0c;交易時間很充足投資機會也多的是&#xff0c;但為什么還是有人虧損甚至爆倉呢&#xff1f;其實導致這種情況&#xff0c;是因為有一些投資者不知道其中的虛假套路&#xff0c;很容易就一頭栽進去了。要規避虛假套路帶來的虧…

CV每日論文--2024.5.15

1、Can Better Text Semantics in Prompt Tuning Improve VLM Generalization? 中文標題&#xff1a;更好的文本語義在提示微調中能否提高視覺語言模型的泛化能力? 簡介&#xff1a;這篇論文介紹了一種新的可學習提示調整方法,該方法超越了僅對視覺語言模型進行微調的傳統方…

Lazyboy品牌發布會“球幕氣膜”

Lazyboy品牌發布會“球幕氣膜”為品牌活動提供了一個獨特、現代化、環保的展示空間。這座球幕氣膜不僅為發布會提供了一個視覺震撼的場地&#xff0c;也為與會嘉賓帶來了全新的體驗。作為輕空間&#xff08;江蘇&#xff09;膜科技有限公司&#xff08;以下簡稱“輕空間”&…

使用Docker在阿里云ECS上部署Gitlab,提供代碼托管、CICD 和 docker鏡像服務

文章目錄 使用Docker在阿里云ECS上部署Gitlab1.購買一個數據&#xff0c;掛載到/data用于存儲gitlab相關數據2. 部署docker引擎3. 調整ssh的默認端口&#xff0c;將22端口留給gitlab4. 部署gitlab5. 進入docker容器獲取gitlab的默認密碼6. 登錄gitlab&#xff0c;完成gitlab-ru…

linux ndk編譯搭建測試

一、ndk下載 NDK 下載 | Android NDK | Android Developers 二、ndk環境變量配置 ndk解壓&#xff1a; unzip android-ndk-r26d-linux.zip 環境變量配置&#xff1a; export NDK_HOME/rd/own/test/android-ndk-r26d/ export PATH$PATH:$NDK_HOME 三、編譯測試驗證 …

虛函數應用和原理

虛函數的表現形式 用子類初始化父類指針, 調用虛函數時, 仍然調用的是子類的虛函數 測試代碼如下 #include <iostream> #include <string.h>using namespace std;class A { public:void test() { cout << a << endl; };virtual void test2 (){ cout …

LeetCode-2589. 完成所有任務的最少時間【棧 貪心 數組 二分查找 排序】

LeetCode-2589. 完成所有任務的最少時間【棧 貪心 數組 二分查找 排序】 題目描述&#xff1a;解題思路一&#xff1a;貪心暴力解題思路二&#xff1a;棧二分查找解題思路三&#xff1a;簡化版 題目描述&#xff1a; 你有一臺電腦&#xff0c;它可以 同時 運行無數個任務。給你…

解鎖電商數據之門:京東商品詳情API接口的深度解析與應用指南

一、京東商品詳情API簡介 京東商品詳情API是京東開放平臺提供的一項服務&#xff0c;允許第三方應用通過調用接口獲取京東商城中商品的詳細信息。這些信息包括但不限于商品名稱、價格、庫存、詳情描述、用戶評價等。 二、功能特點 數據全面&#xff1a;提供商品的全方位數據…