LLM微调(四)| 微调Llama 2实现Text-to-SQL,并使用LlamaIndex在数据库上进行推理

        Llama 2是开源LLM发展的一个巨大里程碑。最大模型及其经过微调的变体位居Hugging Face Open LLM排行榜(https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard)前列。多个基准测试表明,就性能而言,它正在接近GPT-3.5(在某些情况下甚至超过它)。所有这些都意味着,对于从RAG系统到Agent的复杂LLM应用程序,开源LLM是一种越来越可行和可靠的选择。

一、Llama-2–7B不擅长从文本到SQL

       最小的Llama 2模型(7B参数)有一个缺点是它不太擅长生成SQL,因此它不适用于结构化分析示例。例如,我们尝试在给定以下提示模板的情况下提示Llama 2生成正确的SQL语句:

You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question.### Input:{input}### Context:{context}### Response:

         在这里,我们使用sqlcreatecontext数据集(https://huggingface.co/datasets/b-mc2/sql-create-context)的一个示例来测试一下效果:

input: In 1981 which team picked overall 148?context: CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)

         同时,这里是生成的输出与正确输出的对比:

Generated output: SELECT * FROM `table_name_8` WHERE '1980' = YEAR AND TEAM = "Boston Celtics" ORDER BY OVERALL_PICK DESC LIMIT 1;Correct output: SELECT team FROM table_name_8 WHERE year = 1981 AND overall_pick = "148"

       这显然并不理想。与ChatGPT和GPT-4不同,原始的Llama 2不能生成期望的的格式和正确的SQL

      这正是微调的作用所在——如果有一个合适的文本到SQL数据的语料库,我们可以教Llama 2更好地从自然语言生成SQL输出。微调有不同的方法,可以更新模型的所有参数(比如:全量微调),也可以冻结大模型参数仅微调附加参数(比如:LoRA)。

二、微调Llama-2–7B,使其可以从文本生成SQL

       接下来,我们将展示如何在文本到SQL数据集上微调Llama 2,然后使用LlamaIndex的功能对任何SQL数据库进行结构化分析。

准备工作:

微调数据集:来自Hugging Face的b-mc2/sql-create-context(https://huggingface.co/datasets/b-mc2/sql-create-context)

base模型:OpenLLaMa 的open_lama_7b_v2(https://github.com/openlm-research/open_llama)

步骤1:加载微调LLaMa的训练数据

PS:1)以下代码来自doppel-bot:https://github.com/modal-labs/doppel-bot;2)许多Python代码都包含在src目录中;3)需要设置一个Modal帐户,并生成token。

!pip install -r requirements.txt

       首先,我们使用Modal加载b-mc2/sql-create-context数据集,并将其格式化为.jsonl文件。

modal run src.load_data_sql --data-dir "data_sql"

结果如下所示:

# Modal stubs allow our function to run remotely@stub.function(    retries=Retries(        max_retries=3,        initial_delay=5.0,        backoff_coefficient=2.0,    ),    timeout=60 * 60 * 2,    network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},    cloud="gcp",)def load_data_sql(data_dir: str = "data_sql"):    from datasets import load_dataset    dataset = load_dataset("b-mc2/sql-create-context")    dataset_splits = {"train": dataset["train"]}    out_path = get_data_path(data_dir)    out_path.parent.mkdir(parents=True, exist_ok=True)    for key, ds in dataset_splits.items():        with open(out_path, "w") as f:            for item in ds:                newitem = {                    "input": item["question"],                    "context": item["context"],                    "output": item["answer"],                }                f.write(json.dumps(newitem) + "\n")

步骤2:运行微调脚本

在微调数据集微调llama2模型,代码如下:

modal run src.finetune_sql --data-dir "data_sql" --model-dir "model_sql"

微调脚本会执行以下步骤:

将数据集拆分为训练和验证拆分

train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)

       将每个拆分为元组的格式(输入Prompt、标签):输入query和上下文被格式化为输入Prompt,然后对输入Prompt和标签进行 tokenize,模型采用自回归的方法预测下一个token来进行训练。

def generate_and_tokenize_prompt(data_point):  full_prompt = generate_prompt_sql(      data_point["input"],      data_point["context"],      data_point["output"],  )  tokenized_full_prompt = tokenize(full_prompt)  if not train_on_inputs:      raise NotImplementedError("not implemented yet")  return tokenized_full_prompt

PS:输入Prompt与开始测试llama2的格式完全相同。

       运行微调脚本时,模型将保存在model_dir指定的远程云目录中(如果未指定,则设置为默认值)。

步骤3:评估微调后模型

       该模型已经进行了微调,可以从云端提供服务。下面我们使用b-mc2/sql-create-context中的示例数据进行一些基本评估,比较微调后模型与原始Llama 2模型的性能。

modal run src.eval_sql::main

结果表明,微调后的模型有了巨大的改进:

Input 1: {'input': 'Which region (year) has Abigail at number 7, Sophia at number 1 and Aaliyah at number 5?', 'context': 'CREATE TABLE table_name_12 (region__year_ VARCHAR, no_5 VARCHAR, no_7 VARCHAR, no_1 VARCHAR)', 'output': 'SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" ANDno_5 = "aaliyah"'}Output 1 (finetuned model): SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "aaliyah" AND no_5 = "sophia"Output 1 (base model): SELECT * FROM table_name_12 WHERE region__year = '2018' AND no_5 = 'Abigail' AND no_7 = 'Sophia' AND no_1 = 'Aaliyah';Input 2: {'input': 'Name the result/games for 54741', 'context': 'CREATE TABLE table_21436373_11 (result_games VARCHAR, attendance VARCHAR)', 'output': 'SELECT result_games FROM table_21436373_11 WHERE attendance = 54741'}Output 2 (finetuned model): SELECT result_games FROM table_21436373_11 WHERE attendance = "54741"Output 2 (base model): SELECT * FROM table_21436373_11 WHERE result_games = 'name' AND attendance > 0;

步骤4:将微调模型与LlamaIndex集成

       我们现在可以在LlamaIndex中使用这个模型,在任何数据库上进行文本到SQL。

       我们首先定义一个测试SQL数据库,然后可以使用该数据库来测试模型的推理能力。

       我们创建了一个玩具city_stats表,其中包含城市名称、人口和国家信息,并用几个示例城市填充它。

db_file = "cities.db"engine = create_engine(f"sqlite:///{db_file}")metadata_obj = MetaData()# create city SQL tabletable_name = "city_stats"city_stats_table = Table(    table_name,    metadata_obj,    Column("city_name", String(16), primary_key=True),    Column("population", Integer),    Column("country", String(16), nullable=False),)metadata_obj.create_all(engine)

这存储在cities.db文件中。

     然后,我们可以使用Modal将微调后的模型和该数据库文件加载到LlamaIndex中的NLSQLTableQueryEngine中——该查询引擎允许用户轻松地开始在给定的数据库上执行文本到SQL。

modal run src.inference_sql_llamaindex::main --query "Which city has the highest population?" --sqlite-file-path "nbs/cities.db" --model-dir "model_sql" --use-finetuned-model True

我们得到如下回复:

SQL Query: SELECT MAX(population) FROM city_stats WHERE country = "United States"Response: [(2679000,)]

三、结论

        本文提供了一种非常高级的方法来开始微调生成SQL语句的Llama 2模型,并展示了如何使用LlamaIndex将其端到端插入到文本到SQL工作流中。

参考文献:

[1] https://blog.llamaindex.ai/easily-finetune-llama-2-for-your-text-to-sql-applications-ecd53640e10d

[2] https://github.com/run-llama/modal_finetune_sql

[3] https://github.com/run-llama/modal_finetune_sql/blob/main/tutorial.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/223974.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

光耦继电器

光耦继电器(光电继电器) AQW282SX 282SZ 280SX 280SZ 284SX 284SZ 212S 212SX 21 2SZ 文章目录 光耦继电器(光电继电器)前言一、光耦继电器是什么二、光耦继电器的类型三、光电耦合器的应用总结前言 光耦继电器在工业控制、通讯、医疗设备、家电及汽车电子等领域得到广泛应…

【隐私保护】Presidio简化了PII匿名化

自我介绍 做一个简单介绍,酒架年近48 ,有20多年IT工作经历,目前在一家500强做企业架构.因为工作需要,另外也因为兴趣涉猎比较广,为了自己学习建立了三个博客,分别是【全球IT瞭望】,【…

YOLOv8改进 | 2023注意力篇 | MSDA多尺度空洞注意力(附多位置添加教程)

一、本文介绍 本文给大家带来的改进机制是MSDA(多尺度空洞注意力)发表于今年的中科院一区(算是国内计算机领域的最高期刊了),其全称是"DilateFormer: Multi-Scale Dilated Transformer for Visual Recognition"。MSDA的主要思想是…

STM32F407-14.3.10-表73具有有断路功能的互补通道OCx和OCxN的输出控制位-1x111

如上表所示,MOE1,OSSR1,CCxE1,CCxNE1时,OCx与OCxN对应端口的输出状态取决于OCx_REF与极性选择(CCxP,CCxNP) 死区。 -------------------------------------------------------------…

记pbcms网站被攻击,很多标题被篡改(1)

记得定期打开网站看看哦! 被攻击后的网站异常表现:网页内容缺失或变更,页面布局破坏,按钮点击无效,...... 接着查看HTML、CSS、JS文件,发现嵌入了未知代码! 攻击1:index.html 或其他html模板页面的标题、关键词、描述被篡改(俗称,被挂马...),如下: 攻击2:在ht…

【PostGIS】PostgreSQL15+对应PostGIS安装教程及空间数据可视化

一、PostgreSQL15与对应PostGIS安装 PostgreSQL15安装:下载地址PostGIS安装:下载地址(选择倒数第二个) 1、PostgreSQL安装 下载安装包;开始安装,这里使用默认安装,一直next直到安装完成&…

ubuntu下docker安装,配置python运行环境

参考自: 1.最详细ubuntu安装docker教程 2.使用docker搭建python环境 首先假设已经安装了docker,卸载原来的docker 在命令行中运行: sudo apt-get updatesudo apt-get remove docker docker-engine docker.io containerd runc 安装docker依赖 apt-get…

饥荒Mod 开发(二一):超大便携背包,超大物品栏,永久保鲜

饥荒Mod 开发(二十):显示打怪伤害值 饥荒Mod 开发(二二):显示物品信息 源码 游戏中的物品栏容量实在太小了,虽然可以放在箱子里面但是真的很不方便,外出一趟不容易看到东西都不能捡。实在是虐心。 游戏中的食物还有变质机制&#…

SSTI模板注入基础(Flask+Jinja2)

文章目录 一、前置知识1.1 模板引擎1.2 渲染 二、SSTI模板注入2.1 原理2.2 沙箱逃逸沙箱逃逸payload讲解其他重要payload 2.3 过滤绕过点.被过滤下划线_被过滤单双引号 "被过滤中括号[]被过滤关键字被过滤 三、PasecaCTF-2019-Web-Flask SSTI参考文献 一、前置知识 1.1 模…

力扣:51. N 皇后

题目: 按照国际象棋的规则,皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn 的棋盘上,并且使皇后彼此之间不能相互攻击。 给你一个整数 n ,返回所有不同的 n 皇后问题 的…

多维时序 | MATLAB实现SSA-CNN-SVM麻雀算法优化卷积神经网络-支持向量机多变量时间序列预测

多维时序 | MATLAB实现SSA-CNN-SVM麻雀算法优化卷积神经网络-支持向量机多变量时间序列预测 目录 多维时序 | MATLAB实现SSA-CNN-SVM麻雀算法优化卷积神经网络-支持向量机多变量时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 多维时序 | MATLAB实现…

ubuntu22.04 下载路径

ftp下载路径 csdn下载 ubuntu22.04下载路径ubuntu-22.04-desktop-amd64.7z.001资源-CSDN文库 ubuntu22.04下载路径ubuntu-22.04-desktop-amd64.7z.002资源-CSDN文库 【免费】ubuntu-22.04-desktop-amd64.7z.003资源-CSDN文库 【免费】ubuntu-22.04-desktop-amd64.7z.004资源-…

大数据应用开发1——配置基础环境

一、基础环境配置 1.配置虚拟网络 1.1、点击1、编辑2和3, 1.2、点开4,编辑网关 2、配置虚拟机环境 1.1、安装一台虚拟机,使用root用户登录,打开终端 1.2修改主机名 终端输入: vim /etc/hostname使用vim编辑/etc/ho…

linux异步IO的几种方法及重点案例

异步IO的方法 在Linux下,有几种常见的异步I/O(Asynchronous I/O)机制可供选择。以下是其中一些主要的异步I/O机制: POSIX AIO(Asynchronous I/O):POSIX AIO是一种标准的异步I/O机制&#xff0c…

三道C语言中常见的笔试题及答案(一)

题目一&#xff1a; 问题&#xff1a; 解释以下代码中的#define预处理指令的作用&#xff0c;并说明其优点和缺点。 #include <stdio.h> #define PI 3.14159 #define CALCULATE_AREA(r) (PI * r * r) int main() { double radius 5.0; double area CALCULATE_AREA(r…

基于STM32的DS1302实时时钟模块应用

DS1302是一款低功耗的实时时钟芯片&#xff0c;被广泛应用于各种电子产品中。它具有准确计时、多种时间格式表示、定时报警等功能&#xff0c;适用于记录时间、日期和闹钟。在本文中&#xff0c;我们将介绍如何在基于STM32的开发环境中使用DS1302实时时钟模块&#xff0c;并给出…

设计模式--命令模式

实验16&#xff1a;命令模式 本次实验属于模仿型实验&#xff0c;通过本次实验学生将掌握以下内容&#xff1a; 1、理解命令模式的动机&#xff0c;掌握该模式的结构&#xff1b; 2、能够利用命令模式解决实际问题。 [实验任务]&#xff1a;多次撤销和重复的命令模式 某系…

智能优化算法应用:基于孔雀算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于孔雀算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于孔雀算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.孔雀算法4.实验参数设定5.算法结果6.参考文献7.MA…

Prompt-to-Prompt:基于 cross-attention 控制的图像编辑技术

Hertz A, Mokady R, Tenenbaum J, et al. Prompt-to-prompt image editing with cross attention control[J]. arXiv preprint arXiv:2208.01626, 2022. Prompt-to-Prompt 是 Google 提出的一种全新的图像编辑方法&#xff0c;不同于任何传统方法需要用户指定编辑区域&#xff…

Nginx快速入门:实现企业安全防护|nginx部署https,ssl证书(七)

0. 引言 之前我们讲到nginx的一大核心作用就是实现企业安全防护&#xff0c;而实现安全防护的原理就是通过部署https证书&#xff0c;以此实现参数加密访问&#xff0c;从而加强企业网站的安全能力。 nginx作为各类服务的统一入口&#xff0c;只需要在入口处部署一个证书&…