Transformer實戰(18)——微調Transformer語言模型進行回歸分析

Transformer實戰(18)——微調Transformer語言模型進行回歸分析

    • 0. 前言
    • 1. 回歸模型
    • 2. 數據處理
    • 3. 模型構建與訓練
    • 4. 模型推理
    • 小結
    • 系列鏈接

0. 前言

在自然語言處理領域中,預訓練 Transformer 模型不僅能勝任離散類別預測,也可用于連續數值回歸任務。本節介紹了如何將 DistilBert 轉變為回歸模型,為模型賦予預測連續相似度分值的能力。我們以 GLUE 基準中的語義文本相似度 (STS-B) 數據集為例,詳細介紹配置 DistilBertConfig、加載數據集、分詞并構建 TrainingArguments,并定義 Pearson/Spearman 相關系數等回歸指標。

1. 回歸模型

回歸模型通常最后一層只有一個神經元,它不會通過 softmax 邏輯回歸處理,而是進行歸一化。為了定義模型并在頂部添加一個單神經元的輸出層,有兩種方法:直接在 BERT.from_pretrained() 方法中使用參數 num_labels=1,或者通過 config 對象傳遞此信息。首先需要從預訓練模型的 config 對象中復制這些信息:

from transformers import DistilBertConfig, DistilBertTokenizerFast, DistilBertForSequenceClassification
MODEL_PATH='distilbert-base-uncased'
config = DistilBertConfig.from_pretrained(MODEL_PATH, num_labels=1)
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_PATH, config=config)

由于我們設置了 num_labels=1 參數,因此預訓練模型的輸出層包含一個神經元。接下來,準備數據集微調模型進行回歸分析。
在本節中,我們將使用語義文本相似度基準 (STS-B) 數據集,它包含從新聞標題等多種內容中提取的句子對。每對句子都有一個從 15 的相似度評分,我們的任務是微調 DistilBert 模型以預測這些評分,并使用 Pearson/Spearman 相關系數來評估模型。

2. 數據處理

(1) 加載數據。將原始數據分為三部分,但由于測試集沒有標簽,所以我們可以將驗證數據分為兩部分:

import datasets
from datasets import load_dataset
stsb_train= load_dataset('glue','stsb', split="train")
stsb_validation = load_dataset('glue','stsb', split="validation")
stsb_validation=stsb_validation.shuffle(seed=42)
stsb_val= datasets.Dataset.from_dict(stsb_validation[:750])
stsb_test= datasets.Dataset.from_dict(stsb_validation[750:])

(2) 使用 pandas 來整理 stsb_train 訓練數據:

import pandas as pd
pd.DataFrame(stsb_train)

整理后的訓練數據樣本如下:

數據樣本

(3) 查看三個數據集的形狀:

stsb_train.shape, stsb_val.shape, stsb_test.shape
# ((5749, 4), (750, 4), (750, 4))

(4) 對數據集進行分詞處理:

enc_train = stsb_train.map(lambda e: tokenizer( e['sentence1'],e['sentence2'], padding=True, truncation=True), batched=True, batch_size=1000) 
enc_val =   stsb_val.map(lambda e: tokenizer( e['sentence1'],e['sentence2'], padding=True, truncation=True), batched=True, batch_size=1000) 
enc_test =  stsb_test.map(lambda e: tokenizer( e['sentence1'],e['sentence2'], padding=True, truncation=True), batched=True, batch_size=1000) 

(5) 分詞器將兩個句子用 [SEP] 分隔符連接,并為句子對生成 input_idsattention_mask

pd.DataFrame(enc_train)

輸出結果如下:

輸出結果

3. 模型構建與訓練

(1)TrainingArguments 類中定義參數集:

from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(# The output directory where the model predictions and checkpoints will be writtenoutput_dir='./stsb-model', do_train=True,do_eval=True,#  The number of epochs, defaults to 3.0 num_train_epochs=3,              per_device_train_batch_size=32,  per_device_eval_batch_size=64,# Number of steps used for a linear warmupwarmup_steps=100,                weight_decay=0.01,# TensorBoard log directorylogging_strategy='steps',                logging_dir='./logs',            logging_steps=50,# other options : no, stepsevaluation_strategy="epoch",save_strategy="epoch",fp16=True,load_best_model_at_end=True
)

(2) 定義 compute_metrics 函數。其中,評估指標基于皮爾遜相關系數 (Pearson correlation coefficient) 和斯皮爾曼等級相關系數 (Spearman’s rank correlation) 法,此外,還提供均方誤差 (Mean Square Error, MSE)、均方根誤差 (Root Mean Square Error, RMSE) 和平均絕對誤差 (Mean Absolute Error, MAE) 等常用的回歸模型評估指標:

from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
import numpy as np
from scipy.stats import pearsonr
from scipy.stats import spearmanr
def compute_metrics(pred):preds = np.squeeze(pred.predictions) return {"MSE": ((preds - pred.label_ids) ** 2).mean().item(),"RMSE": (np.sqrt ((  (preds - pred.label_ids) ** 2).mean())).item(),"MAE": (np.abs(preds - pred.label_ids)).mean().item(),"Pearson" : pearsonr(preds,pred.label_ids)[0],"Spearman's Rank" : spearmanr(preds,pred.label_ids)[0]}

(3) 實例化 Trainer 對象:

trainer = Trainer(model=model,args=training_args,train_dataset=enc_train,eval_dataset=enc_val,compute_metrics=compute_metrics,tokenizer=tokenizer)

(4) 運行訓練過程:

train_result = trainer.train()
metrics = train_result.metrics

輸出結果如下:

輸出結果
最佳驗證損失為 0.542073,評估最佳權重模型:

q=[trainer.evaluate(eval_dataset=data) for data in [enc_train, enc_val, enc_test]]
pd.DataFrame(q, index=["train","val","test"]).iloc[:,:6]

輸出結果如下:

輸出結果

在測試數據集上,PearsonSpearman 相關系數得分分別為 87.6987.64

4. 模型推理

(1) 運行模型進行推理。以下面兩個意義相同的句子為例,將它們輸入模型:

s1,s2="A plane is taking off.",	"An air plane is taking off."
encoding = tokenizer(s1,s2, return_tensors='pt', padding=True, truncation=True, max_length=512)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
outputs.logits.item()
# 4.57421875

(2) 接下來,將語義不同的句子對輸入模型:

s1,s2="The men are playing soccer.",	"A man is riding a motorcycle."
encoding = tokenizer("hey how are you there","hey how are you", return_tensors='pt', padding=True, truncation=True, max_length=512)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(input_ids, attention_mask=attention_mask)
outputs.logits.item()
# 3.1953125

(3) 最后,保存模型:

model_path = "sentence-pair-regression-model"
trainer.save_model(model_path)
tokenizer.save_pretrained(model_path)

小結

本節介紹了如何基于預訓練 DistilBert 架構完成語義相似度回歸分析。首先,通過修改配置或傳參的方式,為模型頂層添加單神經元回歸頭;隨后,借助 STS-B 數據集構建訓練、驗證與測試集,并應用分詞器生成模型輸入。接著,使用 Trainer 框架與自定義的 compute_metrics 函數,對模型在 MSERMSEMAEPearsonSpearman 相關性等多維度指標上進行評估,驗證了微調方法在回歸任務中的有效性。

系列鏈接

Transformer實戰(1)——詞嵌入技術詳解
Transformer實戰(2)——循環神經網絡詳解
Transformer實戰(3)——從詞袋模型到Transformer:NLP技術演進
Transformer實戰(4)——從零開始構建Transformer
Transformer實戰(5)——Hugging Face環境配置與應用詳解
Transformer實戰(6)——Transformer模型性能評估
Transformer實戰(7)——datasets庫核心功能解析
Transformer實戰(8)——BERT模型詳解與實現
Transformer實戰(9)——Transformer分詞算法詳解
Transformer實戰(10)——生成式語言模型 (Generative Language Model, GLM)
Transformer實戰(11)——從零開始構建GPT模型
Transformer實戰(12)——基于Transformer的文本到文本模型
Transformer實戰(13)——從零開始訓練GPT-2語言模型
Transformer實戰(14)——微調Transformer語言模型用于文本分類
Transformer實戰(15)——使用PyTorch微調Transformer語言模型
Transformer實戰(16)——微調Transformer語言模型用于多類別文本分類
Transformer實戰(17)——微調Transformer語言模型進行多標簽文本分類

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/96594.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/96594.shtml
英文地址,請注明出處:http://en.pswp.cn/web/96594.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

【Linux】【實戰向】Linux 進程替換避坑指南:從理解 bash 阻塞等待,到親手實現能執行 ls/cd 的 Shell

前言:歡迎各位光臨本博客,這里小編帶你直接手撕,文章并不復雜,愿諸君耐其心性,忘卻雜塵,道有所長!!!! IF’Maxue:個人主頁🔥 個人專欄…

linux常用命令 (3)——系統包管理

博客主頁:christine-rr-CSDN博客 ????? ?? hi,大家好,我是christine-rr ! 今天來分享一下linux常用命令——系統包管理 目錄linux常用命令---系統包管理(一)Debian 系發行版(Ubuntu、Debian、Linux …

YOLOv8 mac-intel芯片 部署指南

🚀 在 Jupyter Notebook 和 PyCharm 中使用 Conda 虛擬環境(YOLOv8 部署指南,Python 3.9) YOLOv8 是 Ultralytics 開源的最新目標檢測模型,輕量高效,支持分類、檢測、分割等多種任務。 在 Mac(…

【高等數學】第十一章 曲線積分與曲面積分——第六節 高斯公式 通量與散度

上一節:【高等數學】第十一章 曲線積分與曲面積分——第五節 對坐標的曲面積分 總目錄:【高等數學】 目錄 文章目錄1. 高斯公式2. 沿任意閉曲面的曲面積分為零的條件3. 通量與散度1. 高斯公式 設空間區域ΩΩΩ是由分片光滑的閉曲面ΣΣΣ所圍成&#x…

IDEA試用過期,無法登錄,重置方法

IDEA過期,重置方法: IntelliJ IDEA 2024.2.0.2 (親測有效) 最新Idea重置辦法!: 方法一: 1、刪除C:\Users\{用戶名}\AppData\Local\JetBrains\IntelliJIdea2024.2 下所有文件(注意:是子目錄全部刪除) 2、刪除C:\Users\{用戶名}\App…

創建用戶自定義橋接網絡并連接容器

1.創建用戶自定義的 alpine-net 網絡[roothost1 ~]# docker network create --driver bridge alpine-net 9f6d634e6bd7327163a9d83023e435da6d61bc6cf04c9d96001d1b64eefe4a712.列出 Docker 主機上的網絡[roothost1 ~]# docker network ls NETWORK ID NAME DRIVER …

Vue3 + Vite + Element Plus web轉為 Electron 應用,解決無法登錄、隱藏自定義導航欄

如何在vue3 Vite Element Plus搭好的架構下轉為 electron應用呢? https://www.electronjs.org/zh/docs/latest/官方文檔 https://www.electronjs.org/zh/docs/latest/ 第一步:安裝 electron相關依賴 npm install electron electron-builder concurr…

qt QAreaLegendMarker詳解

1. 概述QAreaLegendMarker 是 Qt Charts 模塊中的一部分,用于在圖例(Legend)中表示 QAreaSeries 的標記。它負責顯示區域圖的圖例項,通常包含區域顏色樣例和對應的描述文字。圖例標記和對應的區域圖關聯,顯示區域的名稱…

linux 函數 kstrtoul

kstrtoul 函數概述 kstrtoul 是 Linux 內核中的一個函數&#xff0c;用于將字符串轉換為無符號長整型&#xff08;unsigned long&#xff09;。該函數定義在 <linux/kernel.h> 頭文件中&#xff0c;常用于內核模塊中解析用戶空間傳遞的字符串參數。 函數原型 int kstrtou…

LLM(三)

一、人類反饋的強化學習&#xff08;RLHF&#xff09;微調的目標是通過指令&#xff0c;包括路徑方法&#xff0c;進一步訓練你的模型&#xff0c;使他們更好地理解人類的提示&#xff0c;并生成更像人類的回應。RLHF&#xff1a;使用人類反饋微調型語言模型&#xff0c;使用強…

DPO vs PPO,偏好優化的兩條技術路徑

1. 背景在大模型對齊&#xff08;alignment&#xff09;里&#xff0c;常見的兩類方法是&#xff1a;PPO&#xff1a;強化學習經典算法&#xff0c;OpenAI 在 RLHF 里用它來“用獎勵模型更新策略”。DPO&#xff1a;2023 年提出的新方法&#xff08;參考論文《Direct Preferenc…

BLE6.0信道探測,如何重構物聯網設備的距離感知邏輯?

在物聯網&#xff08;IoT&#xff09;無線通信技術快速滲透的當下&#xff0c;實現人與物、物與物之間對物理距離的感知響應能力已成為提升設備智能高度與人們交互體驗的關鍵所在。當智能冰箱感知用戶靠近而主動亮屏顯示內部果蔬時、當門禁系統感知到授權人士靠近而主動開門時、…

【計算機 UTF-8 轉換為本地編碼的含義】

UTF-8 轉換為本地編碼的含義 詳細解釋一下"UTF-8轉換為本地編碼"的含義以及為什么在處理中文時這很重要。 基本概念 UTF-8 編碼 國際標準&#xff1a;UTF-8 是一種能夠表示世界上幾乎所有字符的 Unicode 編碼方式跨平臺兼容&#xff1a;無論在哪里&#xff0c;UTF-8 …

4.6 變體

1.變體簡介 2.為什么需要變體 3.變體是如何產生的 4.變體帶來的麻煩 5.multi_compile和shader_feature1.變體簡介 比如我們開了一家餐廳, 你有一本萬能的菜單(Shader源代碼), 上面包含了所有可能的菜式; 但是顧客每次來點餐時, 不可能將整本菜單都做一遍, 他們會根據今天有沒有…

猿輔導Android開發面試題及參考答案(下)

為什么開發中要使用線程池,而不是直接創建線程(如控制線程數量、復用線程、降低開銷)? 開發中優先使用線程池而非直接創建線程,核心原因是線程池能優化線程管理、降低資源消耗、提高系統穩定性,而直接創建線程存在難以解決的缺陷,具體如下: 控制線程數量,避免資源耗盡…

【網絡通信】IP 地址深度解析:從技術原理到企業級應用?

IP 地址深度解析&#xff1a;從技術原理到企業級應用? 文章目錄IP 地址深度解析&#xff1a;從技術原理到企業級應用?前言一、基礎認知&#xff1a;IP 地址的技術定位與核心特性?1.1 定義與網絡層角色1.2 核心屬性與表示法深化二、地址分類&#xff1a;從類別劃分到無類別路…

grafana實踐

一、如何找到grafana的插件目錄 whereis grafana grafana: /etc/grafana /usr/share/grafana插件安裝目錄、默認安裝目錄&#xff1a; 把vertamedia-clickhouse-datasource-3.4.4.zip解壓到下面目錄&#xff0c;然后重啟就可以了 /var/lib/grafana/plugins# 6. 設置權限 sudo …

uniapp 文件查找失敗:main.js

重裝HbuilderX vue.config.js 的 配置 有問題main.js 框架能自動識別 到&#xff0c;不用多余的配置

KEIL燒錄時提示“SWD/JTAG communication failure”的解決方法

最新在使用JTAG仿真器串口下載調試程序時&#xff0c;老是下載不成功&#xff0c;識別不到芯片&#xff0c;我嘗試重啟keil5或者重新插拔仿真器連接線、甚至重啟電腦也都不行&#xff0c;每次下載程序都提示如下信息&#xff1a;在確定硬件連接沒有問題之后&#xff0c;就開始分…

紅日靶場(三)——個人筆記

環境搭建 添加一張網卡&#xff08;僅主機模式&#xff09;&#xff0c;192.168.93.0/24 網段 開啟centos&#xff0c;第一次運行&#xff0c;重啟網絡服務 service network restart192.168.43.57/24&#xff08;外網ip&#xff09; 192.168.93.100/24&#xff08;內網ip&am…