NLP2SQL(自然語言處理到 SQL 查詢的轉換)是一個重要的自然語言處理(NLP)任務,其目標是將用戶的自然語言問題轉換為相應的 SQL 查詢。這一任務在許多場景下具有廣泛的應用,尤其是在與數據庫交互的場景中,例如數據分析、業務智能和問答系統。
任務目標
- 理解自然語言: 理解用戶輸入的自然語言問題,包括意圖、實體和上下文。
- 生成 SQL 查詢: 將理解后的信息轉換為正確的 SQL 查詢,以從數據庫中檢索所需的數據。
例如
輸入: 用戶的自然語言問題,“獲取 Gelderland 區的總人口。”
輸出: 對應的 SQL 查詢,
SELECT population FROM districts WHERE name = 'Gelderland';
Spider?是一個難度最大數據集
耶魯大學在2018年新提出的一個大規模的NL2SQL(Text-to-SQL)數據集。
該數據集包含了10,181條自然語言問句、分布在200個獨立數據庫中的5,693條SQL,內容覆蓋了138個不同的領域。
涉及的SQL語法最全面,是目前難度最大的NL2SQL數據集。
下載查看spider數據集內容
Question 1: How many singers do we have ? ||| concert_singer
SQL: select count(*) from singer
Question 2: What is the total number of singers ? ||| concert_singer
SQL: select count(*) from singer
Question 3: Show name , country , age for all singers ordered by age from the oldest to the youngest . ||| concert_singer
SQL: select name , country , age from singer order by age desc...
首先需要轉換為Spider的標準格式(參考tables.json
和train.json
):
{"db_id": "concert_singer","question": "Show name, country, age...","query": "SELECT name, country, age FROM singer ORDER BY age DESC","schema": {"table_names": ["singer"],"column_names": [[0, "name", "text"],[0, "country", "text"],[0, "age", "int"]]}
}
拆分為table.json的原因可能涉及到數據組織和重用。每個數據庫的結構(表、列、外鍵)在多個問題中都會被重復使用。如果每個問題都附帶完整的schema信息,會導致數據冗余,增加存儲和處理的開銷。所以,將schema單獨存儲為table.json,可以讓不同的數據條目引用同一個數據庫模式,減少重復數據。拆分后的結構需要更高效的數據管理,例如在訓練模型時,根據每個問題的db_id去table.json中查找對應的schema信息。這樣做的好處是當多個問題屬于同一個數據庫時,不需要每次都重復加載schema提高了效率。
column_names
表示數據庫表中每一列的詳細信息。具體來說,column_names
是一個列表,其中每個元素都是一個包含三個部分的子列表:
- 表索引(0):表示該列屬于哪個表。在這個例子中,所有列都屬于第一個表(索引為 0)。
- 列名("name"、"country"、"age"):表示列的名稱。
- 數據類型("text"、"int"):表示該列的數據類型,例如文本(
text
)或整數(int
)。
實現下面邏輯轉換原始數據
def extract_columns_from_sql(sql):# 使用正則表達式匹配 SELECT 語句中的列名match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)if match:# 提取列名columns = match.group(1).split(",")# 構建 column_names 列表column_names = []for index, column in enumerate(columns):column = column.strip() # 去除多余的空格data_type = "text" # 默認數據類型為 text,可以根據需要修改# 添加到 column_names 列表,假設所有列類型為 textcolumn_names.append([0, column, data_type])return column_namesreturn []# 從 dev.sql 文件讀取數據
def load_sql_data(file_path):data_list = []with open(file_path, 'r', encoding='utf-8') as f: # 指定編碼為 UTF-8lines = f.readlines()for i in range(0, len(lines), 3): # 每三行一組question_line = lines[i].strip()sql_line = lines[i + 1].strip()if not question_line or not sql_line:continue# 提取問題和 SQLquestion = question_line.split(': ', 1)[1].strip() # 獲取問題內容sql = sql_line.split(': ', 1)[1].strip() # 獲取 SQL 查詢# 提取表名db_id = question_line.split('|||')[-1].strip() # 從問題行獲取表名question = question.split('|||')[0].strip()target_sql = preprocess(question, db_id, sql)data_list.append({"input_text": f"Translate to SQL: {question} [SEP] Tables: {db_id}","target_sql": json.dumps(target_sql) # 將目標 SQL 轉換為 JSON 格式字符串})return data_list
選擇Tokenizer.from_pretrained("t5-base")
是用于加載 T5(Text-to-Text Transfer Transformer)模型的分詞器。T5 是一個強大的自然語言處理模型,能夠處理各種文本任務(如翻譯、摘要、問答等),并且將所有任務視為文本到文本的轉換。
from transformers import T5Tokenizertokenizer = T5Tokenizer.from_pretrained("t5-base")def preprocess(question, db_id, sql):# 提取列名column_names = extract_columns_from_sql(sql)# 構建目標格式target_sql = {"db_id": db_id,"question": question,"query": sql,"schema": {"table_names": [db_id],"column_names": column_names}}return target_sql# 示例數據
question = "Show name, country, age for all singers ordered by age from the oldest to the youngest."
schema = "singer(name, country, age)"
sql = "SELECT name, country, age FROM singer ORDER BY age DESC"input_text, target_sql = preprocess(question, schema, sql)
# input_text = "Translate to SQL: Show name... [SEP] Tables: singer(name, country, age)"
# target_sql = "select name, country, age from singer order by age desc"
print('input_text', input_text)
print('target_sql', target_sql)
所有nlp任務都涉及的需要token化,使用t5-base 做tokenize
def tokenize_function(examples):model_inputs = tokenizer(examples["input_text"],max_length=512,truncation=True,padding="max_length")with tokenizer.as_target_tokenizer():labels = tokenizer(examples["target_sql"],max_length=512,truncation=True,padding="max_length")model_inputs["labels"] = labels["input_ids"]return model_inputs
使用 tokenizer.as_target_tokenizer()
上下文管理器,確保目標文本(即 SQL 查詢)被正確處理。目標文本也經過編碼,轉換為 token IDs,并同樣進行填充和截斷。將目標文本的編碼結果(token IDs)存儲在 model_inputs["labels"]
中。這是模型在訓練時需要的輸出,用于計算損失。最終返回一個字典 model_inputs
,它包含了模型的輸入和對應的標簽。這種結構使得模型在訓練時可以直接使用。
最后組織下訓練代碼
tokenized_datasets = dataset.map(tokenize_function, batched=True)# 加載模型
model = T5ForConditionalGeneration.from_pretrained("t5-base")# 訓練參數
training_args = Seq2SeqTrainingArguments(output_dir="./results",evaluation_strategy="epoch",learning_rate=3e-5,per_device_train_batch_size=8,per_device_eval_batch_size=8,num_train_epochs=100,predict_with_generate=True,run_name="spider"
)# 開始訓練
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=tokenized_datasets["train"] if 'train' in tokenized_datasets else tokenized_datasets,eval_dataset=tokenized_datasets["test"] if 'test' in tokenized_datasets else None,data_collator=DataCollatorForSeq2Seq(tokenizer)
)trainer.train()
這里使用的是Seq2SeqTrainer, 它
是 Hugging Face 的 transformers
庫中用于序列到序列(Seq2Seq)任務的訓練器。它為處理諸如翻譯、文本生成和問答等任務提供了一個高層次的接口,簡化了訓練過程。以下是 Seq2SeqTrainer
的主要功能和特點:
-
簡化訓練流程:?
Seq2SeqTrainer
封裝了許多常見的訓練步驟,如數據加載、模型訓練、評估和預測,使得用戶可以更專注于模型和數據,而不必處理繁瑣的訓練細節。 -
支持多種訓練參數: 通過
Seq2SeqTrainingArguments
類,可以靈活配置訓練參數,如學習率、批量大小、訓練輪數、評估策略等。 -
自動處理填充和截斷: 在處理輸入和輸出序列時,
Seq2SeqTrainer
可以自動填充和截斷序列,以確保它們適應模型的輸入要求。 -
集成評估和監控: 支持在訓練過程中進行模型評估,并可以根據評估指標(如損失)監控訓練進度。用戶可以設置評估頻率和評估數據集
開始訓練,進行100次epoch
訓練監控在?Weights & Biases?,Seq2SeqTrainer
能夠向 Weights & Biases (wandb) 傳輸訓練監控數據,主要是因為它內置了與 wandb 的集成。以下是一些關鍵點,解釋了這一過程:
-
自動集成:當你使用
Seq2SeqTrainer
時,它會自動檢測 wandb 的安裝并在初始化時配置相關設置。這意味著你無需手動設置 wandb。 -
回調功能:
Trainer
類提供了回調功能,可以在訓練過程中記錄各種指標(如損失、準確率等)。這些指標會被自動發送到 wandb。 -
配置管理:
training_args
中的參數可以指定 wandb 的項目名稱、運行名稱等,從而更好地組織和管理實驗。 -
訓練循環:在每個訓練和評估周期結束時,
Trainer
會調用相應的回調函數,將重要的訓練信息(如損失、學習率等)記錄到 wandb。 -
可視化:通過 wandb,你可以實時監控訓練過程,包括損失曲線、模型性能等,幫助你更好地理解模型的訓練動態。
多次試驗還可以比較訓練性能
訓練結束, 損失收斂到0.05410315271151268
{'eval_loss': 0.008576861582696438, 'eval_runtime': 1.3883, 'eval_samples_per_second': 74.912, 'eval_steps_per_second': 5.042, 'epoch': 100.0}
{'train_runtime': 2914.0548, 'train_samples_per_second': 31.914, 'train_steps_per_second': 2.025, 'train_loss': 0.05410315271151268, 'epoch': 100.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5900/5900 [48:31<00:00, ?2.03it/s]
wandb:
wandb: 🚀 View run spider at: https://wandb.ai/chenruithinking-4th-paradigm/huggingface/runs/dkccvpp4
wandb: Find logs at: wandb/run-20250207_112702-dkccvpp4/logs
測試下預測能力
import os
from transformers import T5Tokenizer, T5ForConditionalGeneration# 設置 NCCL 環境變量
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"# 加載分詞器
tokenizer = T5Tokenizer.from_pretrained("t5-base")model = T5ForConditionalGeneration.from_pretrained("./results/t5-sql-model")
tokenizer.save_pretrained("./results/t5-sql-model")def generate_sql(question, db_id):input_text = f"Translate to SQL: {question} [SEP] Tables: {db_id}"input_ids = tokenizer.encode(input_text, return_tensors="pt") # 使?~T? PyTorch ?~Z~D?| ?~G~O?| ??~Ooutput = model.generate(input_ids,max_length=512,num_beams=5, # 或者嘗試其他解碼策略early_stopping=True)print('output', output)generated_sql = tokenizer.decode(output[0], skip_special_tokens=True)return generated_sqlquestion = "How many singers do we have ?"
db_id = "concert_singer"
evaluation_output = generate_sql(question, db_id)
print("evaluation_output:", evaluation_output)
輸出結果
evaluation_output: "db_id": "concert_singer", "question": "How many singers do we have ?", "query": "select count(*) from singer", "schema": "table_names": ["concert_singer"], "column_names": [[0, "count(*)", "text"]]