240705_昇思學習打卡-Day17-基于 MindSpore 實現 BERT 對話情緒識別

240705_昇思學習打卡-Day17-基于 MindSpore 實現 BERT對話情緒識別

近期確實太忙,此處僅作簡單記錄:

模型簡介

BERT全稱是來自變換器的雙向編碼器表征量(Bidirectional Encoder Representations from Transformers),它是Google于2018年末開發并發布的一種新型語言模型。與BERT模型相似的預訓練語言模型例如問答、命名實體識別、自然語言推理、文本分類等在許多自然語言處理任務中發揮著重要作用。模型是基于Transformer中的Encoder并加上雙向的結構,因此一定要熟練掌握Transformer的Encoder的結構。

image-20240705234457785

關于Transformer的Encoder的結構在這篇中有提及,可以去參考看看240701_昇思學習打卡-Day13-Vision Transformer圖像分類-CSDN博客

BERT模型的主要創新點都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction兩種方法分別捕捉詞語和句子級別的representation。

在用Masked Language Model方法訓練BERT的時候,隨機把語料庫中15%的單詞做Mask操作。對于這15%的單詞做Mask操作分為三種情況:80%的單詞直接用[Mask]替換、10%的單詞直接替換成另一個新的單詞、10%的單詞保持不變。

因為涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之類的任務,增加了Next Sentence Prediction預訓練任務,目的是讓模型理解兩個句子之間的聯系。與Masked Language Model任務相比,Next Sentence Prediction更簡單些,訓練的輸入是句子A和B,B有一半的幾率是A的下一句,輸入這兩個句子,BERT模型預測B是不是A的下一句。

BERT預訓練之后,會保存它的Embedding table和12層Transformer權重(BERT-BASE)或24層Transformer權重(BERT-LARGE)。使用預訓練好的BERT模型可以對下游任務進行Fine-tuning,比如:文本分類、相似度判斷、閱讀理解等。

對話情緒識別(Emotion Detection,簡稱EmoTect),專注于識別智能對話場景中用戶的情緒,針對智能對話場景中的用戶文本,自動判斷該文本的情緒類別并給出相應的置信度,情緒類型分為積極、消極、中性。 對話情緒識別適用于聊天、客服等多個場景,能夠幫助企業更好地把握對話質量、改善產品的用戶交互體驗,也能分析客服服務質量、降低人工質檢成本。

下面以一個文本情感分類任務為例子來說明BERT模型的整個應用過程。

我們假設已經裝好了MindSpore環境

# 該案例在 mindnlp 0.3.1 版本完成適配,如果發現案例跑不通,可以指定mindnlp版本,執行`!pip install mindnlp==0.3.1`
!pip install mindnlp
import osimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, contextfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
# prepare dataset
class SentimentDataset:"""Sentiment Dataset"""def __init__(self, path):self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:label, text_a = line.split("\t")self._labels.append(int(label))self._text_a.append(text_a)def __getitem__(self, index):return self._labels[index], self._text_a[index]def __len__(self):return len(self._labels)
# 準備數據集
class 情感分析數據集(SentimentDataset):"""情感分析數據集類,用于加載和管理數據集。參數:path (str): 數據集文件的路徑。屬性:_labels (list): 存儲情感標簽的列表。_text_a (list): 存儲文本內容的列表。方法:_load(): 從指定路徑加載數據集文件,解析內容并存儲到_labels和_text_a中。__getitem__(index): 根據索引返回特定樣本的標簽和文本。__len__(): 返回數據集的樣本數量。"""def __init__(self, path):"""初始化情感分析數據集對象,設置數據路徑并加載數據。參數:path (str): 數據集文件的路徑。"""self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):"""私有方法:讀取數據集文件,按行處理數據,分割標簽和文本,并存儲到實例變量中。"""with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:  # 跳過首行(假設為列名)和末尾的空行label, text_a = line.split("\t")self._labels.append(int(label))  # 添加標簽到_labels列表self._text_a.append(text_a)  # 添加文本到_text_a列表def __getitem__(self, index):"""通過索引獲取數據集中對應樣本的標簽和文本。參數:index (int): 數據樣本的索引位置。返回:tuple: 包含樣本標簽和文本的元組 (label, text)。"""return self._labels[index], self._text_a[index]def __len__(self):"""返回數據集中的樣本數量。返回:int: 數據集樣本數量。"""return len(self._labels)

數據集

這里提供一份已標注的、經過分詞預處理的機器人聊天數據集,來自于百度飛槳團隊。數據由兩列組成,以制表符(‘\t’)分隔,第一列是情緒分類的類別(0表示消極;1表示中性;2表示積極),第二列是以空格分詞的中文文本,如下示例,文件為 utf8 編碼。

label–text_a

0–誰罵人了?我從來不罵人,我罵的都不是人,你是人嗎 ?

1–我有事等會兒就回來和你聊

2–我見到你很高興謝謝你幫我

這部分主要包括數據集讀取,數據格式轉換,數據 Tokenize 處理和 pad 操作。

# download dataset
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz

數據加載和數據預處理

新建 process_dataset 函數用于數據加載和數據預處理,具體內容可見下面代碼注釋。

import numpy as npdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):"""處理數據集,將其轉換為適合模型訓練的格式。參數:source: 數據集的來源,可以是文件路徑或數據生成器。tokenizer: 用于將文本序列化為模型輸入的標記化器。max_seq_len: 最大序列長度,超過這個長度的序列將被截斷。batch_size: 每個批次的樣本數量。shuffle: 是否在處理數據集前打亂數據順序。返回:經過處理后的數據集,包括輸入序列和標簽。"""# 判斷是否在昇騰設備上運行is_ascend = mindspore.get_context('device_target') == 'Ascend'# 定義數據集的列名column_names = ["label", "text_a"]# 創建數據集對象dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)# 將字符串類型轉換為整型type_cast_op = transforms.TypeCast(mindspore.int32)# 定義文本標記化和填充函數def tokenize_and_pad(text):"""對文本進行標記化和填充,以適應模型的要求。參數:text: 需要處理的文本。返回:標記化和填充后的輸入序列和注意力掩碼。"""if is_ascend:# 在昇騰設備上,使用特定的處理方式tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:# 在其他設備上,直接進行標記化tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']# 對文本列進行標記化和填充處理dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])# 對標簽列進行類型轉換dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')# 根據設備類型選擇合適的批次處理方式if is_ascend:# 在昇騰設備上,使用簡單的批次處理dataset = dataset.batch(batch_size)else:# 在其他設備上,使用帶填充的批次處理dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

數據預處理部分采用靜態Shape處理:

# 導入BertTokenizer類,用于BERT模型的預訓練 tokenizer
from mindnlp.transformers import BertTokenizer# 初始化一個BertTokenizer實例,用于處理中文文本
# 這里使用了預訓練的'bert-base-chinese'模型,該模型已經在中文文本上進行了預訓練
# 選擇這個預訓練模型是因為我們的任務是處理中文文本,需要一個針對中文優化的tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer.pad_token_id
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
dataset_train.get_col_names()
print(next(dataset_train.create_tuple_iterator()))

image-20240706000224197

模型構建

通過 BertForSequenceClassification 構建用于情感分類的 BERT 模型,加載預訓練權重,設置情感三分類的超參數自動構建模型。后面對模型采用自動混合精度操作,提高訓練的速度,然后實例化優化器,緊接著實例化評價指標,設置模型訓練的權重保存策略,最后就是構建訓練器,模型開始訓練。

# 導入MindNLP庫中用于序列分類任務的BertForSequenceClassification模型與用于獲取文本編碼表示的BertModel
from mindnlp.transformers import BertForSequenceClassification, BertModel
# 導入auto_mixed_precision函數以啟用混合精度訓練,能夠加速訓練過程并減少內存占用
from mindnlp._legacy.amp import auto_mixed_precision# 根據預訓練的'bert-base-chinese'模型初始化BertForSequenceClassification模型,設置類別數為3
# 此模型適用于如文本分類任務,將輸入文本歸類到三個預定義類別中的一個
# 設置BERT模型配置及訓練所需參數
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
# 使用auto_mixed_precision函數對模型應用混合精度訓練策略,采用'O1'優化級別
# 混合精度訓練通過結合使用float16和float32數據類型來提升訓練速度并節省內存資源
model = auto_mixed_precision(model, 'O1')# 定義模型訓練使用的優化器為Adam算法,設置學習率為2e-5,并僅針對模型中可訓練參數進行優化
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
# 初始化Accuracy類,用于計算模型預測的準確率
metric = Accuracy()# 定義回調函數以保存訓練過程中的檢查點
# CheckpointCallback用于在指定的epoch后保存模型,保存路徑為'checkpoint',檢查點文件名為'bert_emotect'
# 參數epochs設為1表示每個epoch后保存一次,keep_checkpoint_max=2表示最多保留2個檢查點文件
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)# BestModelCallback用于自動保存驗證性能最優的模型,同樣保存在'checkpoint'路徑下,文件名為'bert_emotect_best'
# 設置auto_load=True可在訓練結束后自動加載該最優模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)# 創建Trainer實例以組織訓練流程
# network參數指定訓練的模型,train_dataset和eval_dataset分別指定了訓練集和驗證集
# metrics參數指定了評估模型性能的指標,此處為剛剛定義的準確率Accuracy
# epochs設置訓練輪次為5,optimizer為訓練使用的優化器,callbacks列表包含了之前定義的保存檢查點和最佳模型的回調函數
trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])
%%time
# start training
trainer.run(tgt_columns="labels")

模型驗證

將驗證數據集加再進訓練好的模型,對數據集進行驗證,查看模型在驗證數據上面的效果,此處的評價指標為準確率。

# 初始化Evaluator對象,用于評估模型性能
# 參數說明:
# network: 待評估的模型
# eval_dataset: 用于評估的測試數據集
# metrics: 評估指標
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)# 執行模型評估,指定目標列作為評估標簽
# 該步驟將計算模型在測試數據集上的指定評估指標
evaluator.run(tgt_columns="labels")
dataset_infer = SentimentDataset("data/infer.tsv")
def predict(text, label=None):"""根據給定的文本進行情感分析預測。參數:text (str): 需要進行情感分析的文本。label (int, optional): 用于比較的預定義標簽。如果提供,將打印預測標簽和給定標簽的比較。返回:無返回值,但打印了模型預測的情感標簽以及輸入文本。"""# 映射預測結果的標簽到人類可讀的情感描述label_map = {0: "消極", 1: "中性", 2: "積極"}# 將文本轉換為模型輸入所需的格式text_tokenized = Tensor([tokenizer(text).input_ids])# 使用模型預測文本的情感logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()# 構建包含預測信息的字符串info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:# 如果提供了標簽,則添加實際標簽的信息info += f" , label: '{label_map[label]}'"# 打印預測結果print(info)
from mindspore import Tensorfor label, text in dataset_infer:predict(text, label)

image-20240706000417506

自定義推理數據集

自己輸入一句話,進行測試

predict("家人們咱就是說一整個無語住了 絕絕子疊buff")

打卡圖片:

image-20240705235200648

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/diannao/41828.shtml
繁體地址,請注明出處:http://hk.pswp.cn/diannao/41828.shtml
英文地址,請注明出處:http://en.pswp.cn/diannao/41828.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【wordpress教程】wordpress博客網站添加非法關鍵詞攔截

有的網站經常被惡意搜索,站長們不勝其煩。那我們如何屏蔽惡意搜索關鍵詞呢?下面就隨小編一起來解決這個問題吧。 后臺設置預覽圖: 設置教程: 1、把以下代碼添加至當前主題的 functions.php 文件中: add_action(admi…

【PyTorch】torch.fmod使用截斷正態分布truncated normal distribution初始化神經網絡的權重

這個代碼片段展示了如何用 PyTorch 初始化神經網絡的權重,具體使用的是截斷正態分布(truncated normal distribution)。截斷正態分布意味著生成的值會在一定范圍內截斷,以防止出現極端值。這里使用 torch.fmod 作為一種變通方法實…

配置linux net.ipv4.ip_forward數據包轉發

前言 出于系統安全考慮,在默認情況下,Linux系統是禁止數據包轉發的。數據包轉發指的是當主機擁有多個網卡時,通過一個網卡接收到的數據包,根據目的IP地址來轉發數據包到其他網卡。這個功能通常用于路由器。 如果在Linux系統中需要…

CVPR 2024最佳論文分享:通過解釋方法比較Transformers和CNNs的決策機制

CVPR(Conference on Computer Vision and Pattern Recognition)是計算機視覺領域最有影響力的會議之一,主要方向包括圖像和視頻處理、目標檢測與識別、三維視覺等。近期,CVPR 2024 公布了最佳論文。共有10篇論文獲獎,其…

計算組的妙用!!頁面權限控制

需求描述: 某些特殊的場景下,針對某頁看板,需要進行數據權限卡控,但是又不能對全部的數據進行RLS處理,這種情況下可以利用計算組來解決這個需求。 實際場景 事實表包含產品維度和銷售維度 兩個維度屬于同一公司下面的…

限幅濾波法

限幅濾波法 限幅濾波法:根據經驗判斷,確定兩次采樣允許的最大偏差值(設為A),每次檢測到新值時判斷:如果本次值與上次值之差<=A,則本次值有效,如果本次值與上次值之差>A,則本次值無效,放棄本次值,用上次值代替本次值。 優點: 能有效克服因偶然因素引起的脈沖…

【Python】已解決:FileNotFoundError: [Errno 2] No such file or directory: ‘./1.xml’

文章目錄 一、分析問題背景二、可能出錯的原因三、錯誤代碼示例四、正確代碼示例五、注意事項 已解決&#xff1a;FileNotFoundError: [Errno 2] No such file or directory: ‘./1.xml’ 一、分析問題背景 在Python編程中&#xff0c;FileNotFoundError是一個常見的異常&…

ChatGPT對話:Python程序自動模擬操作網頁,無法彈出下拉列表框

【編者按】需要編寫Python程序自動模擬操作網頁。編者有編程經驗&#xff0c;但沒有前端編程經驗&#xff0c;完全不知道如何編寫這種程序。通過與ChatGPT討論&#xff0c;1天完成了任務。因為沒有這類程序的編程經驗&#xff0c;需要邊學習&#xff0c;邊編程&#xff0c;遇到…

貝爾曼方程(Bellman Equation)

貝爾曼方程(Bellman Equation) 貝爾曼方程(Bellman Equation)是動態規劃和強化學習中的核心概念,用于描述最優決策問題中的價值函數的遞歸關系。它為狀態值函數和動作值函數提供了一個重要的遞推公式,幫助我們計算每個狀態或狀態-動作對的預期回報。 貝爾曼方程的原理 …

Python 自動化測試必會技能板塊—unittest框架

說到 Python 的單元測試框架&#xff0c;想必接觸過 Python 的朋友腦袋里第一個想到的就是 unittest。 的確&#xff0c;作為 Python 的標準庫&#xff0c;它很優秀&#xff0c;并被廣泛應用于各個項目。但其實在 Python 眾多項目中&#xff0c;主流的單元測試框架遠不止這一個…

西門子PLC1200--與電腦S7通訊

硬件構成 PLC為西門子1211DCDCDC 電腦上位機用PYTHON編寫 二者通訊用網線&#xff0c;通訊協議用S7 PLC上的數據 PLC上的數據是2個uint&#xff0c;在DB1&#xff0c;地址偏移分別是0和2 需要注意的是DB塊要關閉優化的塊訪問&#xff0c;否則是沒有偏移地址的 PLC中的數據內…

elementui中日期/時間的禁用處理,使用傳值的方式

項目中,經常會用到 在一個學年或者一個學期或者某一個時間段需要做的某件事情,則我們需要在創建這個事件的時候,需要設置一定的時間周期,那這個時間周期就需要給一定的限制處理,避免用戶的誤操作,優化用戶體驗 如下:需求為,在選擇學年后,學期的設置需要在學年中,且結束時間大…

Spring Cloud Gateway如何匹配某路徑并進行路由轉發

本案例&#xff0c;將/helloworld-app/**的請求轉發到helloworld微服務的/**路徑&#xff08;既如lb://helloworld/**&#xff09; 配置如下&#xff08;見spring.cloud.gateway.routes配置&#xff09;&#xff1a; spring:application:name: SpringCloudGatewayDemocloud:n…

軟件架構之計算機組成與體系結構

1.1計算機系統組成 計算機系統是一個硬件和軟件的綜合體&#xff0c;可以把它看成按功能劃分的多級層次結構。 1.1.1 計算機硬件的組成 硬件通常是指一切看得見&#xff0c;摸得到的設備實體。原始的馮?諾依曼&#xff08;VonNeumann&#xff09;計算機在結構上是以運算器為…

2024年中國十大杰出起名大師排行榜,最厲害的易經姓名學改名字專家

在2024年揭曉的中國十大杰出易學泰斗評選中&#xff0c;一系列對姓名學與國學易經有深入研究的專家榮登榜單。其中&#xff0c;中國十大權威姓名學專家泰斗頂級杰出代表人物的師傅顏廷利大師以其在國際舞臺上的卓越貢獻和深邃學識&#xff0c;被公認為姓名學及易經起名領域的權…

C#程序調用Sql Server存儲過程異常處理:調用存儲過程后不返回、不拋異常的解決方案

目錄 一、代碼解析&#xff1a; 二、解決方案 1、增加日志記錄 2、異步操作 注意事項 3、增加超時機制 4、使用線程池 5、使用信號量或事件 6、監控數據庫連接狀態 在C#程序操作Sql Server數據庫的實際應用中&#xff0c;若異常就會拋出異常&#xff0c;我們還能找到異…

Leetcode 完美數

1.題目要求: 對于一個 正整數&#xff0c;如果它和除了它自身以外的所有 正因子 之和相等&#xff0c;我們稱它為 「完美數」。給定一個 整數 n&#xff0c; 如果是完美數&#xff0c;返回 true&#xff1b;否則返回 false。示例 1&#xff1a;輸入&#xff1a;num 28 輸出&a…

2024年6月份找工作和面試總結

轉眼間6月份已經過完了&#xff0c;2024年已經過了一半&#xff0c;希望大家都找到了合適的工作。 本人前段時間寫了5月份找工作的情況&#xff0c;請查看2024年5月份面試總結-CSDN博客 但是后續寫的總結被和諧了&#xff0c;不知道這篇文章能不能發出來。 1、6月份面試機會依…

網絡爬蟲基礎

網絡爬蟲基礎 網絡爬蟲&#xff0c;也被稱為網絡蜘蛛或爬蟲&#xff0c;是一種用于自動瀏覽互聯網并從網頁中提取信息的軟件程序。它們能夠訪問網站&#xff0c;解析頁面內容&#xff0c;并收集所需數據。Python語言因其簡潔的語法和強大的庫支持&#xff0c;成為實現網絡爬蟲…

verilog讀寫文件注意事項

想要的16進制數是文本格式提供的文件&#xff0c;想將16進制數提取到變量內&#xff0c; 可以使用 f s c a n f ( f d 1 , " 也可以使用 fscanf(fd1,"%h",rd_byte);實現 也可以使用 fscanf(fd1,"也可以使用readmemh(“./FILE/1.txt”,mem);//fe放在mem[0…