Text-to-SQL将自然语言转换为数据库查询语句

有关Text-To-SQL方法,可以查阅我的另一篇文章,Text-to-SQL方法研究

直接与数据库对话-text2sql

Text2sql就是把文本转换为sql语言,这段时间公司有这方面的需求,调研了一下市面上text2sql的方法,比如阿里的Chat2DB,麻省理工开源的Vanna。试验了一下,最终还是决定自研,基于Vanna的思想,RAG+大模型。

    使用开源的Vanna实现text2sql比较方便,Vanna可以直接连接数据库,但是当用户权限能访问多个数据库的时候,就比较麻烦了,而且Vanna向量化存储之后,新的question作对比时没有区分数据库。因此自己实现了一下text2sq,仍然采用Vanna的思想,提前训练DDL,Sqlques,和数据库document。

这里简单做一下记录,以供后续学习使用。

基本思路

1、数据库DDL语句,SQL-Question,Dcoument信息获取

2、基于用户提问question和数据库Document锁定要分析的数据库

3、模型训练:借助数据库的DDL语句、元数据(描述数据库自身数据的信息)、相关文档说明、参考样例SQL等,训练一个RAG“模型”。

这一模型结合了embedding技术和向量数据库,使得数据库的结构和内容能够被高效地索引和检索。

4、语义检索: 当用户输入自然语言描述的问题时,①会从向量库里面检索,迅速找出与问题相关的内容;②进行BM25算法文本召回,找到与问题 最相关的内容;③分别使用RRF算法和Re-ranking重排序算法,锁定最相关内容

语义匹配:使用算法(如BERT等)来理解查询和文档的语义相似性

文本召回匹配:BM25算法文本召回,找到与问题最相关的内容

rerank结果重排序:对搜索结果进行排序。

5、Prompt构建: 检索到的相关信息会被组装进Prompt中,形成一个结构化的查询描述。这一Prompt随后会被传递给LLM(大型语言模型)用于生成准确的SQL查询。

实现逻辑图

实现架构图:

具体实现方式如下所示:

1.数据库的选择

class DataBaseSearch(object):def __init__(self, _model):self.name = 'DataBaseSearch'self.model = _modelself.instruction = "为这段内容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.textdata = []self.subdata = {}self.i2key = {}self.id2ddls = {}self.id2sqlques = {}self.id2docs = {}self.strtexts = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_textdata()         # 加载text数据self.load_textdata_vec()     # text数据向量化def load_textdata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)textdatas = jsonobj["data"]for textdata in textdatas:                                 # 提取每一个数据库内容cid = textdata["dataSetID"]cddls = textdata["ddl"]csql_ques = textdata["exp"]cdocuments = textdata["Intro"]self.textdata.append((cid, cddls, csql_ques, cdocuments))   # 整合所有数据except Exception as e:print(e)# print("load textdata ", self.textdata)def load_textdata_vec(self):num0 = 0for recode in self.textdata:_id = recode[0]_ddls = recode[1]_sql_ques = recode[2]_documents = recode[3]# _strtexts = str(_ddls) + str(_sql_ques) + str(_documents)_strtexts = str(_sql_ques) + str(_documents)text_embeddings = self.model.encode([_strtexts], normalize_embeddings=True)self.index.add(text_embeddings)self.i2key[num0] = _idself.strtexts[_id] = _strtextsself.id2ddls[_id] = _ddlsself.id2sqlques[_id] = _sql_quesself.id2docs[_id] = _documentsnum0 += 1# print("init instruction vec", num0)def calculate_score(self, score, question, kws):passdef find_vec_database(self, question, k, theata):# print(question)q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]uuid = self.i2key.get(sim_i, "none")sim_v = D[0][i]database_texts = self.strtexts.get(uuid, "none")# score = self.calculate_score(sim_v, question, database_texts) # wait implementscore = int(sim_v*1000)if score < theata:doc = {}doc["score"] = scoredoc["dataSetID"] = uuidresult.append(doc)# print(result)return resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DataBaseSearch(model)result = vs.find_vec_database("查询济南市第三幼儿园所有小班班级?", 1, 2000)print(result)

2.sql_ques:sql问题训练

class SqlQuesSearch(object):def __init__(self, _model):self.name = "SqlQuesSearch"self.model = _modelself.instruction = "为这段内容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.sqlquedata = []self.i2dbid = {}self.i2sqlid = {}self.id2sqlque = {}self.id2que = {}self.id2sql = {}self.dbid2sqlques = {}## self.sqlques = {}## self.i2key = {}## self.id2sqlques = {}## self.num2sqlque = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_textdata()  # 加载text数据self.load_textdata_vec()  # text数据向量化def load_textdata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)textdatas = jsonobj["data"]datadatas = jsonobj["data"]for datadata in datadatas:  # 提取每一个数据库sql-ques内容dbid = datadata["dataSetID"]sql_ques = datadata["exp"]self.sqlquedata.append((dbid, sql_ques))  # 整合sql数据except Exception as e:print(e)# print("load textdata ", self.sqlquedata)def load_textdata_vec(self):num0 = 0for recode in self.sqlquedata:db_id = recode[0]sql_ques = recode[1]for sql_que in sql_ques:sql_id = sql_que["sql_id"]question = sql_que["question"]sql = sql_que["sql"]ddl_embeddings = self.model.encode([question], normalize_embeddings=True)self.index.add(ddl_embeddings)self.i2dbid[num0] = db_idself.i2sqlid[num0] = sql_idself.id2que[sql_id] = questionself.id2sql[sql_id] = sqlnum0 += 1print("init sql-que vec", num0)def calculate_score(sim_v, question, sql_ques):passdef find_vec_sqlque(self, question, k, theta, dataSetID, number):q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]dbid = self.i2dbid.get(sim_i, "none")  # 获取数据库idsqlid = self.i2sqlid.get(sim_i, "none")question = self.id2que.get(sqlid, "none")sql = self.id2sql.get(sqlid, "none")if dbid == dataSetID:sim_v = D[0][i]score = int(sim_v * 1000)if score < theta:doc = {}doc["score"] = scoredoc["question"] = questiondoc["sql"] = sqlresult.append(doc)if len(result) == number:breakreturn resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = SqlQuesSearch(model)result = vs.find_vec_sqlque("查询7月18日所有的儿童观察记录?", 3, 2000, dataSetID=111)print(result)

3.数据库DDL训练

class DdlQuesSearch(object):def __init__(self, _model):self.name = "DdlQuesSearch"self.model = _modelself.instruction = "为这段内容生成表示以用于匹配文本描述:"self.SIZE = 1024self.index = faiss.IndexFlatL2(self.SIZE)self.ddldata = []self.sqlques = {}self.i2dbid = {}self.i2ddlid = {}self.dbid2ddls = {}self.id2ddl = {}self.ddlid2dbid = {}# self.ddldata = []# self.sqlques_data = []# self.document_data = []self.load_ddldata()  # 加载text数据self.load_ddl_vec()  # text数据向量化def load_ddldata(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)for database in databases:db_id = database["dataSetID"]ddls = database["ddl"]self.ddldata.append((db_id, ddls))# print(db_id)# for ddl in database["ddl"]:#     ddl_id = ddl["ddl_id"]#     ddl = ddl['ddl']##     self.id2ddl[ddl_id] = ddl# self.dbid2ddls[db_id] = self.id2ddlexcept Exception as e:print(e)# print("load textdata ", self.ddldata)def load_ddl_vec(self):num0 = 0for recode in self.ddldata:db_id = recode[0]ddls = recode[1]for ddl in ddls:ddl_id = ddl["ddl_id"]ddl_name = ddl["TABLE"]ddl = ddl['ddl']ddl_embeddings = self.model.encode([ddl], normalize_embeddings=True)self.index.add(ddl_embeddings)self.i2dbid[num0] = db_idself.i2ddlid[num0] = ddl_idself.id2ddl[ddl_id] = ddlself.ddlid2dbid[ddl_id] = db_idnum0 += 1self.dbid2ddls[db_id] = self.id2ddlprint("init ddl vec", num0)def find_vec_ddl(self, question, k, theata, dataSetID, number):       # dataSetID:数据库id# self.id2ddls.get(action_id)q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)D, I = self.index.search(q_embeddings, k)result = []for i in range(k):sim_i = I[0][i]dbid = self.i2dbid.get(sim_i, "none")         # 获取数据库idddlid = self.i2ddlid.get(sim_i, "none")if dbid == dataSetID:sim_v = D[0][i]score = int(sim_v * 1000)if score < theata:doc = {}doc["score"] = scoredoc["ddl"] = self.id2ddl.get(ddlid, "none")result.append(doc)if len(result) == number:breakreturn resultif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DdlQuesSearch(model)ss = vs.find_vec_ddl("定时任务执行记录表", 2, 2000, 111)print(ss)

4.数据库document训练

class DocQuesSearch(object):def __init__(self):self.name = "TestDataSearch"self.docdata = []self.load_doc_data()def load_doc_data(self):try:response = requests.post(url="xxx",verify=False)print(response.text)jsonobj = json.loads(response.text)databases = jsonobj["data"]for database in databases:db_id = database["dataSetID"]doc = database["Intro"]self.docdata.append((db_id, doc))except Exception as e:print(e)# print("load ddldata ", self.docdata)def find_similar_doc(self, dataSetID):result = []for recode in self.docdata:dbid = recode[0]doc = recode[1]if dbid == dataSetID:result.append(doc)return resultif __name__ == '__main__':docques_search = DocQuesSearch()result = docques_search.find_similar_doc(222)print(result)

5.生成sql语句,这里使用的qwen-max模型

import re
import random
import os, json
import dashscope
from dashscope.api_entities.dashscope_response import Message
from ddl_engine import DdlQuesSearch
from dashscope import Generation
from sqlques_engine import SqlQuesSearch
from sentence_transformers import SentenceTransformerclass Genarate(object):def __init__(self):self.api_key = os.environ.get('api_key')self.model_name = os.environ.get('model')def system_message(self, message):return {'role': 'system', 'content': message}def user_message(self, message):return {'role': 'user', 'content': message}def assistant_message(self, message):return {'role': 'assistant', 'content': message}def submit_prompt(self, prompt):resp = Generation.call(model=self.model_name,messages=prompt,seed=random.randint(1, 10000),result_format='message',api_key=self.api_key)if resp["status_code"] == 200:answer = resp.output.choices[0].message.contentglobal DEBUG_INFODEBUG_INFO = (prompt, answer)return answerelse:answer = Nonereturn answerdef generate_sql(self, question, sql_result, ddl_result, doc_result):prompt = self.get_sql_prompt(question = question,sql_result = sql_result,ddl_result = ddl_result,doc_result = doc_result)print("SQL Prompt:",prompt)llm_response = self.submit_prompt(prompt)sql = self.extrat_sql(llm_response)return sqldef extrat_sql(self, llm_response):sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlsqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)if sqls:sql = sqls[-1]return sqlreturn llm_responsedef get_sql_prompt(self, question, sql_result, ddl_result, doc_result):initial_prompt = "You are a SQL expert. " + \"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_result)initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_result)initial_prompt += ("===Response Guidelines \n""1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n""2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n""3. If the provided context is insufficient, please explain why it can't be generated. \n""4. Please use the most relevant table(s). \n""5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n")message_log = [self.system_message(initial_prompt)]message_log = self.add_sqlques_to_prompt(question, sql_result, message_log)return message_logdef add_ddl_to_prompt(self, initial_prompt, ddl_result):""":param initial_prompt::param ddl_result::return:"""ddl_list = [ ddl_['ddl'] for ddl_ in ddl_result]if len(ddl_list) > 0:initial_prompt += "\n===Tables \n"for ddl in ddl_list:initial_prompt += f"{ddl}\n\n"return initial_promptdef add_sqlques_to_prompt(self, question, sql_result, message_log):""":param sql_result::return:"""if len(sql_result) > 0:for example in sql_result:if example is not None and "question" in example and "sql" in example:message_log.append(self.user_message(example["question"]))message_log.append(self.assistant_message(example["sql"]))message_log.append(self.user_message(question))return message_logdef add_documentation_to_prompt(self, initial_prompt, doc_result):if len(doc_result) > 0:initial_prompt += "\n===Additional Context \n\n"for doc in doc_result:initial_prompt += f"{doc}\n\n"return initial_promptif __name__ == '__main__':modelpath = "E:\\module\\bge-large-zh-v1.5"model = SentenceTransformer(modelpath)vs = DdlQuesSearch(model)ss = vs.find_vec_ddl("定时任务执行记录表", 1, 2000, 111)print(ss)

6.执行结果显示

如图可以看到正确生成了sql,可以正常执行,因为表是拉取到,没有数据,所以查询结果为空。

需要源码的同学,可以留言。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/35702.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MFC开发:图形的绘制

文章目录 一、获取指定窗口的设备上下文二、画笔的介绍和使用三、绘制直线四、画刷的介绍和使用五、绘制扇形六、绘制圆形七、绘制文本 一、获取指定窗口的设备上下文 1.GetDC()函数的作用 GetDC() 是 Windows API 中的一个函数&#xff0c;它用于获取指定窗口的设备上下文&am…

SPI 总线协议

1、协议介绍 SPI&#xff0c;是英语 Serial Peripheral interface 的缩写&#xff0c;顾名思义就是串行外围设备接口。是 Motorola 首先在其 MC68HCXX 系列处理器上定义的。 SPI&#xff0c;是一种高速的&#xff0c;全双工&#xff0c;同步的通信总线。主节点或子节点的数据在…

Qt msvc程序运行

第一个Qt msvc程序 我们一般用qt msvc来编译程序&#xff0c;就是用webview。 第一个Qt msvc webview程序实现如下&#xff1a; 运行结果&#xff1a; 标注&#xff1a; QT版本大于6.0的时候才能用<Webview>模块。 QT版本在大于5.2版本&#xff0c;引入了Webengine模…

Java设计模式建模语言面向对象设计原则

设计模式 设计模式的概念 设计模式最初用于建筑领域的设计中。 软件的设计模式&#xff0c;又称设计模式&#xff0c;是一套被反复使用&#xff0c;多数人知道的&#xff0c;经过分类编目的&#xff0c;代码设计经验的总结。 它描述了在软件设计过程中的一些不断重复发生的…

搜广推校招面经五十四

美团推荐算法 一、手撕Transformer的位置编码 1.1. 位置编码的作用 Transformer 模型没有显式的序列信息&#xff08;如 RNN 的循环结构&#xff09;&#xff0c;因此需要通过位置编码&#xff08;Positional Encoding&#xff09;为输入序列中的每个位置添加位置信息。位置…

深入解析 SQL 事务:确保数据一致性的关键

SQL 事务 什么是 SQL 事务&#xff1f;事务的 ACID 特性原子性&#xff08;Atomicity&#xff09;:示例&#xff1a; 一致性&#xff08;Consistency&#xff09;:示例&#xff1a; 隔离性&#xff08;Isolation&#xff09;:持久性&#xff08;Durability&#xff09;:示例&am…

【软考-架构】11.3、设计模式-新

✨资料&文章更新✨ GitHub地址&#xff1a;https://github.com/tyronczt/system_architect 文章目录 项目中的应用设计模式创建型设计模式结构型设计模式行为型设计模式 &#x1f4af;考试真题题外话 项目中的应用 在实际项目中&#xff0c;我应用过多种设计模式来解决不同…

观察者模式详解:用 Qt 信号与槽机制深入理解

引言 你是否曾遇到这样的需求&#xff1a;一个对象的状态发生变化后&#xff0c;希望通知其他对象进行相应的更新&#xff1f;比如&#xff1a; 新闻订阅系统&#xff1a;当新闻发布后&#xff0c;所有订阅者都会收到通知。股票行情推送&#xff1a;股价变化时&#xff0c;所…

流量分析实践

下载附件使用wireshark打开&#xff0c;发现数据包非常多&#xff0c;一共有1万多条数据&#xff0c;我们点击分析来看一下协议分级 然后我们再来看一下会话&#xff0c;看有哪些ip地址&#xff0c; 我们通过会话结合大部分的流量发现&#xff0c;172.17.0.1一直在请求172.17.0…

新手村:混淆矩阵

新手村&#xff1a;混淆矩阵 一、前置条件 知识点要求学习资源分类模型基础理解分类任务&#xff08;如二分类、多分类&#xff09;和常见分类算法&#xff08;如逻辑回归、决策树&#xff09;。《Hands-On Machine Learning with Scikit-Learn》Python基础熟悉变量、循环、函…

MYSQL库的操作

目录 一、创建数据库 二、字符集和校验规则 1、查看系统默认字符集以及校验规则 2、查看系统支持的所有字符集以及字符集校验规则 3、指定字符集以及校验规则来创建数据库 4、校验规则对数据库的影响 三、操纵数据库 1、查看数据库 2、修改数据库 3、删除数据库 4、数…

Next App Router(下)

五、loading 新增 app/loading.tsx 页面 const Loading () > {return <div>Loading...</div>; }; export default Loading;修改 app/page.tsx页面 /** 假设为一个获取数字的api */ const fetch_getNumber async (): Promise<number> > {return ne…

【JAVA】】深入浅出了解cookie、session、jwt

文章目录 前言一、首先了解http的cookie是什么&#xff1f;Cookie 属性及其含义1. NameValue2. Expires3. Max-Age4. Domain5. Path6. Secure7. HttpOnly8. SameSite示例 Cookie 分类1. Session Cookies2. Persistent Cookies3. First-Party Cookies4. Third-Party Cookies 二、…

【css酷炫效果】纯CSS实现粒子旋转动画

【css酷炫效果】纯CSS实现粒子旋转动画 缘创作背景html结构css样式完整代码效果图 想直接拿走的老板&#xff0c;链接放在这里&#xff1a;https://download.csdn.net/download/u011561335/90492008 缘 创作随缘&#xff0c;不定时更新。 创作背景 刚看到csdn出活动了&…

C++Lambda表达式

Lambda表达式 什么是Lambda表达式 ​ C11的颁布让C丰富了起来&#xff0c;任何一本介绍C11的书籍&#xff0c;都不可能跳过这一个点——Lambda表达式。人们经常称Lambda表达式是一个语法糖&#xff0c;说明这是一个”没有没事&#xff0c;有了更好“的一种语法表达&#xff0…

每天五分钟深度学习框架pytorch:基于pytorch搭建循环神经网络RNN

本文重点 我们前面介绍了循环神经网络RNN,主要分析了它的维度信息,其实它的维度信息是最重要的,一旦我们把维度弄清楚了,一起就很简单了,本文我们正式的来学习一下,如何使用pytorch搭建循环神经网络RNN。 RNN的搭建 在pytorch中我们使用nn.RNN()就可以创建出RNN神经网络…

el-table树形表格合并相同的值

el-table树形表格合并相同的值 el-table树形表格合并相同的值让Ai进行优化后的代码 el-table树形表格合并相同的值 <style lang"scss" scoped> .tableBox {/deep/ &.el-table th:first-child,/deep/ &.el-table td:first-child {padding-left: 0;} } …

2025年3月19日 十二生肖 今日运势

小运播报&#xff1a;2025年3月19日&#xff0c;星期三&#xff0c;农历二月二十 &#xff08;乙巳年己卯月丁亥日&#xff09;&#xff0c;法定工作日。 红榜生肖&#xff1a;兔、虎、羊 需要注意&#xff1a;猪、猴、蛇 喜神方位&#xff1a;正南方 财神方位&#xff1a;…

Git——分布式版本控制工具使用教程

本文主要介绍两种版本控制工具——SVN和Git的概念&#xff0c;接着会讲到Git的安装&#xff0c;Git常用的命令&#xff0c;以及怎么在Vscode中使用Git。帮助新手小白快速上手Git。 1. SVN和Git介绍 1.1 SVN 集中式版本控制工具&#xff0c;版本库是集中存放在中央服务器的&am…

QT5.15.2加载pdf为QGraphicsScene的背景

5.15.2使用pdf 必须要安装QT源码&#xff0c;可以看到编译器lib目录已经有pdf相关的lib文件&#xff0c;d是debug 1.找到源码目录&#xff1a;D:\soft\QT\5.15.2\Src\qtwebengine\include 复制这两个文件夹到编译器的包含目录中:D:\soft\QT\5.15.2\msvc2019_64\include 2.找…