Text-to-SQL將自然語言轉換為數據庫查詢語句

有關Text-To-SQL方法,可以查閱我的另一篇文章,Text-to-SQL方法研究

直接與數據庫對話-text2sql

Text2sql就是把文本轉換為sql語言,這段時間公司有這方面的需求,調研了一下市面上text2sql的方法,比如阿里的Chat2DB,麻省理工開源的Vanna。試驗了一下,最終還是決定自研,基于Vanna的思想,RAG+大模型。

? ? 使用開源的Vanna實現text2sql比較方便,Vanna可以直接連接數據庫,但是當用戶權限能訪問多個數據庫的時候,就比較麻煩了,而且Vanna向量化存儲之后,新的question作對比時沒有區分數據庫。因此自己實現了一下text2sq,仍然采用Vanna的思想,提前訓練DDL,Sqlques,和數據庫document。

這里簡單做一下記錄,以供后續學習使用。

基本思路

1、數據庫DDL語句,SQL-Question,Dcoument信息獲取

2、基于用戶提問question和數據庫Document鎖定要分析的數據庫

3、模型訓練:借助數據庫的DDL語句、元數據(描述數據庫自身數據的信息)、相關文檔說明、參考樣例SQL等,訓練一個RAG“模型”。

這一模型結合了embedding技術和向量數據庫,使得數據庫的結構和內容能夠被高效地索引和檢索。

4、語義檢索: 當用戶輸入自然語言描述的問題時,①會從向量庫里面檢索,迅速找出與問題相關的內容;②進行BM25算法文本召回,找到與問題 最相關的內容;③分別使用RRF算法和Re-ranking重排序算法,鎖定最相關內容

語義匹配:使用算法(如BERT等)來理解查詢和文檔的語義相似性

文本召回匹配:BM25算法文本召回,找到與問題最相關的內容

rerank結果重排序:對搜索結果進行排序。

5、Prompt構建: 檢索到的相關信息會被組裝進Prompt中,形成一個結構化的查詢描述。這一Prompt隨后會被傳遞給LLM(大型語言模型)用于生成準確的SQL查詢。

實現邏輯圖

實現架構圖:

具體實現方式如下所示:

1.數據庫的選擇

class DataBaseSearch(object):def __init__(self, _model):self.name = 'DataBaseSearch'self.model = _modelself.instruction = "為這段內容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.textdata = []self.subdata = {}self.i2key = {}self.id2ddls = {}self.id2sqlques = {}self.id2docs = {}self.strtexts = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_textdata()         # 加載text數據self.load_textdata_vec()     # text數據向量化def load_textdata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)textdatas = jsonobj["data"]for textdata in textdatas:                                 # 提取每一個數據庫內容cid = textdata["dataSetID"]cddls = textdata["ddl"]csql_ques = textdata["exp"]cdocuments = textdata["Intro"]self.textdata.append((cid, cddls, csql_ques, cdocuments))   # 整合所有數據except Exception as e:print(e)# print("load textdata ", self.textdata)def load_textdata_vec(self):num0 = 0for recode in self.textdata:_id = recode[0]_ddls = recode[1]_sql_ques = recode[2]_documents = recode[3]# _strtexts = str(_ddls) + str(_sql_ques) + str(_documents)_strtexts = str(_sql_ques) + str(_documents)text_embeddings = self.model.encode([_strtexts], normalize_embeddings=True)self.index.add(text_embeddings)self.i2key[num0] = _idself.strtexts[_id] = _strtextsself.id2ddls[_id] = _ddlsself.id2sqlques[_id] = _sql_quesself.id2docs[_id] = _documentsnum0 += 1# print("init instruction vec", num0)def calculate_score(self, score, question, kws):passdef find_vec_database(self, question, k, theata):# print(question)q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]uuid = self.i2key.get(sim_i, "none")sim_v = D[0][i]database_texts = self.strtexts.get(uuid, "none")# score = self.calculate_score(sim_v, question, database_texts) # wait implementscore = int(sim_v*1000)if score < theata:doc = {}doc["score"] = scoredoc["dataSetID"] = uuidresult.append(doc)# print(result)return resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DataBaseSearch(model)result = vs.find_vec_database("查詢濟南市第三幼兒園所有小班班級?", 1, 2000)print(result)

2.sql_ques:sql問題訓練

class SqlQuesSearch(object):def __init__(self, _model):self.name = "SqlQuesSearch"self.model = _modelself.instruction = "為這段內容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.sqlquedata = []self.i2dbid = {}self.i2sqlid = {}self.id2sqlque = {}self.id2que = {}self.id2sql = {}self.dbid2sqlques = {}## self.sqlques = {}## self.i2key = {}## self.id2sqlques = {}## self.num2sqlque = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_textdata()  # 加載text數據self.load_textdata_vec()  # text數據向量化def load_textdata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)textdatas = jsonobj["data"]datadatas = jsonobj["data"]for datadata in datadatas:  # 提取每一個數據庫sql-ques內容dbid = datadata["dataSetID"]sql_ques = datadata["exp"]self.sqlquedata.append((dbid, sql_ques))  # 整合sql數據except Exception as e:print(e)# print("load textdata ", self.sqlquedata)def load_textdata_vec(self):num0 = 0for recode in self.sqlquedata:db_id = recode[0]sql_ques = recode[1]for sql_que in sql_ques:sql_id = sql_que["sql_id"]question = sql_que["question"]sql = sql_que["sql"]ddl_embeddings = self.model.encode([question], normalize_embeddings=True)self.index.add(ddl_embeddings)self.i2dbid[num0] = db_idself.i2sqlid[num0] = sql_idself.id2que[sql_id] = questionself.id2sql[sql_id] = sqlnum0 += 1print("init sql-que vec", num0)def calculate_score(sim_v, question, sql_ques):passdef find_vec_sqlque(self, question, k, theta, dataSetID, number):q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]dbid = self.i2dbid.get(sim_i, "none")  # 獲取數據庫idsqlid = self.i2sqlid.get(sim_i, "none")question = self.id2que.get(sqlid, "none")sql = self.id2sql.get(sqlid, "none")if dbid == dataSetID:sim_v = D[0][i]score = int(sim_v * 1000)if score < theta:doc = {}doc["score"] = scoredoc["question"] = questiondoc["sql"] = sqlresult.append(doc)if len(result) == number:breakreturn resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = SqlQuesSearch(model)result = vs.find_vec_sqlque("查詢7月18日所有的兒童觀察記錄?", 3, 2000, dataSetID=111)print(result)

3.數據庫DDL訓練

class DdlQuesSearch(object):def __init__(self, _model):self.name = "DdlQuesSearch"self.model = _modelself.instruction = "為這段內容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.ddldata = []self.sqlques = {}self.i2dbid = {}self.i2ddlid = {}self.dbid2ddls = {}self.id2ddl = {}self.ddlid2dbid = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_ddldata()  # 加載text數據self.load_ddl_vec()  # text數據向量化def load_ddldata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)for database in databases:db_id = database["dataSetID"]ddls = database["ddl"]self.ddldata.append((db_id, ddls))# print(db_id)# for ddl in database["ddl"]:#     ddl_id = ddl["ddl_id"]#     ddl = ddl['ddl']##     self.id2ddl[ddl_id] = ddl# self.dbid2ddls[db_id] = self.id2ddlexcept Exception as e:print(e)# print("load textdata ", self.ddldata)def load_ddl_vec(self):num0 = 0for recode in self.ddldata:db_id = recode[0]ddls = recode[1]for ddl in ddls:ddl_id = ddl["ddl_id"]ddl_name = ddl["TABLE"]ddl = ddl['ddl']ddl_embeddings = self.model.encode([ddl], normalize_embeddings=True)self.index.add(ddl_embeddings)self.i2dbid[num0] = db_idself.i2ddlid[num0] = ddl_idself.id2ddl[ddl_id] = ddlself.ddlid2dbid[ddl_id] = db_idnum0 += 1self.dbid2ddls[db_id] = self.id2ddlprint("init ddl vec", num0)def find_vec_ddl(self, question, k, theata, dataSetID, number):       # dataSetID:數據庫id# self.id2ddls.get(action_id)q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]dbid = self.i2dbid.get(sim_i, "none")         # 獲取數據庫idddlid = self.i2ddlid.get(sim_i, "none")if dbid == dataSetID:sim_v = D[0][i]score = int(sim_v * 1000)if score < theata:doc = {}doc["score"] = scoredoc["ddl"] = self.id2ddl.get(ddlid, "none")result.append(doc)if len(result) == number:breakreturn resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DdlQuesSearch(model)ss = vs.find_vec_ddl("定時任務執行記錄表", 2, 2000, 111)print(ss)

4.數據庫document訓練

class DocQuesSearch(object):def __init__(self):self.name = "TestDataSearch"self.docdata = []self.load_doc_data()def load_doc_data(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)databases = jsonobj["data"]for database in databases:db_id = database["dataSetID"]doc = database["Intro"]self.docdata.append((db_id, doc))except Exception as e:print(e)# print("load ddldata ", self.docdata)def find_similar_doc(self, dataSetID):result = []for recode in self.docdata:dbid = recode[0]doc = recode[1]if dbid == dataSetID:result.append(doc)return resultif __name__ == '__main__':docques_search = DocQuesSearch()result = docques_search.find_similar_doc(222)print(result)

5.生成sql語句,這里使用的qwen-max模型

import re
import random
import os, json
import dashscope
from dashscope.api_entities.dashscope_response import Message
from ddl_engine import DdlQuesSearch
from dashscope import Generation
from sqlques_engine import SqlQuesSearch
from sentence_transformers import SentenceTransformerclass Genarate(object):def __init__(self):self.api_key = os.environ.get('api_key')self.model_name = os.environ.get('model')def system_message(self, message):return {'role': 'system', 'content': message}def user_message(self, message):return {'role': 'user', 'content': message}def assistant_message(self, message):return {'role': 'assistant', 'content': message}def submit_prompt(self, prompt):resp = Generation.call(model=self.model_name,messages=prompt,seed=random.randint(1, 10000),result_format='message',api_key=self.api_key)if resp["status_code"] == 200:answer = resp.output.choices[0].message.contentglobal DEBUG_INFODEBUG_INFO = (prompt, answer)return answerelse:answer = Nonereturn answerdef generate_sql(self, question, sql_result, ddl_result, doc_result):prompt = self.get_sql_prompt(question = question,sql_result = sql_result,ddl_result = ddl_result,doc_result = doc_result)print("SQL Prompt:",prompt)llm_response = self.submit_prompt(prompt)sql = self.extrat_sql(llm_response)return sqldef extrat_sql(self, llm_response):sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlreturn llm_responsedef get_sql_prompt(self, question, sql_result, ddl_result, doc_result):initial_prompt = "You are a SQL expert. " + \"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_result)initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_result)initial_prompt += ("===Response Guidelines \n""1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n""2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n""3. If the provided context is insufficient, please explain why it can't be generated. \n""4. Please use the most relevant table(s). \n""5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n")message_log = [self.system_message(initial_prompt)]message_log = self.add_sqlques_to_prompt(question, sql_result, message_log)return message_logdef add_ddl_to_prompt(self, initial_prompt, ddl_result):""":param initial_prompt::param ddl_result::return:"""ddl_list = [ ddl_['ddl'] for ddl_ in ddl_result]if len(ddl_list) > 0:initial_prompt += "\n===Tables \n"for ddl in ddl_list:initial_prompt += f"{ddl}\n\n"return initial_promptdef add_sqlques_to_prompt(self, question, sql_result, message_log):""":param sql_result::return:"""if len(sql_result) > 0:for example in sql_result:if example is not None and "question" in example and "sql" in example:message_log.append(self.user_message(example["question"]))message_log.append(self.assistant_message(example["sql"]))message_log.append(self.user_message(question))return message_logdef add_documentation_to_prompt(self, initial_prompt, doc_result):if len(doc_result) > 0:initial_prompt += "\n===Additional Context \n\n"for doc in doc_result:initial_prompt += f"{doc}\n\n"return initial_promptif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DdlQuesSearch(model)ss = vs.find_vec_ddl("定時任務執行記錄表", 1, 2000, 111)print(ss)

6.執行結果顯示

如圖可以看到正確生成了sql,可以正常執行,因為表是拉取到,沒有數據,所以查詢結果為空。

需要源碼的同學,可以留言。

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/899649.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/899649.shtml
英文地址,請注明出處:http://en.pswp.cn/news/899649.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

golang 的strconv包常用方法

目錄 1. 字符串與整數的轉換 2. 字符串與浮點數的轉換 3. 布爾值的轉換 4. 字符串的轉義 5. 補充&#xff1a;rune 類型的使用 方法功能詳解 代碼示例&#xff1a; 1. 字符串與整數的轉換 方法名稱功能描述示例Atoi將字符串轉換為十進制整數。strconv.Atoi("123&q…

MATLAB詳細圖文安裝教程(附安裝包)

前言 MATLAB&#xff08;Matrix Laboratory&#xff09;是由MathWorks公司開發的一款高性能的編程語言和交互式環境&#xff0c;主要用于數值計算、數據分析和算法開發。內置數學函數和工具箱豐富&#xff0c;開發效率高&#xff0c;特別適合矩陣運算和領域特定問題。接下來就…

ShapeCrawler:.NET開發者的PPTX操控魔法

引言 在當今的軟件開發領域&#xff0c;隨著數據可視化和信息展示需求的不斷增長&#xff0c;處理 PPTX 文件的場景日益頻繁。無論是自動化生成報告、批量制作演示文稿&#xff0c;還是對現有 PPT 進行內容更新與格式調整&#xff0c;開發者都需要高效的工具來完成這些任務。傳…

HTML5貪吃蛇游戲開發經驗分享

HTML5貪吃蛇游戲開發經驗分享 這里寫目錄標題 HTML5貪吃蛇游戲開發經驗分享項目介紹技術棧核心功能實現1. 游戲初始化2. 蛇的移動控制3. 碰撞檢測4. 食物生成 開發心得項目收獲后續優化方向結語 項目介紹 在這個項目中&#xff0c;我使用HTML5 Canvas和原生JavaScript實現了一…

有關pip與conda的介紹

Conda vs. Pip vs. Virtualenv 命令對比 任務Conda 命令Pip 命令Virtualenv 命令安裝包conda install $PACKAGE_NAMEpip install $PACKAGE_NAMEX更新包conda update --name $ENVIRONMENT_NAME $PACKAGE_NAMEpip install --upgrade $PACKAGE_NAMEX更新包管理器conda update con…

【Linux】調試器——gdb使用

目錄 一、預備知識 二、常用指令 三、調試技巧 &#xff08;一&#xff09;監視變量的變化指令 watch &#xff08;二&#xff09;更改指定變量的值 set var 正文 一、預備知識 程序的發布形式有兩種&#xff0c;debug和release模式&#xff0c;Linux gcc/g出來的二進制…

【Ubuntu常用命令】

1.將本地服務器文件或文件夾傳輸到遠程服務器 文件 scp /data/a.txt administrator10.60.51.20:/home/administrator/ 文件夾 scp -r /data/ administrator10.60.51.20:/home/administrator/ 2.從遠程服務器傳輸文件到本地服務器 scp administrator10.60.51.20:/data/a.txt /h…

golang 的time包的常用方法

目錄 time 包方法總結 類型 time.Time 的方法 庫函數 代碼示例&#xff1a; time 包方法總結 類型 time.Time 的方法 方法名描述示例               ?Now()獲取當前時間和日期time.Now()Format()格式化時間為字符串time.Now().Format("2006-01-02 15…

Elasticsearch:使用 Azure AI 文檔智能解析 PDF 文本和表格數據

作者&#xff1a;來自 Elastic James Williams 了解如何使用 Azure AI 文檔智能解析包含文本和表格數據的 PDF 文檔。 Azure AI 文檔智能是一個強大的工具&#xff0c;用于從 PDF 中提取結構化數據。它可以有效地提取文本和表格數據。提取的數據可以索引到 Elastic Cloud Serve…

【ArcGIS操作】ArcGIS 進行空間聚類分析

ArcGIS 是一個強大的地理信息系統&#xff08;GIS&#xff09;軟件&#xff0c;主要用于地理數據的存儲、分析、可視化和制圖 啟動 ArcMap 在 Windows 中&#xff0c;點擊“開始”菜單&#xff0c;找到 ArcGIS文件夾&#xff0c;然后點擊 ArcMap 添加數據 添加數據 - 點擊工具…

RabbitMQ消息相關

MQ的模式&#xff1a; 基本消息模式&#xff1a;一個生產者&#xff0c;一個消費者work模式&#xff1a;一個生產者&#xff0c;多個消費者訂閱模式&#xff1a; fanout廣播模式&#xff1a;在Fanout模式中&#xff0c;一條消息&#xff0c;會被所有訂閱的隊列都消費。 在廣播…

緩存使用紀要

一、本地緩存&#xff1a;Caffeine 1、簡介 Caffeine是一種高性能、高命中率、內存占用低的本地緩存庫&#xff0c;簡單來說它是 Guava Cache 的優化加強版&#xff0c;是當下最流行、最佳&#xff08;最優&#xff09;緩存框架。 Spring5 即將放棄掉 Guava Cache 作為緩存機…

2025年3月29日筆記

問題&#xff1a;創建一個長度為99的整數數組&#xff0c;輸出數組的每個位置數字是幾&#xff1f; 解題思路&#xff1a; 1.因為題中沒有明確要求需要輸入,所以所有類型的答案都需要寫出 解法1&#xff1a; #include<iostream> #include<bits/stdc.h> using n…

hadoop相關面試題以及答案

什么是Hadoop&#xff1f;它的主要組件是什么&#xff1f; Hadoop是一個開源的分布式計算框架&#xff0c;用于處理大規模數據的存儲和計算。其主要組件包括Hadoop Distributed File System&#xff08;HDFS&#xff09;和MapReduce。 解釋HDFS的工作原理。 HDFS采用主從架構&…

微信小程序:數據拼接方法

1. 使用 concat() 方法拼接數組 // 在原有數組基礎上拼接新數組 Page({data: {originalArray: [1, 2, 3]},appendData() {const newData [4, 5, 6];const combinedArray this.data.originalArray.concat(newData);this.setData({originalArray: combinedArray});} }) 2. 使…

Python之貪心算法

Python實現貪心算法(Greedy Algorithm) 概念 貪心算法是一種在每一步選擇中都采取當前狀態下最優的選擇&#xff0c;從而希望導致結果是全局最優的算法策略。 基本特點 局部最優選擇&#xff1a;每一步都做出當前看起來最佳的選擇不可回退&#xff1a;一旦做出選擇&#xf…

【 <二> 丹方改良:Spring 時代的 JavaWeb】之 Spring Boot 中的 AOP:實現日志記錄與性能監控

<前文回顧> 點擊此處查看 合集 https://blog.csdn.net/foyodesigner/category_12907601.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12907601&sharereferPC&sharesourceFoyoDesigner&sharefromfrom_link <今日更新> 一、開篇整…

TCP/IP協議簇

文章目錄 應用層http/httpsDNS補充 傳輸層TCP1. 序列號與確認機制2. 超時重傳3. 流量控制&#xff08;滑動窗口機制&#xff09;4. 擁塞控制5. 錯誤檢測與校驗6. 連接管理總結 網絡層ARP**ARP 的核心功能**ARP 的工作流程1. ARP 請求&#xff08;Broadcast&#xff09;2. ARP 緩…

SpringBoot分布式項目訂單管理實戰:Mybatis最佳實踐全解

一、架構設計與技術選型 典型分布式訂單系統架構&#xff1a; [網關層] → [訂單服務] ←→ [分布式緩存]↑ ↓ [用戶服務] [支付服務]↓ ↓ [MySQL集群] ← [分庫分表中間件]技術棧組合&#xff1a; Spring Boot 3.xMybatis-Plus 3.5.xShardingSpher…

微服務架構中的精妙設計:環境和工程搭建

一.前期準備 1.1開發環境安裝 Oracle從JDK9開始每半年發布?個新版本, 新版本發布后, ?版本就不再進?維護. 但是會有?個?期維護的版本. ?前?期維護的版本有: JDK8, JDK11, JDK17, JDK21 在 JDK版本的選擇上&#xff0c;盡量選擇?期維護的版本. 為什么選擇JDK17? S…