Spider 數據集上實現nlp2sql訓練任務

NLP2SQL(自然語言處理到 SQL 查詢的轉換)是一個重要的自然語言處理(NLP)任務,其目標是將用戶的自然語言問題轉換為相應的 SQL 查詢。這一任務在許多場景下具有廣泛的應用,尤其是在與數據庫交互的場景中,例如數據分析、業務智能和問答系統。

任務目標
  • 理解自然語言: 理解用戶輸入的自然語言問題,包括意圖、實體和上下文。
  • 生成 SQL 查詢: 將理解后的信息轉換為正確的 SQL 查詢,以從數據庫中檢索所需的數據。

例如

輸入: 用戶的自然語言問題,“獲取 Gelderland 區的總人口。”

輸出: 對應的 SQL 查詢,SELECT population FROM districts WHERE name = 'Gelderland';

Spider?是一個難度最大數據集

耶魯大學在2018年新提出的一個大規模的NL2SQL(Text-to-SQL)數據集。
該數據集包含了10,181條自然語言問句、分布在200個獨立數據庫中的5,693條SQL,內容覆蓋了138個不同的領域。
涉及的SQL語法最全面,是目前難度最大的NL2SQL數據集。

下載查看spider數據集內容

Question 1: How many singers do we have ? ||| concert_singer
SQL: select count(*) from singer

Question 2: What is the total number of singers ? ||| concert_singer
SQL: select count(*) from singer

Question 3: Show name , country , age for all singers ordered by age from the oldest to the youngest . ||| concert_singer
SQL: select name , country , age from singer order by age desc

...

首先需要轉換為Spider的標準格式(參考tables.jsontrain.json):

{"db_id": "concert_singer","question": "Show name, country, age...","query": "SELECT name, country, age FROM singer ORDER BY age DESC","schema": {"table_names": ["singer"],"column_names": [[0, "name", "text"],[0, "country", "text"],[0, "age", "int"]]}
}

拆分為table.json的原因可能涉及到數據組織和重用。每個數據庫的結構(表、列、外鍵)在多個問題中都會被重復使用。如果每個問題都附帶完整的schema信息,會導致數據冗余,增加存儲和處理的開銷。所以,將schema單獨存儲為table.json,可以讓不同的數據條目引用同一個數據庫模式,減少重復數據。拆分后的結構需要更高效的數據管理,例如在訓練模型時,根據每個問題的db_id去table.json中查找對應的schema信息。這樣做的好處是當多個問題屬于同一個數據庫時,不需要每次都重復加載schema提高了效率。

column_names 表示數據庫表中每一列的詳細信息。具體來說,column_names 是一個列表,其中每個元素都是一個包含三個部分的子列表:

  1. 表索引(0):表示該列屬于哪個表。在這個例子中,所有列都屬于第一個表(索引為 0)。
  2. 列名("name"、"country"、"age"):表示列的名稱。
  3. 數據類型("text"、"int"):表示該列的數據類型,例如文本(text)或整數(int)。

實現下面邏輯轉換原始數據

def extract_columns_from_sql(sql):# 使用正則表達式匹配 SELECT 語句中的列名match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)if match:# 提取列名columns = match.group(1).split(",")# 構建 column_names 列表column_names = []for index, column in enumerate(columns):column = column.strip()  # 去除多余的空格data_type = "text"  # 默認數據類型為 text,可以根據需要修改# 添加到 column_names 列表,假設所有列類型為 textcolumn_names.append([0, column, data_type])return column_namesreturn []# 從 dev.sql 文件讀取數據
def load_sql_data(file_path):data_list = []with open(file_path, 'r', encoding='utf-8') as f:  # 指定編碼為 UTF-8lines = f.readlines()for i in range(0, len(lines), 3):  # 每三行一組question_line = lines[i].strip()sql_line = lines[i + 1].strip()if not question_line or not sql_line:continue# 提取問題和 SQLquestion = question_line.split(': ', 1)[1].strip()  # 獲取問題內容sql = sql_line.split(': ', 1)[1].strip()  # 獲取 SQL 查詢# 提取表名db_id = question_line.split('|||')[-1].strip()  # 從問題行獲取表名question = question.split('|||')[0].strip()target_sql = preprocess(question, db_id, sql)data_list.append({"input_text": f"Translate to SQL: {question} [SEP] Tables: {db_id}","target_sql": json.dumps(target_sql)  # 將目標 SQL 轉換為 JSON 格式字符串})return data_list

選擇Tokenizer.from_pretrained("t5-base") 是用于加載 T5(Text-to-Text Transfer Transformer)模型的分詞器。T5 是一個強大的自然語言處理模型,能夠處理各種文本任務(如翻譯、摘要、問答等),并且將所有任務視為文本到文本的轉換。

from transformers import T5Tokenizertokenizer = T5Tokenizer.from_pretrained("t5-base")def preprocess(question, db_id, sql):# 提取列名column_names = extract_columns_from_sql(sql)# 構建目標格式target_sql = {"db_id": db_id,"question": question,"query": sql,"schema": {"table_names": [db_id],"column_names": column_names}}return target_sql# 示例數據
question = "Show name, country, age for all singers ordered by age from the oldest to the youngest."
schema = "singer(name, country, age)"
sql = "SELECT name, country, age FROM singer ORDER BY age DESC"input_text, target_sql = preprocess(question, schema, sql)
# input_text = "Translate to SQL: Show name... [SEP] Tables: singer(name, country, age)"
# target_sql = "select name, country, age from singer order by age desc"
print('input_text', input_text)
print('target_sql', target_sql)

所有nlp任務都涉及的需要token化,使用t5-base 做tokenize

def tokenize_function(examples):model_inputs = tokenizer(examples["input_text"],max_length=512,truncation=True,padding="max_length")with tokenizer.as_target_tokenizer():labels = tokenizer(examples["target_sql"],max_length=512,truncation=True,padding="max_length")model_inputs["labels"] = labels["input_ids"]return model_inputs

使用 tokenizer.as_target_tokenizer() 上下文管理器,確保目標文本(即 SQL 查詢)被正確處理。目標文本也經過編碼,轉換為 token IDs,并同樣進行填充和截斷。將目標文本的編碼結果(token IDs)存儲在 model_inputs["labels"] 中。這是模型在訓練時需要的輸出,用于計算損失。最終返回一個字典 model_inputs,它包含了模型的輸入和對應的標簽。這種結構使得模型在訓練時可以直接使用。

最后組織下訓練代碼

tokenized_datasets = dataset.map(tokenize_function, batched=True)# 加載模型
model = T5ForConditionalGeneration.from_pretrained("t5-base")# 訓練參數
training_args = Seq2SeqTrainingArguments(output_dir="./results",evaluation_strategy="epoch",learning_rate=3e-5,per_device_train_batch_size=8,per_device_eval_batch_size=8,num_train_epochs=100,predict_with_generate=True,run_name="spider"
)# 開始訓練
trainer = Seq2SeqTrainer(model=model,args=training_args,train_dataset=tokenized_datasets["train"] if 'train' in tokenized_datasets else tokenized_datasets,eval_dataset=tokenized_datasets["test"] if 'test' in tokenized_datasets else None,data_collator=DataCollatorForSeq2Seq(tokenizer)
)trainer.train()

這里使用的是Seq2SeqTrainer, 它是 Hugging Face 的 transformers 庫中用于序列到序列(Seq2Seq)任務的訓練器。它為處理諸如翻譯、文本生成和問答等任務提供了一個高層次的接口,簡化了訓練過程。以下是 Seq2SeqTrainer 的主要功能和特點:

  1. 簡化訓練流程:?Seq2SeqTrainer 封裝了許多常見的訓練步驟,如數據加載、模型訓練、評估和預測,使得用戶可以更專注于模型和數據,而不必處理繁瑣的訓練細節。

  2. 支持多種訓練參數: 通過 Seq2SeqTrainingArguments 類,可以靈活配置訓練參數,如學習率、批量大小、訓練輪數、評估策略等。

  3. 自動處理填充和截斷: 在處理輸入和輸出序列時,Seq2SeqTrainer 可以自動填充和截斷序列,以確保它們適應模型的輸入要求。

  4. 集成評估和監控: 支持在訓練過程中進行模型評估,并可以根據評估指標(如損失)監控訓練進度。用戶可以設置評估頻率和評估數據集

開始訓練,進行100次epoch

訓練監控在?Weights & Biases?,Seq2SeqTrainer 能夠向 Weights & Biases (wandb) 傳輸訓練監控數據,主要是因為它內置了與 wandb 的集成。以下是一些關鍵點,解釋了這一過程:

  1. 自動集成:當你使用 Seq2SeqTrainer 時,它會自動檢測 wandb 的安裝并在初始化時配置相關設置。這意味著你無需手動設置 wandb。

  2. 回調功能Trainer 類提供了回調功能,可以在訓練過程中記錄各種指標(如損失、準確率等)。這些指標會被自動發送到 wandb。

  3. 配置管理training_args 中的參數可以指定 wandb 的項目名稱、運行名稱等,從而更好地組織和管理實驗。

  4. 訓練循環:在每個訓練和評估周期結束時,Trainer 會調用相應的回調函數,將重要的訓練信息(如損失、學習率等)記錄到 wandb。

  5. 可視化:通過 wandb,你可以實時監控訓練過程,包括損失曲線、模型性能等,幫助你更好地理解模型的訓練動態。

多次試驗還可以比較訓練性能

訓練結束, 損失收斂到0.05410315271151268

{'eval_loss': 0.008576861582696438, 'eval_runtime': 1.3883, 'eval_samples_per_second': 74.912, 'eval_steps_per_second': 5.042, 'epoch': 100.0}
{'train_runtime': 2914.0548, 'train_samples_per_second': 31.914, 'train_steps_per_second': 2.025, 'train_loss': 0.05410315271151268, 'epoch': 100.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5900/5900 [48:31<00:00, ?2.03it/s]
wandb:
wandb: 🚀 View run spider at: https://wandb.ai/chenruithinking-4th-paradigm/huggingface/runs/dkccvpp4
wandb: Find logs at: wandb/run-20250207_112702-dkccvpp4/logs

測試下預測能力

import os
from transformers import T5Tokenizer, T5ForConditionalGeneration# 設置 NCCL 環境變量
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"# 加載分詞器
tokenizer = T5Tokenizer.from_pretrained("t5-base")model = T5ForConditionalGeneration.from_pretrained("./results/t5-sql-model")
tokenizer.save_pretrained("./results/t5-sql-model")def generate_sql(question, db_id):input_text = f"Translate to SQL: {question} [SEP] Tables: {db_id}"input_ids = tokenizer.encode(input_text, return_tensors="pt")  # 使?~T? PyTorch ?~Z~D?| ?~G~O?| ??~Ooutput = model.generate(input_ids,max_length=512,num_beams=5,  # 或者嘗試其他解碼策略early_stopping=True)print('output', output)generated_sql = tokenizer.decode(output[0], skip_special_tokens=True)return generated_sqlquestion = "How many singers do we have ?"
db_id = "concert_singer"
evaluation_output = generate_sql(question, db_id)
print("evaluation_output:", evaluation_output)

輸出結果

evaluation_output: "db_id": "concert_singer", "question": "How many singers do we have ?", "query": "select count(*) from singer", "schema": "table_names": ["concert_singer"], "column_names": [[0, "count(*)", "text"]]

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/web/69260.shtml
繁體地址,請注明出處:http://hk.pswp.cn/web/69260.shtml
英文地址,請注明出處:http://en.pswp.cn/web/69260.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

IDEA+DeepSeek讓Java開發起飛

1.獲取DeepSeek秘鑰 登錄DeepSeek官網 : https://www.deepseek.com/ 進入API開放平臺&#xff0c;第一次需要注冊一個賬號 進去之后需要創建一個API KEY&#xff0c;然后把APIkey記錄保存下來 接著我們獲取DeepSeek的API對話接口地址&#xff0c;點擊左邊的&#xff1a;接口…

k8s常見面試題2

k8s常見面試題2 安全與權限RBAC配置如何保護 Kubernetes 集群的 API Server&#xff1f;如何管理集群中的敏感信息&#xff08;如密碼、密鑰&#xff09;&#xff1f;如何限制容器的權限&#xff08;如使用 SecurityContext&#xff09;&#xff1f;如何防止容器逃逸&#xff0…

flutter安卓打包簽名

flutter安卓打包簽名 1.創建簽名文件 keytool -genkeypair -v -keystore my-release-key.jks -keyalg RSA -keysize 2048 -validity 10000 -alias my-key-aliaskeytool 是一個用于管理密鑰和證書的命令行工具&#xff0c;通常與 Java 開發工具包 (JDK) 一起使用。my-release-…

React - jsx 語法

在 React 中&#xff0c;JSX&#xff08;JavaScript XML&#xff09;是一種語法擴展&#xff0c;它允許開發者在 JavaScript 代碼中使用類似 HTML 的語法。JSX 提升了代碼的可讀性和可維護性&#xff0c;使得編寫和構建用戶界面更加直觀。它被廣泛應用于 React 組件的定義。 一…

intra-mart實現簡易登錄頁面筆記

一、前言 最近在學習intra-mart框架&#xff0c;在此總結下筆記。 intra-mart是一個前后端不分離的框架&#xff0c;開發時主要用的就是xml、html、js這幾個文件&#xff1b; xml文件當做配置文件&#xff0c;html當做前端頁面文件&#xff0c;js當做后端文件&#xff08;js里…

Linux+Docer 容器化部署之 Shell 語法入門篇 【Shell 替代】

&#x1f380;&#x1f380;Shell語法入門篇 系列篇 &#x1f380;&#x1f380; LinuxDocer 容器化部署之 Shell 語法入門篇 【準備階段】LinuxDocer 容器化部署之 Shell 語法入門篇 【Shell變量】LinuxDocer 容器化部署之 Shell 語法入門篇 【Shell數組與函數】LinuxDocer 容…

Intellij IDEA如何查看當前文件的類

快捷鍵&#xff1a;CtrlF12&#xff0c;我個人感覺記快捷鍵很麻煩&#xff0c;知道具體的位置更簡單&#xff0c;如果忘了快捷鍵&#xff08;KeyMap&#xff09;看一下就記起來了&#xff0c;不需要再Google or Baidu or GPT啥的&#xff0c;位置&#xff1a;Navigate > Fi…

C++----繼承

一、繼承的基本概念 本質&#xff1a;代碼復用類關系建模&#xff08;是多態的基礎&#xff09; class Person { /*...*/ }; class Student : public Person { /*...*/ }; // public繼承 派生類繼承基類成員&#xff08;數據方法&#xff09;&#xff0c;可以通過監視窗口檢…

已驗證正常,Java輸入字符串生成PDF文件

Java輸入字符串生成PDF文件過程&#xff1a; 在Java開發中&#xff0c;如何將字符串轉換為 PDF 是一個常見的需求。網上找了很多例子都無法生成&#xff0c;經過多次嘗試&#xff0c;終于實現了&#xff0c;特此記錄一下。 1、引入pom.xml 添加所需的依賴 <dependency>&…

Mac M1 Comfyui 使用MMAudio遇到的問題解決?

問題1: AssertionError: Torch not compiled with CUDA enabled&#xff1f; 解決辦法&#xff1a;修改代碼以 CPU 運行 第一步&#xff1a;找到 /ComfyUI/custom_nodes/ComfyUI-MMAudio/mmaudio/ext/autoencoder/vae.py文件中的下面這兩行代碼 self.data_mean nn.Buffer(t…

從 .NET Framework 升級到 .NET 8 后 SignalR 問題處理與解決方案

隨著 .NET Framework 向 .NET 8 的遷移&#xff0c;許多開發者在使用 SignalR 時遇到了一些前后端連接、配置、調用等方面的問題。尤其是在處理 SignalR 實時通信功能時&#xff0c;升級后的一些兼容性問題可能導致應用程序無法正常工作。本文將介紹在從 .NET Framework 升級到…

2025.2.5——五、[網鼎杯 2020 青龍組]AreUSerialz 代碼審計|反序列化

題目來源&#xff1a;BUUCTF [網鼎杯 2020 青龍組]AreUSerialz 目錄 一、打開靶機&#xff0c;整理信息 二、解題思路 step 1&#xff1a;代碼審計 step 2&#xff1a;開始解題 突破protected訪問修飾符限制 三、小結 一、打開靶機&#xff0c;整理信息 直接得到一串ph…

Docker深度解析:安裝各大環境

安裝 Nginx 實現負載均衡&#xff1a; 掛載 nginx html 文件&#xff1a; 創建過載目錄&#xff1a; mkdir -p /data/nginx/{conf,conf.d,html,logs} 注意&#xff1a;在掛載前需要對 conf/nginx.conf 文件進行編寫 worker_processes 1;events {worker_connections 1024; …

docker啟動報錯code=exited, status=1/FAILURE——問題排查

問題 在某臺centos7機器上&#xff0c;啟動docker服務 sudo systemctl start docker報下列錯誤&#xff1a; ● docker.service - Docker Application Container EngineLoaded: loaded (/usr/lib/systemd/system/docker.service; enabled; vendor preset: disabled)Active: …

基于SpringBoot養老院平臺系統功能實現五

一、前言介紹&#xff1a; 1.1 項目摘要 隨著全球人口老齡化的不斷加劇&#xff0c;養老服務需求日益增長。特別是在中國&#xff0c;隨著經濟的快速發展和人民生活水平的提高&#xff0c;老年人口數量不斷增加&#xff0c;對養老服務的質量和效率提出了更高的要求。傳統的養…

PostGIS:使用shp2pgsql、pgsql2shp、OGR2OGR函數進行數據導入、導出

數據導入與導出函數 數據庫數據導入與導出可以通過多個函數完成&#xff0c;QGIS文檔介紹了3個函數&#xff1a; shp2pgsql、pgsql2shp、OGR2OGR&#xff0c;分別用于shp導入數據庫、數據庫文件導出為shp、數據轉換為多種數據格式。 &#xff08;1&#xff09;shp2pgsql 在l…

【AIGC魔童】DeepSeek v3推理部署:vLLM/SGLang/LMDeploy

【AIGC魔童】DeepSeek v3推理部署&#xff1a;vLLM/SGLang/LMDeploy &#xff08;1&#xff09;使用vLLM推理部署DeepSeek&#xff08;2&#xff09;使用SGLang推理部署DeepSeek&#xff08;3&#xff09;使用LMDeploy推理部署DeepSeek &#xff08;1&#xff09;使用vLLM推理部…

《AI “造臉術”:生成對抗網絡打造超真實虛擬人臉》

在科技飛速發展的當下&#xff0c;人工智能的浪潮席卷而來&#xff0c;其中生成對抗網絡&#xff08;GANs&#xff09;技術以其獨特的魅力&#xff0c;成為了生成高度真實感虛擬人臉的強大引擎。無論是影視制作中虛擬角色的塑造&#xff0c;還是游戲領域中多樣化角色形象的構建…

C語言的靈魂——指針(2)

前言&#xff1a;上期我們介紹了如何理解地址&#xff0c;內存&#xff0c;以及指針的一些基礎知識和運算&#xff1b;這期我們來介紹一下const修飾指針&#xff0c;野指針&#xff0c;assert斷言&#xff0c;指針的傳址調用。 上一篇指針&#xff08;1&#xff09; 文章目錄 一…

Android studio 創建aar包給Unity使用

1、aar 是什么&#xff1f; 和 Jar有什么區別 aar 和 jar包 都是壓縮包&#xff0c;可以使用壓縮軟件打開 jar包 用于封裝 Java 類及其相關資源 aar 文件是專門為 Android 平臺設計的 &#xff0c;可以包含Android的專有內容&#xff0c;比如AndroidManifest.xml 文件 &#…