有關Text-To-SQL方法,可以查閱我的另一篇文章,Text-to-SQL方法研究
直接與數據庫對話-text2sql
Text2sql就是把文本轉換為sql語言,這段時間公司有這方面的需求,調研了一下市面上text2sql的方法,比如阿里的Chat2DB,麻省理工開源的Vanna。試驗了一下,最終還是決定自研,基于Vanna的思想,RAG+大模型。
? ? 使用開源的Vanna實現text2sql比較方便,Vanna可以直接連接數據庫,但是當用戶權限能訪問多個數據庫的時候,就比較麻煩了,而且Vanna向量化存儲之后,新的question作對比時沒有區分數據庫。因此自己實現了一下text2sq,仍然采用Vanna的思想,提前訓練DDL,Sqlques,和數據庫document。
這里簡單做一下記錄,以供后續學習使用。
基本思路
1、數據庫DDL語句,SQL-Question,Dcoument信息獲取
2、基于用戶提問question和數據庫Document鎖定要分析的數據庫
3、模型訓練:借助數據庫的DDL語句、元數據(描述數據庫自身數據的信息)、相關文檔說明、參考樣例SQL等,訓練一個RAG“模型”。
這一模型結合了embedding技術和向量數據庫,使得數據庫的結構和內容能夠被高效地索引和檢索。
4、語義檢索: 當用戶輸入自然語言描述的問題時,①會從向量庫里面檢索,迅速找出與問題相關的內容;②進行BM25算法文本召回,找到與問題 最相關的內容;③分別使用RRF算法和Re-ranking重排序算法,鎖定最相關內容
語義匹配:使用算法(如BERT等)來理解查詢和文檔的語義相似性
文本召回匹配:BM25算法文本召回,找到與問題最相關的內容
rerank結果重排序:對搜索結果進行排序。
5、Prompt構建: 檢索到的相關信息會被組裝進Prompt中,形成一個結構化的查詢描述。這一Prompt隨后會被傳遞給LLM(大型語言模型)用于生成準確的SQL查詢。
實現邏輯圖
實現架構圖:
具體實現方式如下所示:
1.數據庫的選擇
class DataBaseSearch(object):def __init__(self, _model):self.name = 'DataBaseSearch'self.model = _modelself.instruction = "為這段內容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.textdata = []self.subdata = {}self.i2key = {}self.id2ddls = {}self.id2sqlques = {}self.id2docs = {}self.strtexts = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_textdata() # 加載text數據self.load_textdata_vec() # text數據向量化def load_textdata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)textdatas = jsonobj["data"]for textdata in textdatas: # 提取每一個數據庫內容cid = textdata["dataSetID"]cddls = textdata["ddl"]csql_ques = textdata["exp"]cdocuments = textdata["Intro"]self.textdata.append((cid, cddls, csql_ques, cdocuments)) # 整合所有數據except Exception as e:print(e)# print("load textdata ", self.textdata)def load_textdata_vec(self):num0 = 0for recode in self.textdata:_id = recode[0]_ddls = recode[1]_sql_ques = recode[2]_documents = recode[3]# _strtexts = str(_ddls) + str(_sql_ques) + str(_documents)_strtexts = str(_sql_ques) + str(_documents)text_embeddings = self.model.encode([_strtexts], normalize_embeddings=True)self.index.add(text_embeddings)self.i2key[num0] = _idself.strtexts[_id] = _strtextsself.id2ddls[_id] = _ddlsself.id2sqlques[_id] = _sql_quesself.id2docs[_id] = _documentsnum0 += 1# print("init instruction vec", num0)def calculate_score(self, score, question, kws):passdef find_vec_database(self, question, k, theata):# print(question)q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]uuid = self.i2key.get(sim_i, "none")sim_v = D[0][i]database_texts = self.strtexts.get(uuid, "none")# score = self.calculate_score(sim_v, question, database_texts) # wait implementscore = int(sim_v*1000)if score < theata:doc = {}doc["score"] = scoredoc["dataSetID"] = uuidresult.append(doc)# print(result)return resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DataBaseSearch(model)result = vs.find_vec_database("查詢濟南市第三幼兒園所有小班班級?", 1, 2000)print(result)
2.sql_ques:sql問題訓練
class SqlQuesSearch(object):def __init__(self, _model):self.name = "SqlQuesSearch"self.model = _modelself.instruction = "為這段內容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.sqlquedata = []self.i2dbid = {}self.i2sqlid = {}self.id2sqlque = {}self.id2que = {}self.id2sql = {}self.dbid2sqlques = {}## self.sqlques = {}## self.i2key = {}## self.id2sqlques = {}## self.num2sqlque = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_textdata() # 加載text數據self.load_textdata_vec() # text數據向量化def load_textdata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)textdatas = jsonobj["data"]datadatas = jsonobj["data"]for datadata in datadatas: # 提取每一個數據庫sql-ques內容dbid = datadata["dataSetID"]sql_ques = datadata["exp"]self.sqlquedata.append((dbid, sql_ques)) # 整合sql數據except Exception as e:print(e)# print("load textdata ", self.sqlquedata)def load_textdata_vec(self):num0 = 0for recode in self.sqlquedata:db_id = recode[0]sql_ques = recode[1]for sql_que in sql_ques:sql_id = sql_que["sql_id"]question = sql_que["question"]sql = sql_que["sql"]ddl_embeddings = self.model.encode([question], normalize_embeddings=True)self.index.add(ddl_embeddings)self.i2dbid[num0] = db_idself.i2sqlid[num0] = sql_idself.id2que[sql_id] = questionself.id2sql[sql_id] = sqlnum0 += 1print("init sql-que vec", num0)def calculate_score(sim_v, question, sql_ques):passdef find_vec_sqlque(self, question, k, theta, dataSetID, number):q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]dbid = self.i2dbid.get(sim_i, "none") # 獲取數據庫idsqlid = self.i2sqlid.get(sim_i, "none")question = self.id2que.get(sqlid, "none")sql = self.id2sql.get(sqlid, "none")if dbid == dataSetID:sim_v = D[0][i]score = int(sim_v * 1000)if score < theta:doc = {}doc["score"] = scoredoc["question"] = questiondoc["sql"] = sqlresult.append(doc)if len(result) == number:breakreturn resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = SqlQuesSearch(model)result = vs.find_vec_sqlque("查詢7月18日所有的兒童觀察記錄?", 3, 2000, dataSetID=111)print(result)
3.數據庫DDL訓練
class DdlQuesSearch(object):def __init__(self, _model):self.name = "DdlQuesSearch"self.model = _modelself.instruction = "為這段內容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.ddldata = []self.sqlques = {}self.i2dbid = {}self.i2ddlid = {}self.dbid2ddls = {}self.id2ddl = {}self.ddlid2dbid = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_ddldata() # 加載text數據self.load_ddl_vec() # text數據向量化def load_ddldata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)for database in databases:db_id = database["dataSetID"]ddls = database["ddl"]self.ddldata.append((db_id, ddls))# print(db_id)# for ddl in database["ddl"]:# ddl_id = ddl["ddl_id"]# ddl = ddl['ddl']## self.id2ddl[ddl_id] = ddl# self.dbid2ddls[db_id] = self.id2ddlexcept Exception as e:print(e)# print("load textdata ", self.ddldata)def load_ddl_vec(self):num0 = 0for recode in self.ddldata:db_id = recode[0]ddls = recode[1]for ddl in ddls:ddl_id = ddl["ddl_id"]ddl_name = ddl["TABLE"]ddl = ddl['ddl']ddl_embeddings = self.model.encode([ddl], normalize_embeddings=True)self.index.add(ddl_embeddings)self.i2dbid[num0] = db_idself.i2ddlid[num0] = ddl_idself.id2ddl[ddl_id] = ddlself.ddlid2dbid[ddl_id] = db_idnum0 += 1self.dbid2ddls[db_id] = self.id2ddlprint("init ddl vec", num0)def find_vec_ddl(self, question, k, theata, dataSetID, number): # dataSetID:數據庫id# self.id2ddls.get(action_id)q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]dbid = self.i2dbid.get(sim_i, "none") # 獲取數據庫idddlid = self.i2ddlid.get(sim_i, "none")if dbid == dataSetID:sim_v = D[0][i]score = int(sim_v * 1000)if score < theata:doc = {}doc["score"] = scoredoc["ddl"] = self.id2ddl.get(ddlid, "none")result.append(doc)if len(result) == number:breakreturn resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DdlQuesSearch(model)ss = vs.find_vec_ddl("定時任務執行記錄表", 2, 2000, 111)print(ss)
4.數據庫document訓練
class DocQuesSearch(object):def __init__(self):self.name = "TestDataSearch"self.docdata = []self.load_doc_data()def load_doc_data(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)databases = jsonobj["data"]for database in databases:db_id = database["dataSetID"]doc = database["Intro"]self.docdata.append((db_id, doc))except Exception as e:print(e)# print("load ddldata ", self.docdata)def find_similar_doc(self, dataSetID):result = []for recode in self.docdata:dbid = recode[0]doc = recode[1]if dbid == dataSetID:result.append(doc)return resultif __name__ == '__main__':docques_search = DocQuesSearch()result = docques_search.find_similar_doc(222)print(result)
5.生成sql語句,這里使用的qwen-max模型
import re
import random
import os, json
import dashscope
from dashscope.api_entities.dashscope_response import Message
from ddl_engine import DdlQuesSearch
from dashscope import Generation
from sqlques_engine import SqlQuesSearch
from sentence_transformers import SentenceTransformerclass Genarate(object):def __init__(self):self.api_key = os.environ.get('api_key')self.model_name = os.environ.get('model')def system_message(self, message):return {'role': 'system', 'content': message}def user_message(self, message):return {'role': 'user', 'content': message}def assistant_message(self, message):return {'role': 'assistant', 'content': message}def submit_prompt(self, prompt):resp = Generation.call(model=self.model_name,messages=prompt,seed=random.randint(1, 10000),result_format='message',api_key=self.api_key)if resp["status_code"] == 200:answer = resp.output.choices[0].message.contentglobal DEBUG_INFODEBUG_INFO = (prompt, answer)return answerelse:answer = Nonereturn answerdef generate_sql(self, question, sql_result, ddl_result, doc_result):prompt = self.get_sql_prompt(question = question,sql_result = sql_result,ddl_result = ddl_result,doc_result = doc_result)print("SQL Prompt:",prompt)llm_response = self.submit_prompt(prompt)sql = self.extrat_sql(llm_response)return sqldef extrat_sql(self, llm_response):sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlreturn llm_responsedef get_sql_prompt(self, question, sql_result, ddl_result, doc_result):initial_prompt = "You are a SQL expert. " + \"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_result)initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_result)initial_prompt += ("===Response Guidelines \n""1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n""2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n""3. If the provided context is insufficient, please explain why it can't be generated. \n""4. Please use the most relevant table(s). \n""5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n")message_log = [self.system_message(initial_prompt)]message_log = self.add_sqlques_to_prompt(question, sql_result, message_log)return message_logdef add_ddl_to_prompt(self, initial_prompt, ddl_result):""":param initial_prompt::param ddl_result::return:"""ddl_list = [ ddl_['ddl'] for ddl_ in ddl_result]if len(ddl_list) > 0:initial_prompt += "\n===Tables \n"for ddl in ddl_list:initial_prompt += f"{ddl}\n\n"return initial_promptdef add_sqlques_to_prompt(self, question, sql_result, message_log):""":param sql_result::return:"""if len(sql_result) > 0:for example in sql_result:if example is not None and "question" in example and "sql" in example:message_log.append(self.user_message(example["question"]))message_log.append(self.assistant_message(example["sql"]))message_log.append(self.user_message(question))return message_logdef add_documentation_to_prompt(self, initial_prompt, doc_result):if len(doc_result) > 0:initial_prompt += "\n===Additional Context \n\n"for doc in doc_result:initial_prompt += f"{doc}\n\n"return initial_promptif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DdlQuesSearch(model)ss = vs.find_vec_ddl("定時任務執行記錄表", 1, 2000, 111)print(ss)
6.執行結果顯示
如圖可以看到正確生成了sql,可以正常執行,因為表是拉取到,沒有數據,所以查詢結果為空。
需要源碼的同學,可以留言。