bert 相似度任務訓練,簡單版本

目錄

任務

代碼

train.py

predit.py

數據


任務

使用 bert-base-chinese 訓練相似度任務,參考:微調BERT模型實現相似性判斷 - 知乎

參考他上面代碼,他使用的是?BertForNextSentencePrediction 模型,BertForNextSentencePrediction?原本是設計用于下一個句子預測任務的。在BERT的原始訓練中,模型會接收到一對句子,并試圖預測第二個句子是否緊跟在第一個句子之后;所以使用這個模型標簽(label)只能是 0,1,相當于二分類任務了

但其實在相似度任務中,我們每一條數據都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以給兩個文本打分表示相似度,也可以映射為分類任務,0 代表不相似,1 代表相似,他這篇文章利用了這種思想,對新手還挺有用的。

現在我搞了一個招聘數據,里面有辦公區域列,處理過了,每一行代表【地址1\t地址2\t相似度】

只要兩文本中有一個地址相似我就作為相似,標簽為 1,否則 0

利用這數據微調,沒有使用驗證數據集,就最后使用測試集來看看效果。

代碼

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5# 準備數據集
class MyDataset(Dataset):def __init__(self, data_file_paths):self.texts = []self.labels = []# 分詞器用默認的self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')# 自己實現對數據集的解析with open(data_file_paths, 'r', encoding='utf-8') as f:for line in f:text1, text2, label = line.split('\t')self.texts.append((text1, text2))self.labels.append(int(label))def __len__(self):return len(self.texts)def __getitem__(self, idx):text1, text2 = self.texts[idx]label = self.labels[idx]encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')return encoded_text, label# 訓練數據文件路徑
train_dataset = MyDataset('../data/train.txt')# 定義模型
# num_labels=5 定義相似度評分有幾個
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 訓練模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0for epoch in range(epoch):trained_data = 0for batch in train_loader:inputs, labels = batch# 不知道為啥,出來的數據維度是 (batch_size, 1, 128),需要把第二維去掉inputs['input_ids'] = inputs['input_ids'].squeeze(1)inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)# 因為要用GPU,將數據傳輸到gpu上inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()outputs = model(**inputs, labels=labels)loss, logits = outputs[:2]loss.backward()optimizer.step()trained_data += len(labels)trained_process = float(trained_data) / len(train_dataset)batch_after_last_save += 1total_batch += 1# 每訓練 auto_save_batch 個 batch,保存一次模型if batch_after_last_save >= auto_save_batch:batch_after_last_save = 0model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))print("訓練進度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))total_epoch += 1model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

訓練好后的文件,輸出的最后一個文件夾才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePredictiontokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')with torch.no_grad():with open('../data/test.txt', 'r', encoding='utf8') as f:lines = f.readlines()correct = 0for i, line in enumerate(lines):text1, text2, label = line.split('\t')encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')outputs = model(**encoded_text)res = torch.argmax(outputs.logits, dim=1).item()print(text1, text2, label, res)if str(res) == label.strip('\n'):correct += 1print(f'{i + 1}/{len(lines)}')print(f'acc:{correct / len(lines)}')

可以看到還是較好的學習了我數據特征:只要兩文本中有一個地址相似我就作為相似,標簽為 1,否則 0

數據

鏈接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw?
提取碼:eryw?
鏈接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg?
提取碼:o8py?
?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/712869.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/712869.shtml
英文地址,請注明出處:http://en.pswp.cn/news/712869.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

thinkphp學習10-數據庫的修改刪除

數據修改 使用 update()方法來修改數據,修改成功返回影響行數,沒有修改返回 0 public function index(){$data [username > 孫悟空1,];return Db::name(user)->where(id,11)->update($data);}如果修改數據包含了主鍵信息,比如 i…

STM32標準庫開發——BKP備份RTC時鐘

備份寄存器BKP(Backup Registers) 由于RTC與BKP關聯性較高,所以RTC的時鐘校準寄存器以及一些功能都放在了BKP中。TAMPER引腳主要用于防止芯片數據泄露,可以設計一個機關當TAMPER引腳發生電平跳變時自動清除寄存器內數據不同芯片BKP區別,主要體…

c++入門(2)

上期我們說到了部分c修補C語言的不足,今天我們將剩下的一一說清楚。 函數重載 (1).函數重載的形式 C語言不允許函數名相同的同時存在,但是C允許同名函數存在,但是有要求:函數名相同,參數不同,構成函數重…

【數據結構-圖論】并查集

并查集(Union-Find)是一種數據結構,它提供了處理一些不交集的合并及查詢問題的高效方法。并查集主要支持兩種操作: 查找(Find):確定某個元素屬于哪個子集,這通常意味著找到該子集的…

vue購物車實戰

1.引入vue <script src"https://cdn.jsdelivr.net/npm/vue2.7.14/dist/vue.js"></script> 2.設置高亮部分的樣式 <style> table tr{text-align: center;}.skyblue{background-color: skyblue;}</style> 3.設置body的基本樣式 <div id&q…

人大金倉與mysql的差異與替換

人大金倉中不能使用~下面的符號&#xff0c;字段中使用”&#xff0c;無法識別建表語句 創建表時語句中只定義字段名.字段類型.是否是否為空 Varchar類型改為varchar&#xff08;長度 char&#xff09; Int(0) 類型為int4 定義主鍵&#xff1a;CONSTRAINT 鍵名 主鍵類型&#x…

Found option without preceding group in config file 問題解決

方法就是用記事本打開 然后 左上角點擊 文件 有另存為 就可以選擇編碼格式

Linux設置程序任意位置執行(設置環境變量)

問題 直接編譯出來的可執行程序在執行時需要寫出完整路徑比較麻煩&#xff0c;設置環境變量可以實現在任意位置直接運行。 解決 1.打開.bashrc文件 vim ~/.bashrc 2.修改該文件&#xff08;實現將/home/zhangziheng/file/seqrequester/build/bin&#xff0c;路徑下的可執…

文件流【文件輸入流】

文件流&#xff1a;使用文件輸入流讀取文件中的數據&#xff1a; public class FISDemo {public static void main(String[] args) throws IOException {//將fos.dat文件中的字節讀取回來/*fos.dat文件中的數據:00000001 00000010*/FileInputStream fis new FileInputStream(…

第六節:Vben Admin權限-后端控制方式

系列文章目錄 第一節:Vben Admin介紹和初次運行 第二節:Vben Admin 登錄邏輯梳理和對接后端準備 第三節:Vben Admin登錄對接后端login接口 第四節:Vben Admin登錄對接后端getUserInfo接口 第五節:Vben Admin權限-前端控制方式 文章目錄 系列文章目錄前言一、角色權限(后端…

Java面試題總結6

Spring中有哪些設計模式 簡單工廠&#xff1a;由一個工廠類根據傳入的參數&#xff0c;動態決定應該創建哪一個產品類 工廠方法&#xff1a;實現FactoryBean接口的bean是一類叫做factory的bean 單例模式&#xff1a;保證一個類僅有一個實例&#xff0c;并提供一個訪問它的全…

【辦公類-18-03】(Python)中班米羅可兒證書批量生成打印(班級、姓名)

作品展示——米羅可兒證書打印幼兒姓名 背景需求 2024年3月1日&#xff0c;中4班孩子一起整理美術操作材料《米羅可兒》的操作本——將每一頁紙撕下來&#xff0c;分類擺放、確保紙張上下位置正確。每位孩子們都非常厲害&#xff0c;不僅完成了自己的一本&#xff0c;還將沒有…

C++數據結構與算法——二叉搜索樹的屬性

C第二階段——數據結構和算法&#xff0c;之前學過一點點數據結構&#xff0c;當時是基于Python來學習的&#xff0c;現在基于C查漏補缺&#xff0c;尤其是樹的部分。這一部分計劃一個月&#xff0c;主要利用代碼隨想錄來學習&#xff0c;刷題使用力扣網站&#xff0c;不定時更…

C++編程面試復盤:數組降重+快排+函數指針+類模板

面試真題 真題1 #include <iostream> // 在函數 find_repetnum 的參數列表中&#xff0c;int& length 中的 & 符號表示引用。通過將 length 聲明為引用&#xff0c;函數可以修改傳入的 length 變量的值&#xff0c;并且這種修改會在函數外部生效。 void find_r…

Vue2:路由history模式的項目部署后頁面刷新404問題處理

一、問題描述 我們把Vue項目的路由模式&#xff0c;設置成history 然后&#xff0c;build 并把dist中的代碼部署到nodeexpress服務中 訪問頁面后&#xff0c;刷新頁面報404問題 二、原因分析 server.js文件 會發現&#xff0c;文件中配置的路徑沒有Vue項目中對應的路徑 所以…

React withRouter的使用及源碼實現

一 基本介紹 作用&#xff1a; 把不是通過路由切換過來的組件中&#xff0c;將react-router 的 history、location、match 三個對象傳入props對象上。比如首頁&#xff01; 默認情況下必須是經過路由匹配渲染的組件才存在this.props&#xff0c;才擁有路由參數&#xff0c;才能…

嵌入式學習筆記Day27

今天主要學習了進程間的通信&#xff0c;主要學習了通過管道進行通信 一、進程間的通信 進程間的通信方式有以下幾種&#xff1a; 1.管道 2.信號 3.消息隊列 4.共享內存 5.信號燈 6.套接字二、管道 2.1無名管道 無名管道只能用于具有親緣關系的進程間通信 函數接口&#x…

Nacos進階

目錄 Nacos支持三種配置加載方案 Namespace方案 DataID方案 Group方案 同時加載多個配置集 Nacos支持三種配置加載方案 Nacos支持“Namespacegroupdata ID”的配置解決方案。 詳情見&#xff1a;Nacos config alibaba/spring-cloud-alibaba Wiki GitHub Namespace方案…

《TCP/IP詳解 卷一》第12章 TCP初步介紹

目錄 12.1 引言 12.1.1 ARQ和重傳 12.1.2 滑動窗口 12.1.3 變量窗口&#xff1a;流量控制和擁塞控制 12.1.4 設置重傳的超時值 12.2 TCP的引入 12.2.1 TCP服務模型 12.2.2 TCP可靠性 12.3 TCP頭部和封裝 12.4 總結 12.1 引言 關于TCP詳細內容&#xff0c;原書有5個章…

【C++ map和set】

文章目錄 map和set序列式容器和關聯式容器鍵值對setset的主要操作 mapmap主要操作 multiset和multimap map和set 序列式容器和關聯式容器 之前我們接觸的vector,list,deque等&#xff0c;這些容器統稱為序列式容器&#xff0c;其底層為線性序列的的數據結構&#xff0c;里面存…