問題描述
在使用 LangChain 和 Llama 模型生成 SQL 查詢時,遇到了 sqlite3.OperationalError
錯誤。錯誤信息如下:
OperationalError: (sqlite3.OperationalError) near "```sql
SELECT Name
FROM MediaType
LIMIT 5;
```": syntax error
[SQL: ```sql
SELECT Name
FROM MediaType
LIMIT 5;
```]
錯誤發生的原因是生成的 SQL 查詢包含了不必要的 Markdown 代碼塊標記 ```,也就是在生成SQL語句的過程中,產生了其他的不干凈文本,導致 SQL 語法錯誤。
最終解決方案
通過修改 PromptTemplate 來生成干凈的 SQL 查詢,確保生成的查詢不包含任何 Markdown 代碼塊標記或附加評論。以下是解決方案的詳細步驟和代碼實現:
1. 初始化環境
首先,初始化所需的環境變量和模型:
import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool# 如果沒有設置 GROQ_API_KEY,則提示用戶輸入
if not os.environ.get("GROQ_API_KEY"):os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")# 初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-versatile", model_provider="groq", temperature=0)
2. 定義自定義提示模板
定義一個自定義的 PromptTemplate,用于生成干凈的 SQL 查詢:
custom_prompt = PromptTemplate(input_variables=["dialect", "input", "table_info", "top_k"],template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)
3. 創建 SQL 查詢鏈
創建一個 SQL 查詢鏈,并使用自定義提示模板:
write_query = create_sql_query_chain(llm, db, prompt=custom_prompt)
4. 構造輸入數據字典
構造輸入數據字典,其中包含方言、表結構、問題和行數限制:
input_data = {"dialect": db.dialect, # 數據庫方言,如 "sqlite""table_info": db.get_table_info(), # 表結構信息"input": "What name of MediaType is?", # 問題"top_k": 5 # 行數限制
}
5. 調用鏈生成并執行 SQL 查詢
調用鏈生成 SQL 查詢,確保生成的查詢不包含 Markdown 代碼塊標記,然后執行查詢并打印結果:
response = write_query.invoke(input_data)
query = response["query"]# 執行 SQL 查詢并打印結果
execute_query = QuerySQLDataBaseTool(db=db)
result = execute_query.invoke({"query": query})
print(result)
總結
通過修改 PromptTemplate 來生成 SQL 查詢時,明確要求返回的 SQL 查詢不包含任何附加評論或 Markdown 格式,確保生成的 SQL 查詢是干凈的、可執行的。這樣可以避免由多余的標記導致的 SQL 語法錯誤。
最后提供完整代碼:
import getpass
import os
from langchain.chat_models import init_chat_model
from langchain_core.prompts import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from dotenv import load_dotenv
from pyprojroot import here
from langchain.chains import create_sql_query_chain
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabaseload_dotenv()# 如果沒有設置 GROQ_API_KEY,則提示用戶輸入
if not os.environ.get("GROQ_API_KEY"):os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
table_info = db.get_table_info(["Album"]) # 注意需要傳遞列表
print(f"\n Original table info: {table_info}")# 初始化 Llama 模型,使用 Groq 后端
llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)
# 定義自定義提示模板,用于生成 SQL 查詢
custom_prompt = PromptTemplate(input_variables=["dialect", "input", "table_info", "top_k"],template="""You are a SQL expert using {dialect}.
Given the following table schema:
{table_info}
Generate a syntactically correct SQL query to answer the question: "{input}".
Limit the results to at most {top_k} rows.
Return only the SQL query without any additional commentary or Markdown formatting.
"""
)write_query = create_sql_query_chain(llm, db,prompt=custom_prompt)
# 構造輸入數據字典,其中包含方言、表結構、問題和行數限制
input_data = {"dialect": db.dialect, # 數據庫方言,如 "sqlite""table_info": db.get_table_info(), # 表結構信息"question": "What name of MediaType is?","top_k": 5
}# 調用鏈生成 SQL 查詢,返回結果為一個字典,包含鍵 "query"
write_query_response = write_query.invoke(input_data)
print('\n write_query result:',write_query_response)#執行SQL語句
execute_query = QuerySQLDataBaseTool(db=db)
execute_response = execute_query.invoke(write_query_response)
print('\n execute_response result:',execute_response)#兩個動作合起來搞成鏈
chain = write_query | execute_query
result_chain = chain.invoke(input_data)
print('\n result_chain==',result_chain)
輸出: