一、食用指南
基于SQL數據庫的智能問答系統設計與實現介紹了在數據庫中創建表格數據問答系統的基本方法,我們可以向該系統提出關于數據庫數據的問題,最終獲得自然語言答案。
為了針對數據庫編寫有效的查詢,我們需要向模型提供表名、表結構和特征值以供其查詢。當存在許多表、列和/或高基數列時,我們不可能在每個提示中傾倒關于數據庫的全部信息,相反,我們必須找到一種方法,僅將最相關的信息動態地插入到提示中。
本文介紹識別此類相關信息并將其輸入到查詢生成步驟中的方法,我們將涵蓋:
- 識別相關的表子集;
- 識別相關的列值子集。
二、安裝依賴
%pip install --upgrade --quiet langchain langchain-community langchain-openai
三、示例數據
# 下載sql腳本
wget https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql
sqlite3 Chinook.db
.read Chinook_Sqlite.sql
SELECT * FROM Artist LIMIT 10;
$ sqlite3 Chinook.db
SQLite version 3.45.3 2024-04-15 13:34:05
Enter ".help" for usage hints.
sqlite> .read Chinook_Sqlite.sql
sqlite> SELECT * FROM Artist LIMIT 10;
1|AC/DC
2|Accept
3|Aerosmith
4|Alanis Morissette
5|Alice In Chains
6|Ant?nio Carlos Jobim
7|Apocalyptica
8|Audioslave
9|BackBeat
10|Billy Cobham
sqlite> .quit
現在,Chinook.db
位于我們的目錄中,我們可以使用 SQLAlchemy
驅動的 SQLDatabase
類與之交互:
from langchain_community.utilities import SQLDatabasedb = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
print(db.run("SELECT * FROM Artist LIMIT 10;"))
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Ant?nio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]
四、LLM
%pip install langchain-openai
import osos.environ["OPENAI_BASE_URL"] = "https://api.siliconflow.cn/v1/"
os.environ["OPENAI_API_KEY"] = "sk-xxx"from langchain.chat_models import init_chat_modelllm = init_chat_model("Qwen/Qwen3-8B", model_provider="openai")
# llm = init_chat_model("THUDM/GLM-Z1-9B-0414", model_provider="openai")
# llm = init_chat_model("deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", model_provider="openai")
這里使用硅基流動的免費模型服務,以上代碼中使用 Qwen/Qwen3-8B
模型,當然也可以使用其他免費模型,直接復制官網上的模型名稱即可,點擊這里直達官網,注冊完成后創建一個 API 密鑰就能使用模型了。
五、相關表
我們需要在提示中包含的主要信息之一是相關表的結構,當我們有非常多的數據表時,無法將所有表都放入一個提示中。在這種情況下,我們可以先提取與用戶輸入最相關的表名,一種簡單可靠的方法是使用工具調用,agent 通過工具獲取符合查詢所需格式的輸出(在本例中為表名list),我們使用聊天模型的 .bind_tools
方法綁定一個 Pydantic
格式的工具,并將其輸入到輸出解析器中,以從模型的響應中重建對象。
from langchain_core.output_parsers.openai_tools import PydanticToolsParser
from langchain_core.prompts import ChatPromptTemplate
from pydantic import BaseModel, Fieldclass Table(BaseModel):"""Table in SQL database."""name: str = Field(description="Name of table in SQL database.")table_names = "\n".join(db.get_usable_table_names())
system = f"""返回可能與用戶問題相關的所有SQL表名稱。表包括:{table_names}。請記住包含所有可能相關的表,即使不確定是否需要它們。"""prompt = ChatPromptTemplate.from_messages([("system", system),("human", "{input}"),]
)
llm_with_tools = llm.bind_tools([Table])
output_parser = PydanticToolsParser(tools=[Table])table_chain = prompt | llm_with_tools | output_parsertable_chain.invoke({"input": "Alanis Morissette 歌曲的所有流派是什么"})
[Table(name='Artist'), Table(name='Track'), Table(name='Genre')]
效果很好,返回的三個表名中,Genre
是真正所需要的。實際上,我們還需要一些其他表才能把信息鏈打通,但僅根據用戶問題,模型很難知道這些。在這種情況下,我們可以考慮通過將表分組來簡化模型的工作,只要求模型在“音樂”和“業務”類別之間進行選擇所有相關表。
system = """返回與用戶問題相關的SQL表名。可用的表有:
1. 音樂(Music)
2. 業務(Business)
"""prompt = ChatPromptTemplate.from_messages([("system", system),("human", "{input}"),]
)category_chain = prompt | llm_with_tools | output_parser
category_chain.invoke({"input": "Alanis Morissette 歌曲的所有流派是什么"})
[Table(name='Music')]
根據返回結果再做細分處理:
from typing import Listdef get_tables(categories: List[Table]) -> List[str]:tables = []for category in categories:if category.name == "Music":tables.extend(["Album","Artist","Genre","MediaType","Playlist","PlaylistTrack","Track",])elif category.name == "Business":tables.extend(["Customer", "Employee", "Invoice", "InvoiceLine"])return tablestable_chain = category_chain | get_tables
table_chain.invoke({"input": "Alanis Morissette 歌曲的所有流派是什么"})
['Album', 'Artist', 'Genre', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
現在我們有了一個可以為任何查詢輸出相關表的 Chain,將其與 create_sql_query_chain
連接起來,該 Chain 可以接受一個 table_names_to_use
列表來確定提示中包含哪些表:
from operator import itemgetterfrom langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthroughquery_chain = create_sql_query_chain(llm, db)
# Convert "question" key to the "input" key expected by current table_chain.
table_chain = {"input": itemgetter("question")} | table_chain
# Set table_names_to_use using table_chain.
full_chain = RunnablePassthrough.assign(table_names_to_use=table_chain) | query_chain
query_chain
RunnableAssign(mapper={input: RunnableLambda(...),table_info: RunnableLambda(...)
})
| RunnableLambda(lambda x: {k: v for (k, v) in x.items() if k not in ('question', 'table_names_to_use')})
| PromptTemplate(input_variables=['input', 'table_info'], input_types={}, partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question involves "today".\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use the following tables:\n{table_info}\n\nQuestion: {input}')
| RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x74101c7aacb0>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x74101c7a9de0>, root_client=<openai.OpenAI object at 0x74101e9b24a0>, root_async_client=<openai.AsyncOpenAI object at 0x74101c2af100>, model_name='Qwen/Qwen3-8B', model_kwargs={}, openai_api_key=SecretStr('**********')), kwargs={'stop': ['\nSQLResult:']}, config={}, config_factories=[])
| StrOutputParser()
| RunnableLambda(_strip)
full_chain
RunnableAssign(mapper={table_names_to_use: {input: RunnableLambda(itemgetter('question'))}| ChatPromptTemplate(input_variables=['input'], input_types={}, partial_variables={}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=[], input_types={}, partial_variables={}, template='返回與用戶問題相關的SQL表名。可用的表有:\n1. 音樂(Music)\n2. 業務(Business)\n'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], input_types={}, partial_variables={}, template='{input}'), additional_kwargs={})])| RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x74101c7aacb0>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x74101c7a9de0>, root_client=<openai.OpenAI object at 0x74101e9b24a0>, root_async_client=<openai.AsyncOpenAI object at 0x74101c2af100>, model_name='Qwen/Qwen3-8B', model_kwargs={}, openai_api_key=SecretStr('**********')), kwargs={'tools': [{'type': 'function', 'function': {'name': 'Table', 'description': 'Table in SQL database.', 'parameters': {'properties': {'name': {'description': 'Name of table in SQL database.', 'type': 'string'}}, 'required': ['name'], 'type': 'object'}}}]}, config={}, config_factories=[])| PydanticToolsParser(tools=[<class '__main__.Table'>])| RunnableLambda(get_tables)
})
| RunnableAssign(mapper={input: RunnableLambda(...),table_info: RunnableLambda(...)})
| RunnableLambda(lambda x: {k: v for (k, v) in x.items() if k not in ('question', 'table_names_to_use')})
| PromptTemplate(input_variables=['input', 'table_info'], input_types={}, partial_variables={'top_k': '5'}, template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question involves "today".\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use the following tables:\n{table_info}\n\nQuestion: {input}')
| RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x74101c7aacb0>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x74101c7a9de0>, root_client=<openai.OpenAI object at 0x74101e9b24a0>, root_async_client=<openai.AsyncOpenAI object at 0x74101c2af100>, model_name='Qwen/Qwen3-8B', model_kwargs={}, openai_api_key=SecretStr('**********')), kwargs={'stop': ['\nSQLResult:']}, config={}, config_factories=[])
| StrOutputParser()
| RunnableLambda(_strip)
測試:
query = full_chain.invoke({"question": "Alanis Morissette 歌曲的所有流派是什么"}
)
print(query)
SQLQuery: SELECT DISTINCT "Genre"."Name" FROM "Track" JOIN "Album" ON "Track"."AlbumId" = "Album"."AlbumId" JOIN "Artist" ON "Album"."ArtistId" = "Artist"."ArtistId" JOIN "Genre" ON "Track"."GenreId" = "Genre"."GenreId" WHERE "Artist"."Name" = 'Alanis Morissette' LIMIT 5;
執行 SQL:
db.run(query.replace("SQLQuery: ",""))
"[('Rock',)]"
至此,我們實現了在 Chain 中動態地在提示詞中提供相關表。
解決此問題的另一種可能方法是讓 Agent 通過調用工具來決定何時查找表,這個過程可能會需要多次調用查詢工具,具體細節可參考:基于SQL數據庫的智能問答系統設計與實現中的 Agent 部分。
六、高基數列
為了過濾包含專有名詞(如地址、歌曲名稱或藝術家)的列,我們首先需要仔細檢查拼寫,以正確過濾數據。我們可以通過創建一個包含數據庫中所有不同專有名詞的向量存儲來實現這一點,然后,每當用戶在問題中包含專有名詞時,讓 agent 查詢該向量存儲,以找到該詞的正確拼寫。通過這種方式,agent 可以確保在構建目標查詢之前,它理解用戶指的是哪個實體。
首先,將結果解析為元素列表:
import ast
import redef query_as_list(db, query):res = db.run(query)res = [el for sub in ast.literal_eval(res) for el in sub if el]res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]return resproper_nouns = query_as_list(db, "SELECT Name FROM Artist")
proper_nouns += query_as_list(db, "SELECT Title FROM Album")
proper_nouns += query_as_list(db, "SELECT Name FROM Genre")
len(proper_nouns)
proper_nouns[:5]
['AC/DC', 'Accept', 'Aerosmith', 'Alanis Morissette', 'Alice In Chains']
現在我們可以將所有值嵌入并存儲在向量數據庫中:
# %pip install faiss-gpu
%pip install faiss-cpu
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddingsvector_db = FAISS.from_texts(proper_nouns, OpenAIEmbeddings(model="BAAI/bge-m3",base_url="http://localhost:8000/v1", api_key="EMPTY"))
retriever = vector_db.as_retriever(search_kwargs={"k": 15})
向量查詢:
retriever.invoke("elanis Morisset")
[Document(id='0a7ad312-dbba-4a56-883a-7cf90edfbdf5', metadata={}, page_content='Alanis Morissette'),Document(id='9176353f-7047-4c9a-8780-b24174fb1f3d', metadata={}, page_content='Elis Regina'),Document(id='9a876473-aaea-4b86-8467-132008632795', metadata={}, page_content='Volume Dois'),Document(id='17298f91-b479-4447-9480-9712ef722412', metadata={}, page_content='Xis'),Document(id='0f7d044d-290d-4b43-b436-ab8e82b86688', metadata={}, page_content='Handel: Music for the Royal Fireworks (Original Version )'),Document(id='ef588f64-b5bf-40f0-9b06-f55ef7927435', metadata={}, page_content='LOST, Season'),Document(id='5bd80286-6b27-44af-b8a5-d5f52de2e125', metadata={}, page_content='Garage Inc. (Disc )'),Document(id='405229d7-a098-4425-ad97-7afb4d3a459a', metadata={}, page_content='Garage Inc. (Disc )'),Document(id='a108b0fc-e7ab-4095-9f6f-9a1a359717e1', metadata={}, page_content='Surfing with the Alien (Remastered)'),Document(id='36f7c22d-a4a1-4644-a026-a283f78dd761', metadata={}, page_content="Christopher O'Riley"),Document(id='71563ff6-ee38-439b-b743-c7e187bf34d5', metadata={}, page_content='Speak of the Devil'),Document(id='d254a91f-9745-4a69-ac07-39ae4b1655c1', metadata={}, page_content='The Police'),Document(id='5b74ec24-428a-42b8-99fb-9c1442ca47c0', metadata={}, page_content='Vs.'),Document(id='1a1ce9a3-3b28-4c73-a1ba-443e11051672', metadata={}, page_content='Elis Regina-Minha História'),Document(id='4679d89e-e1a8-46fb-884d-0ddf7b4e9953', metadata={}, page_content='Blue Moods')]
組合一個查詢 Chain,該 Chain 首先從數據庫中檢索值并將其插入到提示詞中:
from operator import itemgetterfrom langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthroughsystem = """您是SQLite專家。根據輸入問題生成語法正確的SQLite查詢,除非另有說明,返回不超過{top_k}行結果。僅返回SQL查詢語句,不要包含任何標記或解釋。相關表信息:{table_info}以下是可能特征值的非窮舉列表。若需按特征值篩選,請先核對拼寫:{proper_nouns}
"""prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])query_chain = create_sql_query_chain(llm, db, prompt=prompt)
retriever_chain = (itemgetter("question")| retriever| (lambda docs: "\n".join(doc.page_content for doc in docs))
)
chain = RunnablePassthrough.assign(proper_nouns=retriever_chain) | query_chain
query_chain
RunnableAssign(mapper={input: RunnableLambda(...),table_info: RunnableLambda(...)
})
| RunnableLambda(lambda x: {k: v for (k, v) in x.items() if k not in ('question', 'table_names_to_use')})
| ChatPromptTemplate(input_variables=['input', 'proper_nouns', 'table_info'], input_types={}, partial_variables={'top_k': '5'}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['proper_nouns', 'table_info', 'top_k'], input_types={}, partial_variables={}, template='您是SQLite專家。根據輸入問題生成語法正確的SQLite查詢,除非另有說明,返回不超過{top_k}行結果。\n\n僅返回SQL查詢語句,不要包含任何標記或解釋。\n\n相關表信息:{table_info}\n\n以下是可能特征值的非窮舉列表。若需按特征值篩選,請先核對拼寫:\n\n{proper_nouns}\n'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], input_types={}, partial_variables={}, template='{input}'), additional_kwargs={})])
| RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x74101c7aacb0>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x74101c7a9de0>, root_client=<openai.OpenAI object at 0x74101e9b24a0>, root_async_client=<openai.AsyncOpenAI object at 0x74101c2af100>, model_name='Qwen/Qwen3-8B', model_kwargs={}, openai_api_key=SecretStr('**********')), kwargs={'stop': ['\nSQLResult:']}, config={}, config_factories=[])
| StrOutputParser()
| RunnableLambda(_strip)
chain
RunnableAssign(mapper={proper_nouns: RunnableLambda(itemgetter('question'))| VectorStoreRetriever(tags=['FAISS', 'OpenAIEmbeddings'], vectorstore=<langchain_community.vectorstores.faiss.FAISS object at 0x7410111f3b80>, search_kwargs={'k': 15})| RunnableLambda(...)
})
| RunnableAssign(mapper={input: RunnableLambda(...),table_info: RunnableLambda(...)})
| RunnableLambda(lambda x: {k: v for (k, v) in x.items() if k not in ('question', 'table_names_to_use')})
| ChatPromptTemplate(input_variables=['input', 'proper_nouns', 'table_info'], input_types={}, partial_variables={'top_k': '5'}, messages=[SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['proper_nouns', 'table_info', 'top_k'], input_types={}, partial_variables={}, template='您是SQLite專家。根據輸入問題生成語法正確的SQLite查詢,除非另有說明,返回不超過{top_k}行結果。\n\n僅返回SQL查詢語句,不要包含任何標記或解釋。\n\n相關表信息:{table_info}\n\n以下是可能特征值的非窮舉列表。若需按特征值篩選,請先核對拼寫:\n\n{proper_nouns}\n'), additional_kwargs={}), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], input_types={}, partial_variables={}, template='{input}'), additional_kwargs={})])
| RunnableBinding(bound=ChatOpenAI(client=<openai.resources.chat.completions.completions.Completions object at 0x74101c7aacb0>, async_client=<openai.resources.chat.completions.completions.AsyncCompletions object at 0x74101c7a9de0>, root_client=<openai.OpenAI object at 0x74101e9b24a0>, root_async_client=<openai.AsyncOpenAI object at 0x74101c2af100>, model_name='Qwen/Qwen3-8B', model_kwargs={}, openai_api_key=SecretStr('**********')), kwargs={'stop': ['\nSQLResult:']}, config={}, config_factories=[])
| StrOutputParser()
| RunnableLambda(_strip)
現在可以測試效果,看看在不使用檢索和使用檢索的情況下,嘗試在歌手名字拼寫錯誤提問時會返回什么。
# Without retrieval
query = query_chain.invoke({"question": "elanis Morissette歌曲的所有流派是什么", "proper_nouns": ""}
)
print(query)
db.run(query)
SELECT Genre.Name
FROM Track
JOIN Genre ON Track.GenreId = Genre.GenreId
WHERE Track.Composer LIKE '%Elanis Morissette%';
''
# With retrieval
query = chain.invoke({"question": "Alanis Morissett歌曲的所有流派是什么"})
print(query)
db.run(query)
SELECT DISTINCT g.Name
FROM Track t
JOIN Genre g ON t.GenreId = g.GenreId
WHERE t.Composer LIKE '%Alanis Morissette%' OR t.Name LIKE '%Alanis Morissette%';
"[('Rock',)]"
我們可以看到,通過檢索能夠將錯誤的拼寫糾正并獲得有效結果。
參考資料
- https://python.langchain.ac.cn/docs/how_to/sql_large_db/