简体中文|不和谐|微信|拥抱脸|社区|纸
文本2SQL |文本2NLU
Text2SQL 评估执行准确性(ex)指标,我们将其移至src/dbgpt_hub_sql
模型 | 方法 | 简单的 | 中等的 | 难的 | 额外的 | 全部 |
---|---|---|---|---|---|---|
根据 | 0 | 0 | 0 | 0 | 0 | |
Llama2-7B-聊天 | 洛拉 | 0.887 | 0.641 | 0.489 | 0.331 | 0.626 |
格洛拉 | 0.847 | 0.623 | 0.466 | 0.361 | 0.608 | |
根据 | 0 | 0 | 0 | 0 | 0 | |
Llama2-13B-聊天 | 洛拉 | 0.907 | 0.729 | 0.552 | 0.343 | 0.68 |
格洛拉 | 0.911 | 0.7 | 0.552 | 0.319 | 0.664 | |
根据 | 0.214 | 0.177 | 0.092 | 0.036 | 0.149 | |
CodeLlama-7B-指令 | 洛拉 | 0.923 | 0.756 | 0.586 | 0.349 | 0.702 |
格洛拉 | 0.911 | 0.751 | 0.598 | 0.331 | 0.696 | |
根据 | 0.698 | 0.601 | 0.408 | 0.271 | 0.539 | |
CodeLlama-13B-指令 | 洛拉 | 0.94 | 0.789 | 0.684 | 0.404 | 0.746 |
格洛拉 | 0.94 | 0.774 | 0.626 | 0.392 | 0.727 | |
根据 | 0.577 | 0.352 | 0.201 | 0.066 | 0.335 | |
百川2-7B-聊天 | 洛拉 | 0.871 | 0.63 | 0.448 | 0.295 | 0.603 |
格洛拉 | 0.891 | 0.637 | 0.489 | 0.331 | 0.624 | |
根据 | 0.581 | 0.413 | 0.264 | 0.187 | 0.392 | |
百川2-13B-聊天 | 洛拉 | 0.903 | 0.702 | 0.569 | 0.392 | 0.678 |
格洛拉 | 0.895 | 0.675 | 0.58 | 0.343 | 0.659 | |
根据 | 0.395 | 0.256 | 0.138 | 0.042 | 0.235 | |
Qwen-7B-聊天 | 洛拉 | 0.855 | 0.688 | 0.575 | 0.331 | 0.652 |
格洛拉 | 0.911 | 0.675 | 0.575 | 0.343 | 0.662 | |
根据 | 0.871 | 0.632 | 0.368 | 0.181 | 0.573 | |
Qwen-14B-聊天 | 洛拉 | 0.895 | 0.702 | 0.552 | 0.331 | 0.663 |
格洛拉 | 0.919 | 0.744 | 0.598 | 0.367 | 0.701 | |
根据 | 0 | 0 | 0 | 0 | 0 | |
聊天GLM3-6b | 洛拉 | 0.855 | 0.605 | 0.477 | 0.271 | 0.59 |
格洛拉 | 0.843 | 0.603 | 0.506 | 0.211 | 0.581 |
DB-GPT-Hub 是一个实验项目,利用大型语言模型 (LLM) 来实现文本到 SQL 的解析。该项目涵盖数据收集、数据预处理、模型选择和构建以及模型权重微调等各个阶段。通过这些流程,我们的目标是增强 Text-to-SQL 能力,同时降低模型训练成本,从而使更多开发人员能够为提高 Text-to-SQL 准确性做出贡献。我们的最终目标是实现基于数据库的自动化问答能力,让用户使用自然语言描述执行复杂的数据库查询。
迄今为止,我们已经成功集成了多个大型模型,并建立了包括数据处理、监督微调(SFT)模型训练、预测输出和评估在内的全面工作流程。为此项目开发的代码可以在项目本身中轻松重用。
截至2023年10月10日,我们已经利用该项目对开源13B大小的模型进行了微调,纳入了更多相关数据。在零样本提示下,利用基于Spider的测试套件,我们对1.27G大小的数据库实现了0.764的执行准确率。另外,Spider官网给出的95M大小的数据库的执行精度为0.825。
我们通过在大型语言模型上应用监督微调(SFT)来增强文本到 SQL 的性能。
该项目示例的主要数据集是Spider数据集:
其他可用的 text2sql 数据集:
WikiSQL:一个大型语义解析数据集,由 80,654 个自然语句表达式和 24,241 个表的 SQL 注释组成。 WikiSQL 中的每个查询仅限于同一个表,不包含排序、分组等复杂操作 WikiSQL 中的查询仅限于同一个表,不包含排序、分组、子查询等复杂操作。
CHASE:一个跨域多轮交互式 text2sql 中文数据集,包含 5,459 个多轮问题的列表,其中包含 280 个不同域数据库的 17,940 个
BIRD-SQL:大规模跨域文本到 SQL 的英文基准测试,特别关注大型数据库内容。该数据集包含 12,751 个文本到 SQL 数据对和 95 个数据库,总大小为 33.4 GB,涵盖 37 个职业领域。 BIRD-SQL 数据集通过探索三个额外的挑战,即处理大型且混乱的数据库值、外部知识推理和优化 SQL 执行效率,弥合了文本到 SQL 研究和实际应用之间的差距。
CoSQL:用于构建跨域会话文本到 SQL 系统的语料库。它是 Spider 和 SParC 任务的对话版本。 CoSQL 包含超过 30k 轮和 10k 多个带注释的 SQL 查询,这些查询来自 Wizard-of-Oz 的 3k 对话集合,查询跨 138 个域的 200 个复杂数据库。每个对话都模拟一个真实的数据库查询场景,其中工作人员以用户身份探索数据库,而 SQL 专家使用 SQL 来检索答案、澄清不明确的问题或以其他方式提供信息。
按照NSQL的处理模板,对数据集进行基础处理,得到约20W的数据集
DB-GPT-Hub目前支持以下基础型号:
该模型使用冗余架构量化学习 (QLoRA) 基于 4 的量化位进行微调。其最低硬件要求可参考如下:
型号参数 | 图形处理器内存 | 中央处理器内存 | 磁盘 |
---|---|---|---|
7b | 6GB | 3.6GB | 36.4GB |
13b | 13.4GB | 5.9GB | 60.2GB |
所有相关参数均设置为最小值,批量大小为1,最大长度为512。根据经验,为了获得更好的性能,建议将相关长度值设置为1024或2048。
git clone https://github.com/eosphoros-ai/DB-GPT-Hub.git
cd DB-GPT-Hub
conda create -n dbgpt_hub python=3.10
conda activate dbgpt_hub
cd src/dbgpt_hub_sql
pip install -e .
首先,使用以下命令安装dbgpt-hub
pip install dbgpt-hub
然后,设置参数并运行整个过程。
from dbgpt_hub_sql . data_process import preprocess_sft_data
from dbgpt_hub_sql . train import start_sft
from dbgpt_hub_sql . predict import start_predict
from dbgpt_hub_sql . eval import start_evaluate
# Config the input datasets
data_folder = "dbgpt_hub_sql/data"
data_info = [
{
"data_source" : "spider" ,
"train_file" : [ "train_spider.json" , "train_others.json" ],
"dev_file" : [ "dev.json" ],
"tables_file" : "tables.json" ,
"db_id_name" : "db_id" ,
"is_multiple_turn" : False ,
"train_output" : "spider_train.json" ,
"dev_output" : "spider_dev.json" ,
}
]
# Config training parameters
train_args = {
"model_name_or_path" : "codellama/CodeLlama-13b-Instruct-hf" ,
"do_train" : True ,
"dataset" : "example_text2sql_train" ,
"max_source_length" : 2048 ,
"max_target_length" : 512 ,
"finetuning_type" : "lora" ,
"lora_target" : "q_proj,v_proj" ,
"template" : "llama2" ,
"lora_rank" : 64 ,
"lora_alpha" : 32 ,
"output_dir" : "dbgpt_hub_sql/output/adapter/CodeLlama-13b-sql-lora" ,
"overwrite_cache" : True ,
"overwrite_output_dir" : True ,
"per_device_train_batch_size" : 1 ,
"gradient_accumulation_steps" : 16 ,
"lr_scheduler_type" : "cosine_with_restarts" ,
"logging_steps" : 50 ,
"save_steps" : 2000 ,
"learning_rate" : 2e-4 ,
"num_train_epochs" : 8 ,
"plot_loss" : True ,
"bf16" : True ,
}
# Config predict parameters
predict_args = {
"model_name_or_path" : "codellama/CodeLlama-13b-Instruct-hf" ,
"template" : "llama2" ,
"finetuning_type" : "lora" ,
"checkpoint_dir" : "dbgpt_hub_sql/output/adapter/CodeLlama-13b-sql-lora" ,
"predict_file_path" : "dbgpt_hub_sql/data/eval_data/dev_sql.json" ,
"predict_out_dir" : "dbgpt_hub_sql/output/" ,
"predicted_out_filename" : "pred_sql.sql" ,
}
# Config evaluation parameters
evaluate_args = {
"input" : "./dbgpt_hub_sql/output/pred/pred_sql_dev_skeleton.sql" ,
"gold" : "./dbgpt_hub_sql/data/eval_data/gold.txt" ,
"gold_natsql" : "./dbgpt_hub_sql/data/eval_data/gold_natsql2sql.txt" ,
"db" : "./dbgpt_hub_sql/data/spider/database" ,
"table" : "./dbgpt_hub_sql/data/eval_data/tables.json" ,
"table_natsql" : "./dbgpt_hub_sql/data/eval_data/tables_for_natsql2sql.json" ,
"etype" : "exec" ,
"plug_value" : True ,
"keep_distict" : False ,
"progress_bar_for_each_datapoint" : False ,
"natsql" : False ,
}
# Run the whole fine-tuning workflow
preprocess_sft_data (
data_folder = data_folder ,
data_info = data_info
)
start_sft ( train_args )
start_predict ( predict_args )
start_evaluate ( evaluate_args )
DB-GPT-Hub采用信息匹配生成方式进行数据准备,即结合表信息的SQL+Repository生成方式。该方法结合数据表信息,可以更好地理解数据表的结构和关系,适合生成符合要求的SQL语句。
从 Spider 数据集链接下载 Spider 数据集。默认情况下,下载并解压数据后,将其放置在 dbgpt_hub_sql/data 目录下,即路径应为dbgpt_hub_sql/data/spider
。
对于数据预处理部分,只需运行以下脚本:
# # generate train and dev(eval) data
sh dbgpt_hub_sql/scripts/gen_train_eval_data.sh
在目录dbgpt_hub_sql/data/
中,您将找到新生成的训练文件 example_text2sql_train.json 和测试文件 example_text2sql_dev.json,分别包含 8659 和 1034 个条目。对于后续微调使用的数据,将参数file_name
值设置为dbgpt_hub_sql/data/dataset_info.json中训练集的文件名,例如example_text2sql_train.json
生成的 JSON 中的数据如下所示:
{
"db_id": "department_management",
"instruction": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.n"n##Instruction:ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.nThe head_ID of management is the foreign key of head_ID of head.nThe department_ID of management is the foreign key of Department_ID of department.nn",
"input": "###Input:nHow many heads of the departments are older than 56 ?nn###Response:",
"output": "SELECT count(*) FROM head WHERE age > 56",
"history": []
},
项目的数据处理代码中已嵌入chase
、 cosql
、 sparc
的数据处理代码。按照上面的链接下载完数据集后,只需要在 dbgpt_hub_sql/configs/config.py in
添加Just loosen the corresponding code comment in SQL_DATA_INFO
。
模型微调支持LoRA和QLoRA方法。我们可以运行以下命令来微调模型。默认情况下,通过参数--quantization_bit,它使用QLoRA微调方法。要切换到 LoRA,只需从脚本中删除相关参数即可。运行命令:
sh dbgpt_hub_sql/scripts/train_sft.sh
微调后,模型权重将默认保存在适配器文件夹中,具体为 dbgpt_hub_sql/output/adapter 目录。
如果您使用多 GPU 训练并希望使用 deepseed ,则应修改 train_sft.sh 中的默认内容。变化是:
CUDA_VISIBLE_DEVICES=0 python dbgpt_hub_sql/train/sft_train.py
--quantization_bit 4
...
改为:
deepspeed --num_gpus 2 dbgpt_hub_sql/train/sft_train.py
--deepspeed dbgpt_hub_sql/configs/ds_config.json
--quantization_bit 4
...
如果您需要订单卡 ID
deepspeed --include localhost:0,1 dbgpt_hub_sql/train/sft_train.py
--deepspeed dbgpt_hub_sql/configs/ds_config.json
--quantization_bit 4
...
其他省略的部分(……)可以保持一致。如果要更改默认的 Deepseed 配置,请进入dbgpt_hub_sql/configs
目录并根据需要更改 ds_config.json,默认为 stage2。
脚本中,微调时,不同模型对应关键参数lora_target和template,如下表所示:
型号名称 | 洛拉目标 | 模板 |
---|---|---|
拉玛-2 | q_proj,v_proj | 骆驼2 |
代码骆马-2 | q_proj,v_proj | 骆驼2 |
百川2 | W_pack | 百川2 |
奎文 | c_attn | 聊天室 |
sqlcoder-7b | q_proj,v_proj | 米斯塔拉尔 |
sqlcoder2-15b | c_attn | 默认 |
实习生LM | q_proj,v_proj | 实习生 |
XVERSE | q_proj,v_proj | 宇宙 |
聊天GLM2 | 查询键值 | 聊天glm2 |
骆驼 | q_proj,v_proj | - |
盛开 | 查询键值 | - |
布卢姆兹 | 查询键值 | - |
百川 | W_pack | 百川 |
鹘 | 查询键值 | - |
在train_sft.sh
中,其他关键参数如下:
quantization_bit:表示是否进行量化,有效值为[4或8]。
model_name_or_path:LLM(大型语言模型)的路径。
dataset:指定训练数据集配置的名称,对应dbgpt_hub_sql/data/dataset_info.json中的外键值,例如example_text2sql。
max_source_length:输入模型的文本长度。如果计算资源允许的话,可以设置得尽可能大,比如1024或者2048。
max_target_length:模型输出的SQL内容的长度; 512一般就足够了。
output_dir:SFT(监督微调)期间 Peft 模块的输出路径,默认设置为dbgpt_hub_sql/output/adapter/
。
per_device_train_batch_size:批次的大小。如果计算资源允许,可以设置得大一些;默认值为 1。
gradient_accumulation_steps:更新前累积梯度的步数。
save_steps:保存模型检查点的步数;默认情况下可以设置为100。
num_train_epochs:训练数据集的纪元数。
在项目目录./dbgpt_hub_sql/output/pred/下,该文件夹是模型预测的默认输出位置(如果不存在,则仅mkdir)。
sh ./dbgpt_hub_sql/scripts/predict_sft.sh
在脚本中,默认使用参数--quantization_bit
,它使用 QLoRA 进行预测。删除它会切换到 LoRA 预测方法。参数predicted_input_filename
的值是您的预测测试数据集文件。 --predicted_out_filename
是模型预测结果的文件名。
第二个对应的模型权重可以从 Huggingface hg-eosphoros-ai 中找到,我们在 10 月份上传了 LoRA 权重,在 Spider 评估集上的执行精度达到了 0.789。
如果需要合并训练好的基础模型和微调后的Peft模块的权重来导出完整的模型,请执行以下模型导出脚本:
sh ./dbgpt_hub_sql/scripts/export_merge.sh
请务必将脚本中的参数路径值替换为与您的项目对应的路径。
要评估数据集上的模型性能,默认为蜘蛛开发数据集。运行以下命令:
python dbgpt_hub_sql/eval/evaluation.py --plug_value --input Your_model_pred_file
您可以在这里找到我们最新的审查结果和部分实验结果
注:默认代码指向的数据库是从【Spider官网】(https://yale-lily.github.io/spider)下载的95M数据库。如果您需要在测试套件中使用Spider数据库(大小1.27G),请先将链接中的数据库下载到自定义目录,然后运行上面的评估命令,其中添加参数和值,如--db Your_download_db_path
。
整个过程我们将分为三个阶段:
第一阶段:
目前,我们提供对以下功能的支持:
第二阶段:
20231010
之前支持通过多种方式微调更多不同模型prompts
第三阶段:
如果我们的工作为您提供了哪怕一点点帮助,请考虑给我们一颗星。您的反馈和支持将成为我们继续发布更多相关工作和改进工作的动力。谢谢你!
我们热烈邀请更多的人加入我们,并积极参与我们项目的各个方面,例如数据集、模型微调、性能评估、论文推荐和代码复制。请随时提出问题或拉取请求 (PR),我们将积极响应您的贡献。
在提交代码之前,请使用以下命令确保其格式符合黑色样式:
black dbgpt_hub
如果您有更多时间对代码执行更详细的类型检查和样式检查,请使用以下命令:
pyright dbgpt_hub
pylint dbgpt_hub
如果您有任何疑问或需要进一步帮助,请随时与我们联系。我们感谢您的参与!
我们的工作主要建立在众多开源贡献的基础上。感谢以下开源项目
感谢所有贡献者,特别是@JBoRu,他提出了这个问题,提醒我们添加一种新的有前途的评估方式,即测试套件。正如论文《SQL-PALM: IMPROVED LARGE LANGUAGE MODEL ADAPTATION FOR TEXT-TO-SQL》中提到的,“我们考虑两个常用的评估指标:执行精度(EX)和测试套件精度(TS)。EX衡量是否SQL 执行结果与真实值 (GT) 匹配,而 TS 衡量 SQL 是否通过了由数据库增强生成的多个测试的所有 EX 评估,因此我们认为 TS 是更可靠的评估指标。
如果您发现DB-GPT-Hub
对您的研究或开发有用,请引用以下论文:
@misc { zhou2024dbgpthub ,
title = { DB-GPT-Hub: Towards Open Benchmarking Text-to-SQL Empowered by Large Language Models } ,
author = { Fan Zhou and Siqiao Xue and Danrui Qi and Wenhui Shi and Wang Zhao and Ganglin Wei and Hongyang Zhang and Caigai Jiang and Gangwei Jiang and Zhixuan Chu and Faqiang Chen } ,
year = { 2024 } ,
eprint = { 2406.11434 } ,
archivePrefix = { arXiv } ,
primaryClass = { id='cs.DB' full_name='Databases' is_active=True alt_name=None in_archive='cs' is_general=False description='Covers database management, datamining, and data processing. Roughly includes material in ACM Subject Classes E.2, E.5, H.0, H.2, and J.1.' }
}
麻省理工学院许可证 (MIT)
我们作为一个社区进行合作,如果您对我们的社区工作有任何想法,请随时与我们联系。如果您有兴趣深入实验并优化DB-GPT-Hub子项目,可以联系微信群中的“旺仔”。我们竭诚欢迎您的贡献,让我们一起变得更好!