bert 相似度任務訓練完整版

任務

之前寫了一個相似度任務的版本:bert 相似度任務訓練簡單版本,faiss 尋找相似 topk-CSDN博客

相似度用的是 0,1,相當于分類任務,現在我們相似度有評分,不再是 0,1 了,分數為 0-5,數字越大代表兩個句子越相似,這一次的比較完整,評估,驗證集,相似度模型都有了。

數據集

鏈接:https://pan.baidu.com/s/1B1-PKAKNoT_JwMYJx_zT1g?
提取碼:er1z?
原始數據好幾千條,我訓練數據用了部分 2500 條,驗證,測試 300 左右,使用 cpu 也用了好幾個小時

train.py

import torch
import os
import time
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AdamW, get_cosine_schedule_with_warmup
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np# 設備選擇
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'# 定義文本相似度數據集類
class TextSimilarityDataset(Dataset):def __init__(self, file_path, tokenizer, max_len=128):self.data = []with open(file_path, 'r', encoding='utf-8') as f:for line in f.readlines():text1, text2, similarity_score = line.strip().split('\t')inputs1 = tokenizer(text1, padding='max_length', truncation=True, max_length=max_len)inputs2 = tokenizer(text2, padding='max_length', truncation=True, max_length=max_len)self.data.append({'input_ids1': inputs1['input_ids'],'attention_mask1': inputs1['attention_mask'],'input_ids2': inputs2['input_ids'],'attention_mask2': inputs2['attention_mask'],'similarity_score': float(similarity_score),})def __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx]def cosine_similarity_torch(vec1, vec2, eps=1e-8):dot_product = torch.mm(vec1, vec2.t())norm1 = torch.norm(vec1, 2, dim=1, keepdim=True)norm2 = torch.norm(vec2, 2, dim=1, keepdim=True)similarity_scores = dot_product / (norm1 * norm2.t()).clamp(min=eps)return similarity_scores# 定義模型,這里我們不僅計算兩段文本的[CLS] token的點積,而是整個句向量的余弦相似度
class BertSimilarityModel(torch.nn.Module):def __init__(self, pretrained_model):super(BertSimilarityModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model)self.dropout = torch.nn.Dropout(p=0.1)  # 引入Dropout層以防止過擬合def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):embeddings1 = self.dropout(self.bert(input_ids=input_ids1, attention_mask=attention_mask1)['last_hidden_state'])embeddings2 = self.dropout(self.bert(input_ids=input_ids2, attention_mask=attention_mask2)['last_hidden_state'])# 計算兩個文本向量的余弦相似度embeddings1 = torch.mean(embeddings1, dim=1)embeddings2 = torch.mean(embeddings2, dim=1)similarity_scores = cosine_similarity_torch(embeddings1, embeddings2)# 映射到[0, 5]評分范圍normalized_similarities = (similarity_scores + 1) * 2.5return normalized_similarities.unsqueeze(1)# 自定義損失函數,使用Smooth L1 Loss,更適合處理回歸問題
class SmoothL1Loss(torch.nn.Module):def __init__(self):super(SmoothL1Loss, self).__init__()def forward(self, predictions, targets):diff = predictions - targetsabs_diff = torch.abs(diff)quadratic = torch.where(abs_diff < 1, 0.5 * diff ** 2, abs_diff - 0.5)return torch.mean(quadratic)def train_model(model, train_loader, val_loader, epochs=3, model_save_path='../output/bert_similarity_model.pth'):model.to(device)criterion = SmoothL1Loss()  # 使用自定義的Smooth L1 Lossoptimizer = AdamW(model.parameters(), lr=5e-5)  # 調整初始學習率為5e-5num_training_steps = len(train_loader) * epochsscheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.1*num_training_steps, num_training_steps=num_training_steps)  # 使用帶有warmup的余弦退火學習率調度best_val_loss = float('inf')for epoch in range(epochs):model.train()for batch in train_loader:input_ids1 = batch['input_ids1'].to(device)attention_mask1 = batch['attention_mask1'].to(device)input_ids2 = batch['input_ids2'].to(device)attention_mask2 = batch['attention_mask2'].to(device)similarity_scores = batch['similarity_score'].to(device)optimizer.zero_grad()outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)loss = criterion(outputs, similarity_scores.unsqueeze(1))loss.backward()optimizer.step()scheduler.step()# 驗證階段model.eval()with torch.no_grad():val_loss = 0total_val_samples = 0for batch in val_loader:input_ids1 = batch['input_ids1'].to(device)attention_mask1 = batch['attention_mask1'].to(device)input_ids2 = batch['input_ids2'].to(device)attention_mask2 = batch['attention_mask2'].to(device)similarity_scores = batch['similarity_score'].to(device)val_outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)val_loss += criterion(val_outputs, similarity_scores.unsqueeze(1)).item()total_val_samples += len(similarity_scores)val_loss /= len(val_loader)print(f'Epoch {epoch + 1}, Validation Loss: {val_loss:.4f}')if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), model_save_path)def collate_to_tensors(batch):'''把數據處理為模型可用的數據,不同任務可能需要修改一下,'''input_ids1 = torch.tensor([example['input_ids1'] for example in batch])attention_mask1 = torch.tensor([example['attention_mask1'] for example in batch])input_ids2 = torch.tensor([example['input_ids2'] for example in batch])attention_mask2 = torch.tensor([example['attention_mask2'] for example in batch])similarity_score = torch.tensor([example['similarity_score'] for example in batch])return {'input_ids1': input_ids1, 'attention_mask1': attention_mask1, 'input_ids2': input_ids2,'attention_mask2': attention_mask2, 'similarity_score': similarity_score}# 加載數據集和預訓練模型
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertSimilarityModel('../bert-base-chinese')# 加載數據并創建
train_data = TextSimilarityDataset('../data/STS-B/STS-B.train - 副本.data', tokenizer)
val_data = TextSimilarityDataset('../data/STS-B/STS-B.valid - 副本.data', tokenizer)
test_data = TextSimilarityDataset('../data/STS-B/STS-B.test - 副本.data', tokenizer)train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_to_tensors)
val_loader = DataLoader(val_data, batch_size=32, collate_fn=collate_to_tensors)
test_loader = DataLoader(test_data, batch_size=32, collate_fn=collate_to_tensors)optimizer = AdamW(model.parameters(), lr=2e-5)# 開始訓練
train_model(model, train_loader, val_loader)# 加載最佳模型進行測試
model.load_state_dict(torch.load('../output/bert_similarity_model.pth'))
test_loss = 0
total_test_samples = 0with torch.no_grad():for batch in test_loader:input_ids1 = batch['input_ids1'].to(device)attention_mask1 = batch['attention_mask1'].to(device)input_ids2 = batch['input_ids2'].to(device)attention_mask2 = batch['attention_mask2'].to(device)similarity_scores = batch['similarity_score'].to(device)test_outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)test_loss += torch.nn.functional.mse_loss(test_outputs, similarity_scores.unsqueeze(1)).item()total_test_samples += len(similarity_scores)test_loss /= len(test_loader)
print(f'Test Loss: {test_loss:.4f}')

predit.py

這個腳本是用來看看效果的,直接傳入兩個文本,使用訓練好的模型來計算相似度的

import torch
from transformers import BertTokenizer, BertModeldef cosine_similarity_torch(vec1, vec2, eps=1e-8):dot_product = torch.mm(vec1, vec2.t())norm1 = torch.norm(vec1, 2, dim=1, keepdim=True)norm2 = torch.norm(vec2, 2, dim=1, keepdim=True)similarity_scores = dot_product / (norm1 * norm2.t()).clamp(min=eps)return similarity_scores# 定義模型,這里我們不僅計算兩段文本的[CLS] token的點積,而是整個句向量的余弦相似度
class BertSimilarityModel(torch.nn.Module):def __init__(self, pretrained_model):super(BertSimilarityModel, self).__init__()self.bert = BertModel.from_pretrained(pretrained_model)self.dropout = torch.nn.Dropout(p=0.1)  # 引入Dropout層以防止過擬合def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):'''如果是用來預測,forward 會被禁用'''pass# 加載預訓練模型和分詞器
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertSimilarityModel('../bert-base-chinese')
model.load_state_dict(torch.load('../output/bert_similarity_model.pth'))  # 請確保路徑正確
model.eval()  # 設置模型為評估模式def calculate_similarity(text1, text2):# 對輸入文本進行編碼inputs1 = tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt')inputs2 = tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')# 計算相似度with torch.no_grad():embeddings1 = model.bert(**inputs1.to('cpu'))['last_hidden_state'][:, 0]embeddings2 = model.bert(**inputs2.to('cpu'))['last_hidden_state'][:, 0]similarity_score = cosine_similarity_torch(embeddings1, embeddings2).item()# 映射到[0, 5]評分范圍(假設訓練時有此步驟)normalized_similarity = (similarity_score + 1) * 2.5return normalized_similarity# 示例
text1 = "瑞典駐利比亞班加西領事館發生汽車炸彈襲擊,無人員傷亡"
text2 = "汽車炸彈擊中瑞典駐班加西領事館,無人受傷。"
similarity = calculate_similarity(text1, text2)
print(f"兩個句子的相似度為:{similarity:.2f}")

?

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/718168.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/718168.shtml
英文地址,請注明出處:http://en.pswp.cn/news/718168.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

EasyRecovery易恢復2024免費文件數據恢復軟件下載

一、軟件概述 EasyRecovery易恢復中文文件數據恢復軟件是一款專為中文用戶設計的強大數據恢復工具。該軟件致力于幫助用戶從各種存儲設備中恢復因各種原因丟失的中文文件&#xff0c;如文檔、圖片、視頻、音頻等。憑借其核心技術和多年的研發經驗&#xff0c;EasyRecovery易恢…

C語言計算誤碼率

#include <stdio.h> #include <stdlib.h> bool dayintrue; //是否打印 int main(){ int i,k,g0; int n10,n20; int good0,bad0; double rate; (dayin)? printf("打印具體數據\n"):printf("不打印具體數據\n\n");…

STM32-SPI通信協議

串行外設接口SPI&#xff08;Serial Peripheral Interface&#xff09;是由Motorola公司開發的一種通用數據總線。 在某些芯片上&#xff0c;SPI接口可以配置為支持SPI協議或者支持I2S音頻協議。 SPI接口默認工作在SPI方式&#xff0c;可以通過軟件把功能從SPI模式切換…

Python·算法·每日一題(3月4日)最長公共前綴

題目 編寫一個函數來查找字符串數組中的最長公共前綴。 如果不存在公共前綴&#xff0c;返回空字符串 “”。 示例 示例 1&#xff1a; 輸入&#xff1a;strs ["flower","flow","flight"] 輸出&#xff1a;"fl"示例 2&#xff1a;…

【數據結構與算法】常見排序算法(Sorting Algorithm)

文章目錄 相關概念1. 冒泡排序&#xff08;Bubble Sort&#xff09;2. 直接插入排序&#xff08;Insertion Sort&#xff09;3. 希爾排序&#xff08;Shell Sort&#xff09;4. 直接選擇排序&#xff08;Selection Sort&#xff09;5. 堆排序&#xff08;Heap Sort&#xff09;…

【腦科學相關合集】有關腦影像數據相關介紹的筆記及有關腦網絡的筆記合集

【腦科學相關合集】有關腦影像數據相關介紹的筆記及有關腦網絡的筆記合集 前言腦模板方面相關筆記清單 基于腦網絡的方法方面數據基本方面 前言 這里&#xff0c;我將展開有關我自己關于腦影像數據相關介紹的筆記及有關腦網絡的筆記合集。其中&#xff0c;腦網絡的相關論文主要…

【錯誤處理】【Hive】【Spark】ERROR FileFormatwriter: Aborting job null.

問題背景 近日&#xff0c;使用 Spark 在讀寫 Hive 表時發生了報錯&#xff1a;Aborting job null&#xff0c;如果怎么都使用不了那張表的話&#xff0c;大概率是那張表有臟數據&#xff0c;導致整張表無法正常使用。 ERROR FileFormatwriter: Aborting job null.解決方法 …

SpringBoot 如何快速過濾出一次請求的所有日志?

前言 在現網出現故障時&#xff0c;我們經常需要獲取一次請求流程里的所有日志進行定位。如果請求只在一個線程里處理&#xff0c;則我們可以通過線程ID來過濾日志&#xff0c;但如果請求包含異步線程的處理&#xff0c;那么光靠線程ID就顯得捉襟見肘了。 華為IoT平臺&#x…

《自然》:人工智能在創造性思維方面超越人類

發散性思維被認為是創造性思維的指標。ChatGPT-4 在三項有151名人類參與的**發散思維測試中&#xff0c;**展現出比人類更高水平的創造力&#xff0c;結果顯示人工智能在創意領域持續發展。 發散性思維的特點是能夠針對沒有預期解決方案的問題提出獨特的解決方案&#xff0c;例…

TOMCAT的安裝與基本信息

一、TOMCAT簡介 Tomcat 服務器是一個免費的開放源代碼的Web 應用服務器&#xff0c;屬于輕量級應用服務器&#xff0c;在中小型系統和并發訪問用戶不是很多的場合下被普遍使用&#xff0c;是開發和調試JSP 程序的首選。對于一個初學者來說&#xff0c;可以這樣認為&#xff0c…

IO 與 NIO

優質博文&#xff1a;IT-BLOG-CN 一、阻塞IO / 非阻塞NIO 阻塞IO&#xff1a;當一條線程執行read()或者write()方法時&#xff0c;這條線程會一直阻塞直到讀取到了一些數據或者要寫出去的數據已經全部寫出&#xff0c;在這期間這條線程不能做任何其他的事情。 非阻塞NIO&…

記錄踩過的坑-macOS下使用VS Code

目錄 切換主題 安裝插件 搭建Python開發環境 裝Python插件 配置解釋器 打開項目 打開終端 切換主題 安裝插件 方法1 方法2 搭建Python開發環境 裝Python插件 配置解釋器 假設解釋器已經通過Anaconda建好&#xff0c;只需要在VS Code中關聯。 打開項目 打開終端

ArmV8架構

Armv8/armv9架構入門指南 — Armv8/armv9架構入門指南 v1.0 documentation 上面只是給了一個比較好的參考文檔 其他內容待補充

網絡-httpclient調用https服務端繞過證書的方法

httpclient調用https服務端繞過證書的方法 在日常開發或者測試中&#xff0c;通常會遇到需要用httpclient客戶端調用對方http是服務器的場景&#xff0c;由于沒有證書&#xff0c;所以直接是無法調用的。采用下面的方法可以繞過證書驗證&#xff1a; TrustManager[] trustAll…

AutoSAR(基礎入門篇)13.5-Mcal Mcu時鐘的配置

目錄 一、EB的Mcu模塊結構 二、時鐘的配置 對Mcu的配置主要就是其時鐘樹的配置,但是EB將GTM、ERU等很多重要的模塊也都放在了Mcu里面做配置,所以這里的Mcu是一個很龐大的模塊, 我們目前只講時鐘樹的部分 一、EB的Mcu模塊結構 1. 所有的模塊都基本上是這么些配置類別,Mc…

單詞級文本攻擊—論文閱讀

TAAD2.2論文概覽 0.前言1-101.Bridge the Gap Between CV and NLP! A Gradient-based Textual Adversarial Attack Frameworka. 背景b. 方法c. 結果d. 論文及代碼 2.TextHacker: Learning based Hybrid Local Search Algorithm for Text Hard-label Adversarial Attacka. 背景b…

閱讀筆記 | Transformers in Time Series: A Survey

閱讀論文&#xff1a; Wen, Qingsong, et al. “Transformers in time series: A survey.” arXiv preprint arXiv:2202.07125 (2022). 這篇綜述主要對基于Transformer的時序建模方法進行介紹。論文首先簡單介紹了Transformer的基本原理&#xff0c;包括位置編碼、多頭注意力機…

OPENAI SORA:未來視頻創作的新引擎——淺析其背后的人工智能算法

Sora - 探索AI視頻模型的無限可能 隨著人工智能技術的飛速發展&#xff0c;AI視頻模型已成為科技領域的新熱點。而在這個浪潮中&#xff0c;OpenAI推出的首個AI視頻模型Sora&#xff0c;以其卓越的性能和前瞻性的技術&#xff0c;引領著AI視頻領域的創新發展。本文將探討SORA的…

回歸預測 | Matlab實現RIME-BP霜冰算法優化BP神經網絡多變量回歸預測

回歸預測 | Matlab實現RIME-BP霜冰算法優化BP神經網絡多變量回歸預測 目錄 回歸預測 | Matlab實現RIME-BP霜冰算法優化BP神經網絡多變量回歸預測預測效果基本描述程序設計參考資料 預測效果 基本描述 1.Matlab實現RIME-BP霜冰算法優化BP神經網絡多變量回歸預測&#xff08;完整…

自動化測試介紹、selenium用法(自動化測試框架+爬蟲可用)

文章目錄 一、自動化測試1、什么是自動化測試&#xff1f;2、手工測試 vs 自動化測試3、自動化測試常見誤區4、自動化測試的優劣5、自動化測試分層6、什么項目適合自動化測試 二、Selenuim1、小例子2、用法3、頁面操作獲取輸入內容模擬點擊清空文本元素拖拽frame切換窗口切換/標…