0基础搞AI-NL2SQL数据集处理脚本(用于LLM-fine-tune)

这篇具有很好参考价值的文章主要介绍了0基础搞AI-NL2SQL数据集处理脚本(用于LLM-fine-tune)。希望对大家有所帮助。如果存在错误或未考虑完全的地方,请大家不吝赐教,您也可以点击"举报违法"按钮提交疑问。

        消失了好久好久,这次换了一家公司,然后又在忙于秋招,因此很久没有更新,最近事情也告一段落,因此终于有空回来水博客,今天给大家带来最近的工作,NL2SQL数据集,我们的工作是利用代码生成大模型(类似CodeFuse系列,CodeLlama系列)进行fine-tune,通过用户query和query涉及的数据库表的Schema作为输入,使用fine-tune后的LLM进行推理来得到最后的生成SQL,当然为了工作的方便,所以我们试图将所有的开源数据集进行整合,因此在此处的NL2SQL数据集中,提供了经过模型翻译的Wiki_SQL数据集,Cspider数据集,Du_SQL数据集,如果有大佬有追一科技的数据集请告诉我,需要一些帮助,接下来首先给出NL2SQL数据集的处理脚本:

1、数据集生成(耗时13h,把8w条WikiSQL翻译了)

Data_deal_Script.py

"""
codeer:Jinzhangli
function:数据集处理和构建
relation:2035877994@qq.com
time:2023/11/21 15:23
"""
import json,re

class Cspider_Data_make:
    def Cspider_Schema_load_deal(self):
        Schema={}
        All_DB=self.Cspider_Data_load("Data/Cspider/tables.json")
        for i in range(len(All_DB)):
            DB={}
            column_names=All_DB[i]["column_names"]
            table_names=All_DB[i]['table_names']
            for j in range(len(table_names)):
                DB["_".join(re.split(" ",table_names[j]))]=[column_names[k][1] for k in range(len(column_names)) if column_names[k][0]==j]
            Schema[All_DB[i]["db_id"]]=DB
        return Schema

    def Cspider_Data_load(self,file_path:str):
        dict_data=json.loads(open(file_path,"r",encoding="utf-8").read())
        return dict_data

    def Cspider_Schema_pipe(self,db_name:str,Table_list:list):
        All_Schema=self.Cspider_Schema_load_deal()
        result=[]
        Table_list=[i for i in Table_list if i not in ["("]]
        for i in range(len(Table_list)):
            result.append(All_Schema[db_name][Table_list[i]])
        return result

    def Table_get(self,SQL_token:list)->list:
        Table_list=[SQL_token[i] for i in range(len(SQL_token)) if SQL_token[i-1] in ["from","join"]]
        return Table_list

    def Dict_deal(self,one_dict:dict)->dict:
        query=one_dict["question"]
        SQL=one_dict["query"]
        db_name=one_dict["db_id"]
        return {"query":query,"SQL":SQL,"table_name":"","column_name":"","db_name":db_name}

    def Cspider_Datas_Get(self,Cspider_data):
        Result=[]
        for i in range(len(Cspider_data)):
            if i not in [3097,3153]:
                print("=========正在处理第"+str(i)+",总共有"+str(len(Cspider_data))+"个=========")
                one_dict = self.Dict_deal(Cspider_data[i])
                Table_list = list(set(self.Table_get(Cspider_data[i]["query_toks"])))
                result = self.Cspider_Schema_pipe(one_dict["db_name"], Table_list)
                one_dict["table_name"] = Table_list
                one_dict["column_name"] = result
                Result.append(one_dict)
        return Result

    def Csipder_main(self):
        Cspider_train_data = self.Cspider_Data_load("Data/Cspider/train.json")
        Cspider_dev_data=self.Cspider_Data_load("Data/Cspider/dev.json")
        Cspider_Result=self.Cspider_Datas_Get(Cspider_train_data)+self.Cspider_Datas_Get(Cspider_dev_data)
        return Cspider_Result

class wikiSQL_Data_make:
    def wiki_load(self,file_path):
        file_str=open(file_path,"r",encoding="utf-8").readlines()
        Dict_Data=[eval(file_str[i]) for i in range(len(file_str))]
        return Dict_Data

    def wiki_deal(self,data_path,table_path):
        Dict_data=self.wiki_load(data_path)
        Table_data=self.wiki_load(table_path)
        Wiki_Result,Index=[],0
        Table_dict={Table_data[i]["id"]:[Table_data[i]["header"],Table_data[i]['caption']]
                    for i in range(len(Table_data)) if "caption" in Table_data[i].keys()}
        for i in range(len(Dict_data)):
            table_id=Dict_data[i]["table_id"]
            all_table=Table_dict.keys()
            if table_id in all_table:
                #print("正在处理第" + str(Index) + ",总共有" + str(len(Dict_data)) + "个")
                Index+=1
                query=Dict_data[i]["question"]
                table_name="_".join(re.split(" ",Table_dict[Dict_data[i]["table_id"]][1]))
                SQL=Dict_data[i]["sql"]
                column_name=Table_dict[Dict_data[i]["table_id"]][0]
                for j in range(len(column_name)):
                    column=[]
                    if "/" in column_name[j] and "(" not in column_name[j]:
                        column_name[j]=re.split("/",column_name[j])[0]
                    elif "(" in column_name[j]:
                        for k in column_name[j]:
                            if k!="(":
                                column.append(k)
                            else:
                                column_name[j]=re.split(" ","".join(column))
                                if column_name[j][-1]=="":
                                    column_name[j]="_".join(column_name[j][0:-1])
                                else:
                                    column_name[j] = "_".join(column_name[j])
                                break
                    elif " " in column_name[j]:
                        column_name[j]="_".join(re.split(" ",column_name[j]))
                    elif type(column_name[j])==list:
                        column_name[j]=column_name[j][0]
                SQL=self.SQL_make(SQL,column_name,table_name)
                one_dict={"query": query, "SQL": SQL, "table_name": table_name, "column_name":column_name, "db_name": ""}
                Wiki_Result.append(one_dict)
        return Wiki_Result

    def SQL_make(self,SQL_token,column_name,table_name):
        agg_Action, conds_Acction= ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'],['=', '>', '<', 'OP']
        SQL="SELECT "+agg_Action[SQL_token["agg"]]+" ( "+column_name[SQL_token["sel"]]+" ) "+"FROM "+table_name
        if len(SQL_token["conds"])==1:
            if type(SQL_token["conds"][0][2])!=str:
                SQL_token["conds"][0][2]=str(SQL_token["conds"][0][2])
            SQL_token["conds"][0][1]=conds_Acction[SQL_token["conds"][0][1]]
            SQL_token["conds"][0][0]=column_name[SQL_token["conds"][0][0]]
            SQL+=" WHERE "+" ".join(SQL_token["conds"][0])
        else:
            conds_list=SQL_token["conds"]
            for i in range(len(conds_list)):
                if type(conds_list[i][2])!=str:
                    conds_list[i][2]=str(conds_list[i][2])
                conds_list[i][0]=column_name[conds_list[i][0]]
                conds_list[i][1]=conds_Acction[conds_list[i][1]]
            for i in range(len(conds_list)):
                if i==len(conds_list)-1:
                    SQL+="and "+" ".join(conds_list[i])
                elif i==0:
                    SQL+="WHERE "+" ".join(conds_list[i])+" "
                else:
                    SQL+="and "+" ".join(conds_list[i])+" "
        return SQL

    def wiki_main(self):
        Wiki_Result=self.wiki_deal("Data/WikiSQL/train.json","Data/WikiSQL/train_tables.json")
        return Wiki_Result

class DuSQL_Data_make:
    def DuSQL_load(self,file_path):
        DuSQL_data=json.loads(open(file_path,"r",encoding="utf-8").read())
        return DuSQL_data

    def Schema_deal(self,DuSQL_schema:list[dict]):
        Schema_dict={}
        for i in range(len(DuSQL_schema)):
            table_names=DuSQL_schema[i]["table_names"]
            column_names=DuSQL_schema[i]["column_names"]
            Schema_dict[DuSQL_schema[i]["db_id"]]={table_names[j]:[column_names[k][1] for k in range(len(column_names)) if column_names[k][0]==j] for j in range(len(table_names))}
        return Schema_dict

    def TableGetFromSQL(self,SQL):
        SQL_List=re.split(" ",SQL)
        Table=list(set([SQL_List[i] for i in range(len(SQL_List)) if i!=0 and SQL_List[i-1] in ["from","join"]]))
        return Table

    def Query_SQL_Schema(self,DUSQL_data:list[dict],DuSQL_Schema):
        Result=[]
        for i in range(len(DUSQL_data)):
            print("=========正在处理第" + str(i) + ",总共有" + str(len(DUSQL_data)) + "个=========")
            SQL=DUSQL_data[i]["sql_query"]
            query=DUSQL_data[i]["question"]
            db_name=DUSQL_data[i]["db_id"]
            table=self.TableGetFromSQL(SQL)[0]
            column=DuSQL_Schema[db_name][table]
            Result.append({"query":query,"SQL":SQL,"table_name":table,"column_name":column,"db_name":db_name})
        return Result

    def DuSQL_main(self):
        DuSQL_data=self.DuSQL_load("Data/DuSQL/sample-data.json")
        DUSQL_Schema=self.DuSQL_load("Data/DuSQL/db-schema.json")
        DUSQL_Schema=self.Schema_deal(DUSQL_Schema)
        DuSQL_Result=self.Query_SQL_Schema(DuSQL_data,DUSQL_Schema)
        return DuSQL_Result

用于翻译的数据接口,这里用了通义千问14B

OutAPI.py

"""
codeer:Jinzhangli
function:接入外部API服务
relation:2035877994@qq.com
time:2023/11/30 15:49
"""
import requests,json
def Qwen14BChat(text,history):
    url="http://172.16.158.247:9899/Qwen14B"
    data=json.dumps({"prompt":text,"history":history})
    response=requests.post(url=url,data=data)
    response=eval(response.text)
    return response

接下来是主控脚本,Tune_main.py

"""
codeer:Jinzhangli
function:主控文件
relation:2035877994@qq.com
time:2023/11/30 18:05
"""
import json
from Data_Deal_Script import *
from OutAPI import *
def LearningDataJson_build():
    wikiSQL_Data = wikiSQL_Data_make()
    print("开始处理WIKI_SQL")
    WIKI_SQL = wikiSQL_Data.wiki_main()
    #英文数据集翻译
    for i in range(len(WIKI_SQL)):
        print("====翻译第"+str(i)+"个句子====")
        WIKI_SQL[i]["query"] = Qwen14BChat("请帮我将以下文本翻译为中文,只输出结果,不要任何解释\n"+WIKI_SQL[i]["query"],[])["response"]
        print(WIKI_SQL[i]["query"])
    Cspider_Data = Cspider_Data_make()
    Dusql_Data = DuSQL_Data_make()
    print("开始处理DU_SQL")
    DU_SQL = Dusql_Data.DuSQL_main()
    print("开始处理Cspider")
    Cspider = Cspider_Data.Csipder_main()
    Result=DU_SQL+Cspider+WIKI_SQL
    with open("result.json", "w", encoding="utf-8") as json_file:
        json.dump(Result,json_file,ensure_ascii=False)
2、基于Swift框架的加载LoRA微调

接下来是LLM微调脚本(基于Swift框架)

首先安装阿里巴巴Swift框架

git clone https://github.com/modelscope/swift.git
cd swift
pip install -e .

然后进入Clone下来的Swift文件夹

cd ../swift/examples/pytorch/llm

使用llm下自带的脚本,也可以自己写,我比较懒直接os.system()来修改

import os
command="""
CUDA_VISIBLE_DEVICES=0 \
python llm_sft.py \
    --model_type qwen-14b \
    --model_cache_dir /home/gpu-user1/JinzhangLi/Qwen-14B \
    --sft_type lora \
    --template_type default-generation \
    --dtype bf16 \
    --output_dir output \
    --dataset dureader-robust-zh \
    --train_dataset_sample -1 \
    --num_train_epochs 1 \
    --max_length 2048 \
    --quantization_bit 4 \
    --bnb_4bit_comp_dtype bf16 \
    --lora_rank 8 \
    --lora_alpha 32 \
    --lora_dropout_p 0. \
    --lora_target_modules ALL \
    --gradient_checkpointing true \
    --batch_size 1 \
    --weight_decay 0. \
    --learning_rate 1e-4 \
    --gradient_accumulation_steps 16 \
    --max_grad_norm 0.5 \
    --warmup_ratio 0.03 \
    --eval_steps 100 \
    --save_steps 100 \
    --save_total_limit 2 \
    --logging_steps 10 \
    --use_flash_attn false \
    --push_to_hub false \
    --hub_model_id qwen-14b-qlora \
    --hub_private_repo true \
    --hub_token 'your-sdk-token' """
os.system(command)
3、数据集样式和链接(根据自己使用的框架微调,不出意外,后面数据集还会变大)

最后给出搞定后的NL2SQL数据集(当然数据集还得调整,只是将数据格式整理如下)

{    
    "query": "创刊时间不早于1989年10月10日的期刊,按出版刊数降序排列给出期刊的名称以及语言", 
    "SQL": "select 名称 , 语言 from 期刊 where 创刊时间 >= '1989-10-10' order by 出版刊数 desc", 
    "table_name": "期刊",
     "column_name": ["词条id", "名称", "语言", "类别", "主办单位", "创刊时间", "国家", "出版刊数"],
     "db_name": "期刊"
}

如想获取数据,请访问我们在modelscope的开源地址

Text2SQL-英文-150K · 数据集 (modelscope.cn)

Text2SQL-中文-180K · 数据集 (modelscope.cn)文章来源地址https://www.toymoban.com/news/detail-847531.html

到了这里,关于0基础搞AI-NL2SQL数据集处理脚本(用于LLM-fine-tune)的文章就介绍完了。如果您还想了解更多内容,请在右上角搜索TOY模板网以前的文章或继续浏览下面的相关文章,希望大家以后多多支持TOY模板网!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请点击违法举报进行投诉反馈,一经查实,立即删除!

领支付宝红包 赞助服务器费用

相关文章

  • 【REST2SQL】05 GO 操作 达梦 数据库

    【REST2SQL】01RDB关系型数据库REST初设计 【REST2SQL】02 GO连接Oracle数据库 【REST2SQL】03 GO读取JSON文件 【REST2SQL】04 REST2SQL第一版Oracle版实现 信创要求用国产数据库,刚好有项目用的达梦,研究一下go如何操作达梦数据库 登录 达梦 官网,有DM8开发版可以下载,我下载的是X86,Win

    2024年02月01日
    浏览(57)
  • 【REST2SQL】07 GO 操作 Mysql 数据库

    【REST2SQL】01RDB关系型数据库REST初设计 【REST2SQL】02 GO连接Oracle数据库 【REST2SQL】03 GO读取JSON文件 【REST2SQL】04 REST2SQL第一版Oracle版实现 【REST2SQL】05 GO 操作 达梦 数据库 【REST2SQL】06 GO 跨包接口重构代码 MySQL是一个关系型数据库管理系统,由瑞典MySQL AB 公司开发,属于 Oracle旗

    2024年01月22日
    浏览(94)
  • LLM在text2sql上的应用

    目前,大模型的一个热门应用方向text2sql它可以帮助用户快速生成想要查询的SQL语句。那对于用户来说,大部分简单的sql都是正确的,但对于一些复杂逻辑来说,需要用户在产出SQL的基础上进行简单修改,Text2SQL应用主要还是帮助用户去解决开发时间,减少开发成本。 Text to

    2024年02月08日
    浏览(42)
  • TEXT2SQL-顶峰:Vanna部署及介绍

    Vanna 是一款采用 MIT 许可的开源 Python RAG (检索增强生成)框架,用于生成 SQL 语句和相关功能。 如何使用 Vanna Vanna 的使用分为两个简单步骤 - 在你的数据上训练一个 RAG \\\"模型\\\",然后提出问题,该问题将返回可设置为自动在你的数据库上运行的 SQL 查询。 1. 在你的数据上训练一

    2024年02月22日
    浏览(41)
  • selenium系列--测试脚本--将Excel文件用于测试(unittest数据驱动实战)

    我们只需要写一个函数方法进行调用即可,读取Excel文件,将值进行返回便于下一个接口使用。 import xlrd class Excel_Login: def excel_login(self): file_name = xlrd.open_workbook(r’F:111.xlsx’) sh1 = file_name.sheet_by_index(0) rows = sh1.nrows datalist = [] for i in range(1, rows): datalist.append(sh1.row_values(i)) retu

    2024年04月11日
    浏览(62)
  • 大模型LLM在 Text2SQL 上的应用实践

    一、前言 目前,大模型的一个热门应用方向Text2SQL,它可以帮助用户快速生成想要查询的SQL语句,再结合可视化技术可以降低使用数据的门槛,更便捷的支持决策。本文将从以下四个方面介绍LLM在Text2SQL应用上的基础实践。 · Text2SQL概述 · LangChain基础知识 · 基于SQLDatabaseCha

    2024年01月16日
    浏览(42)
  • 最强开源Text2SQL大模型本地部署的解决方案

      大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的

    2024年02月08日
    浏览(37)
  • Archery系统调用my2sql读取binlog的功能优化

    Archery系统集成了my2sql工具,可以通过此功能分析MysQL的binlog,方便SQL回滚,还可以协助异常分析,定位问题。 优化点 解析后没有SQL语句返回,可能的原因是解析过程中遇到了错误,而系统没有捕获错误,更没有将错误异常返回给操作者。 此处的优化,就是解决这一信息黑洞

    2024年01月20日
    浏览(44)
  • 大模型 LLM RAG在 Text2SQL 上的应用实践

    1. 前言 在上篇文章中「LLM Agent在Text2SQL应用上的实践」介绍了基于AI Agent来优化LLM的Text2SQL转换效果的实践,除此之外我们还可以使用RAG(Retrieval-Augmented Generation)来优化大模型应用的效果。 本文将从以下4个方面探讨通过RAG来优化LLM的Text2SQL转换效果。 1. RAG概述 2. 基于LangC

    2024年02月02日
    浏览(38)
  • 【REST2SQL】08 日志重构增加输出到文件log.txt

    【REST2SQL】01RDB关系型数据库REST初设计 【REST2SQL】02 GO连接Oracle数据库 【REST2SQL】03 GO读取JSON文件 【REST2SQL】04 REST2SQL第一版Oracle版实现 【REST2SQL】05 GO 操作 达梦 数据库 【REST2SQL】06 GO 跨包接口重构代码 【REST2SQL】07 GO 操作 Mysql 数据库 原来的日志只输出到控制台,关闭控制台

    2024年02月01日
    浏览(42)

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

博客赞助

微信扫一扫打赏

请作者喝杯咖啡吧~博客赞助

支付宝扫一扫领取红包,优惠每天领

二维码1

领取红包

二维码2

领红包