利用sentence bert 實現語義向量搜索

目錄

基于pytorch的中文語言模型預訓練:https://github.com/zhusleep/pytorch_chinese_lm_pretrain/tree/master

sentence_emb.py

search_faiss_robert768.py

faiss_index.py

gen_vec_save2_faiss.py


基于pytorch的中文語言模型預訓練:https://github.com/zhusleep/pytorch_chinese_lm_pretrain/tree/master

sentence_emb.py

#from transformers import BertTokenizer, BertModel
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#
## First we initialize our model and tokenizer:
#tokenizer = BertTokenizer.from_pretrained('./result')
#model = BertModel.from_pretrained('./result')def split_batch(init_list, batch_size):groups = zip(*(iter(init_list),) * batch_size)end_list = [list(i) for i in groups]count = len(init_list) % batch_sizeend_list.append(init_list[-count:]) if count != 0 else end_listreturn end_list"""
param: sentence list
return: embeddings
"""
def encode(sentences, tokenizer, model):tokens = {'input_ids': [], 'attention_mask': []}data_num = len(sentences)for sentence in sentences:# 編碼每個句子并添加到字典new_tokens = tokenizer.encode_plus(str(sentence), max_length=128,truncation=True, padding='max_length',return_tensors='pt')tokens['input_ids'].append(new_tokens['input_ids'][0])tokens['attention_mask'].append(new_tokens['attention_mask'][0])# 將張量列表重新格式化為一個張量tokens['input_ids'] = torch.stack(tokens['input_ids']).to(device)tokens['attention_mask'] = torch.stack(tokens['attention_mask']).to(device)model.eval()# We process these tokens through our model:with torch.no_grad():#添加這行代碼outputs = model(**tokens)# odict_keys(['last_hidden_state', 'pooler_output'])# The dense vector representations of our text are contained within the outputs 'last_hidden_state' tensor, which we access like so:embeddings = outputs[0]# To perform this operation, we first resize our attention_mask tensor:attention_mask = tokens['attention_mask']# attention_mask.shapemask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()# mask.shape# 上面的每個向量表示一個單獨token的掩碼現在每個token都有一個大小為768的向量,表示它的attention_mask狀態。然后將兩個張量相乘:masked_embeddings = embeddings * mask# masked_embeddings.shape# torch.Size([2, 128, 768])torch.Size([data_num, 128, 768])summed = torch.sum(masked_embeddings, 1)summed_mask = torch.clamp(mask.sum(1), min=1e-9)mean_pooled = summed / summed_mask# print(mean_pooled)# print(type(mean_pooled))return mean_pooled#sentences = [
#    "你叫什么名字?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#    "你的名字是什么?",
#]
#sb = split_batch(sentences, 2)
#embs = []
#for batch in sb:
#	emb = encode(batch)
#	embs += emb
#
#print(embs)
#print(len(embs))

search_faiss_robert768.py

import pickle
from faiss_index import faissIndex
import pandas as pd
import numpy as np
# from sentence_transformers import SentenceTransformer
# Download model
# model = SentenceTransformer('paraphrase-MiniLM-L6-v2/')
from sentence_emb import encodefrom transformers import BertTokenizer, BertModel
import torch
# First we initialize our model and tokenizer:
tokenizer = BertTokenizer.from_pretrained('./result')
model = BertModel.from_pretrained('./result').cuda()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# faiss_index_path = "faiss_index384.pkl"
faiss_index_path = "faiss_index_robert.pkl"symptom_name_df = pd.read_csv("col2.csv")# 從本地加載faiss_index模型
def load_faiss_index(var_faiss_model_path):# 從本地加載faiss_index模型# with open('strategy/semantic_recall/model/tt.txt', 'r') as f:#     print(f.readlines())with open(var_faiss_model_path, mode='rb', errors=None) as fr:index = pickle.load(fr, encoding='ASCII', errors='ASCII')return indexdef symptom_name_recall(symptom_name):# 將參數中當前的文本編碼成向量sentence = []sentence.append(symptom_name)# qyery_emb = model.encode(sentence)qyery_emb = encode(sentence,tokenizer,model)# 去faiss中檢索相近的faiss索引# 加載faissloaded_faiss_index = load_faiss_index(faiss_index_path)# 尋找最近k個物料# R, D, I = loaded_faiss_index.search_items(qyery_emb.reshape([-1, 384]), k=10, n_probe=5)R, D, I = loaded_faiss_index.search_items(np.array(qyery_emb.reshape([-1, 768]).cpu()), k=10, n_probe=5)# 從faiss庫中檢索的物料ID進行轉換result = []for id_list in R:for item in id_list:result.append(item)symptom_name_list = symptom_name_df[symptom_name_df['index'].isin(result)]['symptom_name'].to_list()# 從相似度檢索的結果中,去除自己if symptom_name in symptom_name_list:symptom_name_list.remove(symptom_name)print(symptom_name + ' 的相近的詞:' + str(symptom_name_list))word_lsit = ['頭痛','惡心吧吐','期飲酒','出血','失眠']
for word in word_lsit:symptom_name_recall(word)

faiss_index.py

import faiss
import numpy as npclass faissIndex:def __init__(self, dim, n_centroids, metric):self.dim = dimself.n_centriods = n_centroidsassert metric in ('INNER_PRODUCT', 'L2'), "Input metric not in 'INNER_PRODUCT' or 'L2'"self.metric = faiss.METRIC_INNER_PRODUCT if metric == 'INNER_PRODUCT' else faiss.METRIC_L2self._build_index()returndef _build_index(self):self._quantizer = faiss.IndexFlatL2(self.dim)self.index = faiss.IndexIVFFlat(self._quantizer, self.dim, self.n_centriods, self.metric)self.is_trained = self.index.is_trainedself.n_samples = 0  # 查詢向量池中的向量個數self.items = np.array([])  # 向量池中向量對應的item,數量應與self.n_samples保持一致,即向量與item一一對應return Truedef reset_index(self, dim, n_centroids, metric):self.dim = dimself.n_centriods = n_centroidsassert metric in ('INNER_PRODUCT', 'L2'), "Input metric not in 'INNER_PRODUCT' or 'L2'"self.metric = faiss.METRIC_INNER_PRODUCT if metric == 'INNER_PRODUCT' else faiss.METRIC_L2self._build_index()returndef train(self, vectors_train):self.index.train(vectors_train)self.is_trained = self.index.is_trainedreturndef add(self, vectors, items=None):if not items.empty:  # 當有輸入items時,驗證之前的item和vector數量是否匹配,以及當前輸入assert len(vectors) == len(items), "Length of vectors ({n_vectors}) and items ({n_items}) don't match, please check your input.".format(n_vectors=len(vectors), n_items=len(items))assert self.n_samples == len(self.items), "Amounts of added vectors and items don't match, cannot add more items."self.items = np.append(self.items, items.to_numpy())else:assert len(self.items) == 0, "There were items added previously, please added corresponding items in this batch."self.index.add(vectors)self.n_samples += len(vectors)returndef search(self, query_vector, k, n_probe=1):assert query_vector.shape[1] == self.dim, "The dimension of query vector ({dim_vector}) doesn't match the training vector set ({dim_index})!".format(dim_vector=query_vector.shape[1], dim_index=self.dim)assert self.is_trained, "Faiss index is not trained, please train index first!"assert self.n_samples > 0, "Faiss index doesn't have any vector for query, please add vectors into index first!"self.index.nprobe = n_probeD, I = self.index.search(query_vector, k)return D, I# k = 30 # 對每條向量(每行)尋找最近k個物料# n_probe = 5 # 每次查詢只查詢最近鄰n_probe個聚類def search_items(self, query_vector, k, n_probe=1):D, I = self.search(query_vector, k, n_probe)R = [self.items[i] for i in I]return R, D, I

gen_vec_save2_faiss.py

"""
# 訓練語義向量并保存在faiss中
step1: 將句子生成向量
step2: 將向量保存在faiss中
"""
import pandas as pd
import numpy as np
# from sentence_transformers import SentenceTransformer
# Download model
# model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
from sentence_emb import encode
import pickle
from faiss_index import faissIndex
from tqdm import tqdmfaiss_index_path = "faiss_index_robert.pkl"from transformers import BertTokenizer, BertModel
import torchdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# First we initialize our model and tokenizer:
tokenizer = BertTokenizer.from_pretrained('./result')
model = BertModel.from_pretrained('./result').cuda()# ====================== 創建faiss index并進行訓練 ======================
# 創建faiss index并進行訓練
def build_faiss_index(df_resources, semantic_vector, n_centroids=5, metric='L2'):print("現在開始進行faiss index模型訓練")# 構建faiss索引模型dim = semantic_vector.shape[1]print("訓練數據維度:", dim)print("聚類中心個數:", n_centroids)print("向量距離指標:", metric)# 訓練faiss索引index = faissIndex(dim, n_centroids, metric)# vectors = np.stack(df_resources['index'].values).astype('float32') # faiss只支持32位浮點數查詢vectors = semantic_vectoritems = df_resources['index']index.train(vectors)index.add(vectors, items)print("faiss index模型已訓練完成")return index# ====================== 保存faiss ======================
# 將index按照指定的日期命名并保存至本地
def save_index(index, path):print("現在開始將faiss index保存至本地")fw = open(path, mode='wb', errors=None)pickle.dump(index, fw)fw.close()print("faiss_index模型已保存至本地")def split_batch(init_list, batch_size):groups = zip(*(iter(init_list),) * batch_size)end_list = [list(i) for i in groups]count = len(init_list) % batch_sizeend_list.append(init_list[-count:]) if count != 0 else end_listreturn end_list"""
# 利用sentence transfermer 生成文本向量
# 訓練faiss
# 保存faiss
param: 
"""def sentence2faiss_transfermer():df = pd.read_csv('col2.csv')train_json = df.to_dict('records')# 取文本將文本轉化為向量title_list = [item['symptom_name'] for item in train_json]print(len(title_list))print("正在訓練中.......")# title_list = title_list[:500]sb = split_batch(title_list, 8)embeddings = []# print(len(title_list))# emb = encode(title_list, tokenizer, model)# print(emb)# exit()for batch in tqdm(sb):try:emb = encode(batch, tokenizer, model)emb = np.array(emb.to("cpu"))for item in emb:embeddings.append(item)except Exception as e:print(e)# print(len(embeddings))# embeddings = np.array(embeddings)
#    print(embeddings)
#    print(len(embeddings))# exit()# embeddings = encode(title_list)# 創建faiss index并進行訓練df_resources = pd.DataFrame(train_json)# print(embeddings.shape)print("==================================================")# emb = emb.cpu()# semantic_2d_array = np.array(embeddings.to("cpu"))# 將numpy數組轉換成CUDA張量# semantic_2d_array= torch.tensor([item.cpu().detach().numpy() for item in semantic_2d_array]).cuda()print("開始build_faiss_index")# print(len(np.array(emb)))trained_index = build_faiss_index(df_resources, np.array(embeddings), n_centroids=5, metric='L2')print("開始save_index")# 保存faiss模型save_index(trained_index, faiss_index_path)sentence2faiss_transfermer()

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/14513.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/14513.shtml
英文地址,請注明出處:http://en.pswp.cn/web/14513.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

[協議]stm32讀取AHT20程序示例

AHT20溫度傳感器使用程序&#xff1a; 使用i2c讀取溫度傳感器數據很簡單&#xff0c;但市面上有至少兩個手冊&#xff0c;我這個對應的手冊貼出來&#xff1a; main: #include "stm32f10x.h" // Device header #include <stdint.h> #includ…

數智賦能內澇治理,四信城市排水防澇解決方案保障城市安全運行

由強降雨、臺風造成城市低洼處出現大量積水、內澇的情況時有發生&#xff0c;給人們出行帶來了極大不便和安全隱患&#xff0c;甚至危及群眾生命財產安全。 為降低內澇造成的損失&#xff0c;一方面我們要大力加強城市排水基礎設施的建設&#xff1b;另一方面要全面掌握城市內澇…

U-Boot menu菜單分析

文章目錄 前言目標環境背景U-Boot如何自動調起菜單U-Boot添加自定義命令實踐 前言 在某個廠家的開發板中&#xff0c;在進入它的U-Boot后&#xff0c;會自動彈出一個菜單頁面&#xff0c;輸入對應的選項就會執行對應的功能。如SD卡鏡像更新、顯示設置等&#xff1a; 目標 本…

docker命令詳解大全

Docker是一種流行的容器化平臺&#xff0c;用于快速部署應用程序并管理容器的生命周期。以下是一些常用的Docker命令及其用途的概述&#xff1a; docker run&#xff1a;創建一個新容器并運行一個命令。docker ps&#xff1a;列出當前運行的容器。docker stop&#xff1a;停止…

Unity射擊游戲開發教程:(20)增加護盾強度

在本文中,我們將增強護盾,使其在受到超過 1 次攻擊后才會被禁用。 Player 腳本具有 Shield PowerUp 方法,我們需要調整盾牌在被摧毀之前可以承受的數量,因此我們將聲明一個 int 變量來設置盾牌可以承受的擊中數量。

微信小程序畫布顯示圖片繪制矩形選區

wxml <view class"page-body"><!-- 畫布 --><view class"page-body-wrapper"><canvas canvas-id"myCanvas" type"2d" id"myCanvas" classmyCanvas bindtouchstart"touchStart" bindtouchmo…

OpenFeign快速入門 替代RestTemplate

1.引入依賴 <!--openFeign--><dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-starter-openfeign</artifactId></dependency><!--負載均衡器--><dependency><groupId>org.spr…

【全網最全】2024電工杯數學建模B題問題一14頁論文+19建模過程代碼+py代碼+2種保獎思路+數據等(后續會更新成品論文等)

您的點贊收藏是我繼續更新的最大動力&#xff01; 一定要點擊如下的卡片鏈接&#xff0c;那是獲取資料的入口&#xff01; 【全網最全】2024電工杯數學建模B題問一論文19建模過程代碼py代碼2種保獎思路數據等&#xff08;后續會更新成品論文等&#xff09;「首先來看看目前已…

C++中的四種類型轉換運算符

隱式類型轉換是安全的&#xff0c;顯式類型轉換是有風險的&#xff0c;C語言之所以增加強制類型轉換的語法&#xff0c;就是為了強調風險&#xff0c;讓程序員意識到自己在做什么。但是&#xff0c;這種強調風險的方式還是比較粗放&#xff0c;粒度比較大&#xff0c;它并沒有表…

MySQL中如何知道數據庫表中所有表的字段的排序規則是什么?

查看所有表的字段及其排序規則&#xff1a; 你可以查詢 information_schema 數據庫中的 COLUMNS 表&#xff0c;來獲取所有表的字段及其排序規則。以下是一個示例查詢&#xff1a; SELECT TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, COLLATION_NAME FROM information_schema.COL…

【設計模式深度剖析】【5】【創建型】【原型模式】| 類比群發郵件,加深理解

&#x1f448;?上一篇:建造者模式 | 下一篇:創建型設計模式對比&#x1f449;? 目錄 原型模式(Prototype Pattern)概覽定義英文原話直譯 3個角色類圖1. 抽象原型&#xff08;Prototype&#xff09;角色2. 具體原型&#xff08;Concrete Prototype&#xff09;角色3. 客戶…

必示科技參與智能運維國家標準預研線下編寫會議并做主題分享

近日&#xff0c;《信息技術服務 智能運維 第3部分&#xff1a;算法治理》&#xff08;擬定名&#xff09;國家標準預研階段第一次編寫工作會議在杭州舉行。本次會議由浙商證券承辦。 此次編寫有來自銀行、證券、保險、通信、高校研究機構、互聯網以及技術方等29家單位&#xf…

在云計算環境中,如何實現資源的高效分配和調度?

在云計算環境中&#xff0c;可以通過以下幾種方法實現資源的高效分配和調度&#xff1a; 負載均衡&#xff1a;通過負載均衡算法&#xff0c;將云計算集群的負載均勻地分配到各個節點上。常見的負載均衡算法有輪詢、最小連接數、最短響應時間等。 資源調度算法&#xff1a;為了…

Linux基礎(四):Linux系統文件類型與文件權限

各位看官&#xff0c;好久不見&#xff0c;在正式介紹Linux的基本命令之前&#xff0c;我們首先了解一下&#xff0c;關于文件的知識。 目錄 一、文件類型 二、文件權限 2.1 文件訪問者的分類 2.2 文件權限 2.2.1 文件的基本權限 2.2.2 文件權限值的表示方法 三、修改文…

CSS3 新增背景屬性 + 新增邊框屬性(如果想知道CSS3新增背景屬性和新增邊框屬性的知識點,那么只看這一篇就夠了!)

前言&#xff1a;CSS3在CSS2的基礎上&#xff0c;新增了很多強大的新功能&#xff0c;從而解決一些實際面臨的問題&#xff0c;本篇文章主要講解的為CSS3新增背景屬性和新增邊框屬性。 ???這里是秋刀魚不做夢的BLOG ???想要了解更多內容可以訪問我的主頁秋刀魚不做夢-CSD…

視覺SLAM十四講:從理論到實踐(Chapter5:相機與圖像)

前言 學習筆記&#xff0c;僅供學習&#xff0c;不做商用&#xff0c;如有侵權&#xff0c;聯系我刪除即可 目標 理解針孔相機的模型、內參與徑向畸變參數。理解一個空間點是如何投影到相機成像平面的。掌握OpenCV的圖像存儲與表達方式。學會基本的攝像頭標定方法。 一、相…

機器學習第四十周周報 WDN GGNN

文章目錄 week40 WDN GGNN摘要Abstract一、文獻閱讀1. 題目2. abstract3. 網絡架構3.1 問題提出3.2 GNN3.3 CSI GGNN 4. 文獻解讀4.1 Introduction4.2 創新點4.3 實驗過程4.3.1 數據獲取4.3.2 參數設置4.3.3 實驗結果 5. 結論二、GGNN1. 代碼解釋2. 網絡結構小結參考文獻參考文…

Vue 2 和 Vue 3 中同步和異步

Vue 2 和 Vue 3 中同步和異步 Vue 2 同步和異步 同步更新 (Synchronous Updates) Vue 2 在數據更新后會進行同步渲染更新,但為了性能優化,Vue 會在內部隊列中異步地進行 DOM 更新。這意味著數據變化會立即被捕捉到,但實際的 DOM 更新會被推遲到下一個事件循環隊列中。new V…

基礎3 探索JAVA圖形編程桌面:邏輯圖形組件實現

在一個寬敞明亮的培訓教室里&#xff0c;陽光透過窗戶柔和地灑在地上&#xff0c;教室里擺放著整齊的桌椅。臥龍站在講臺上&#xff0c;面帶微笑&#xff0c;手里拿著激光筆&#xff0c;他的眼神中充滿了熱情和期待。他的聲音清晰而洪亮&#xff0c;傳遍了整個教室&#xff1a;…

Linux模擬考試

注意&#xff0c;以下答案僅供參考 1、某CentOS系統空間不夠&#xff0c;現加一塊100G的硬盤(是系統的第二塊硬盤&#xff09;&#xff0c;分為一個區99G&#xff0c;掛載點是/data&#xff0c;請寫出從分區到掛載并使用的整個步驟及相關命令。 1.創建分區&#xff1a; sudo f…