微调Llama2自我认知

一、概述

最近在学习了解大模型微调相关的内容,在学习的过程中也遇到了很多问题,所以将自己的学习过程记录下来,希望对大模型微调感兴趣的小伙伴提供一点帮助,本文主要介绍一下如何通过SFT微调Llama2的自我认知,先看一下微调前后的效果比对:

微调前:

微调后:

通过本文的学习,你将了解如下内容:

  • 如何使用SFT微调Llama2
  • 如何导出微调后的大模型
  • 如何使用FastChat实现 OpenAI 兼容的 RESTful API 接口

二、环境与模型选择

环境配置

使用 nvidia-smi 命令查看 GPU 的配置,微调的GPU配置如下:

$nvidia-smi     
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A800-SXM...  Off  | 00000000:8E:00.0 Off |                    0 |
| N/A   30C    P0    69W / 400W |  17320MiB / 81251MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------++-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

微调 Llama2 需要 1 个GPU,24G 内存,较低的内存会导致加载模型较慢。

开源框架和模型

  • 微调的模型: Chinese-Llama-2-7b
  • 微调框架: LLaMA-Efficient-Tuning
  • 提供openai兼容的RESTful API框架: FastChat
  • 本地知识库问答应用框架: LangChain-Chatchat

由于Llama2本身的中文对齐较弱,这里没有直接使用 meta-llama/Llama-2-7b而是使用 LinkSoul/Chinese-Llama-2-7b进行微调,微调方法是类似的,感兴趣的可以基于 meta-llama/Llama-2-7b 进行微调。下面详细介绍一下微调的步骤。

三、SFT微调

1、下载预训练模型

在 huggingface上面搜索模型名称,可以看到下载模型的方式如下:

新建一个 models 文件夹用来存放下载的大模型,使用下面的命令下载预训练模型:

# 在当前目录新建一个 models 文件夹用来存放大模型
mkdir models
# 使用下面的命令下载模型,模型比较大,下载过程较缓慢,
git lfs install
git clone https://huggingface.co/LinkSoul/Chinese-Llama-2-7b# 设置下面的环境变量,则不会下载大文件,只会下载小文件
GIT_LFS_SKIP_SMUDGE=1

2、下载微调框架

使用如下命令,在当前目录下载微调框架 LLaMA-Efficient-Tuning

git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git

进入 LLaMA-Efficient-Tuning 目录:

cd LLaMA-Efficient-Tuning

3、准备微调数据

进入微调框架LLaMA-Efficient-Tuning目录后,找到存放微调数据的data目录,如下所示:

我们可以查看一下 self_cognition.json自我认知文件内容如下:

可以看到 <NAME><AUTHOR>是占位符,我们只需要复制一份文件,将对应的占位符替换为需要的名称即可,复制一份文件是为了自我认知的模版文件可复用,我替换后的文件内容如下,你可以改成自己的名字:

微调数据准备好了后,需要在 dataset_info.json 中配置如下:

{"self_cognition": {"file_name": "self_cognition.json","file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67"}
}

dataset_info.json文件会被转换为 python 的字典,self_cognition就是字典的 key,在微调的时候需要指定的数据集名称就是该 key,file_sha1 文件的摘要可以不填,file_name就是微调文件的名称,如果该微调文件在data目录中,则直接指定名称即可,如果在data目录的子目录中,则需要指定子目录的名字,举例如下:

4、开始SFT微调

微调数据准备好后就可以开始执行微调了,使用如下命令进行微调:

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \--stage sft \--do_train \--dataset self_cognition \--model_name_or_path /ossfs/workspace/models/Chinese-Llama-2-7b \--output_dir /ossfs/workspace/llama2-sft/checkpoint-01 \--template default \--finetuning_type lora \--lora_target q_proj,v_proj \--overwrite_cache \--per_device_train_batch_size 4 \--gradient_accumulation_steps 4 \--lr_scheduler_type cosine \--logging_steps 10 \--save_steps 2000 \--learning_rate 1e-3 \--num_train_epochs 10.0 \--plot_loss \--fp16

下面是对这个大模型训练命令中各个参数的详细解释:


--stage sft: 训练阶段。这里指定为sft,表示进行模型的微调(self-supervised fine-tuning)阶段。--do_train: 是否进行训练,设置为True表示进行训练。还可以设置为(--do_eval:表示评估,--do_predict:表示预测)--dataset self_cognition: 数据集名称。这里指定为self_cognition,表示使用自我认知数据集。--model_name_or_path /ossfs/workspace/models/Chinese-Llama-2-7b: 预训练模型的名称或路径。这里指定为/ossfs/workspace/models/Chinese-Llama-2-7b,表示加载路径下的预训练模型。--output_dir /ossfs/workspace/llama2-sft/checkpoint-01: 训练输出目录。训练过程中生成的模型和日志将保存在该目录下。--template default: 模板名称。这里指定为default,表示使用默认模板。--finetuning_type lora: 微调类型。这里指定为lora,表示使用LoRA(Language Representation with Additive Transformation)微调方法。--lora_target q_proj,v_proj: LoRA微调的目标层。这里指定为q_proj,v_proj,表示只对q_proj和v_proj两个层进行微调。--overwrite_cache: 是否覆盖缓存。设置为True表示覆盖缓存。--per_device_train_batch_size 4: 每个设备的训练批次大小。这里指定为4,表示每个设备上的训练批次大小为4。--gradient_accumulation_steps 4: 梯度累积步数。这里指定为4,表示每4个步骤累积一次梯度。--lr_scheduler_type cosine: 学习率调度器类型。这里指定为cosine,表示使用余弦学习率调度器。--logging_steps 10: 日志记录步数。每训练多少步记录一次训练日志。--save_steps 2000: 模型保存步数。每训练多少步保存一次模型。--learning_rate 1e-3: 学习率。这里指定为1e-3,表示初始学习率为0.001。--num_train_epochs 10.0: 训练轮数。这里指定为10.0,表示进行10轮训练。--plot_loss: 是否绘制损失曲线。设置为True表示绘制损失曲线。--fp16: 是否使用混合精度(half-precision)训练。设置为True表示使用混合精度训练。

以上是对该大模型训练命令中各个参数的解释。根据需求,可以根据实际情况进行相应参数的修改。不同的参数设置会对训练过程和结果产生影响,需要根据具体任务和数据集进行调整。

微调过程比较耗时,需要耐心等待

启动微调命令后,输出日志如下,需要用户输入是否需要 wandb (一个深度学习轻量级可视化工具)将训练结果可视化,我这里选择不可视化训练结果。

微调命令结束后可以看到如下日志输出,从日志中可以看到微调后的模型 checkpoint 的位置、损失曲线的信息以及训练的汇总信息:

查看损失曲线:

损失曲线图像解读:

在大模型训练过程中,train loss 图像是指每个训练批次的损失值随训练轮次的变化情况。这个图像可以用来解读训练过程中模型的收敛情况和学习进展。

train loss 图像的纵轴表示损失值,横轴表示训练轮次或训练批次。通常,初始阶段的损失值较高,随着训练的进行,损失值会逐渐下降。如果损失值趋向于稳定,说明模型已经收敛,训练效果良好。如果损失值下降很慢,可能需要更多的训练轮次或调整模型超参数。如果损失值波动较大,可能存在过拟合或其他问题,需要进一步调整模型或数据。

解读train loss 图像时,可以观察以下几个方面:

  1. 初始阶段的损失值高低,较高的初始损失值可能表明模型初始化不合适,需要调整初始化方法。
  2. 损失值下降的速率,较快的下降速率可能表明模型对数据的学习能力较强,但也可能存在过拟合的风险。
  3. 损失值的稳定性,稳定的损失值说明模型已经收敛,训练效果较好。如果损失值在一定范围内波动,可以考虑增加训练轮次或使用正则化等方法进一步优化模型。
  4. 训练过程中的异常情况,如损失值突然上升或跳跃,可能表明出现了问题,需要检查模型或数据是否存在异常。

总之,train loss 图像可以提供对模型训练过程的直观理解,帮助调整模型和优化训练策略,以达到更好的训练效果。

train loss 的值下降到什么范围表示模型的训练效果较好?

train loss 的值下降到一个较低的范围可以表示模型的训练效果较好。具体的判断标准可以根据具体的任务和数据集来确定,没有一个统一的阈值。

一种常见的做法是观察 train loss 图像的趋势,如果随着训练的进行,train loss 不断下降并趋于稳定,说明模型对训练数据的拟合效果较好,训练效果较好。

此外,可以根据验证集的表现来评估模型的训练效果。如果验证集的损失值也在下降并趋于稳定,且与训练集的损失值相近,说明模型在训练集和验证集上都能取得较好的效果,训练效果较好。

需要注意的是,train loss 仅仅是一个指标,不能完全代表模型的训练效果。还需要综合考虑模型在其他指标上的表现,如准确率、精确率、召回率等,以及在实际应用场景中的效果。

5、测试微调后的模型

微调框架 LLaMA-Efficient-Tuning中提供了三种测试使用微调模型的方式,如下所示:

  • api_demo.py:使用api的方式调用微调模型
  • cli_demo.py:在命令行中调用微调模型
  • web_demo.py:在web页面中调用微调模型

由于我这里的服务器没有外网访问的地址,所以使用 cli_demo.py命令行的方式手动测试微调后的模型,启动命令如下:

CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \--model_name_or_path /ossfs/workspace/models/Chinese-Llama-2-7b \--checkpoint_dir /ossfs/workspace/llama2-sft/checkpoint-01\--template llama2

查看 cli_demo.py 源码,调用了 ChatModel.stream_chat(query, history)ChatModel的构造方法中调用了 get_infer_args(args),如下所示:

get_infer_args(args)如下所示,可以看到只有 LoRA 的微调支持指定多个 checkpoint

在模型参数中可以看到指定 checkpoint_dir 时可以使用 分隔多个 checkpoint:

所以如果你使用 LoRA 进行微调,那么当有多个微调任务,生成多个 checkpoint 时,多个 checkpoint 可以使用 ,分隔,假设你微调了两个checkpoint:/ossfs/workspace/llama2-sft/checkpoint-01/ossfs/workspace/llama2-sft/checkpoint-02,那么你可以使用下面的命令测试两个微调后的模型,如下所示:

CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \--model_name_or_path /ossfs/workspace/models/Chinese-Llama-2-7b \--checkpoint_dir /ossfs/workspace/llama2-sft/checkpoint-01,/ossfs/workspace/llama2-sft/checkpoint-02\--template default

因为我这里只微调了自我认知,并且将在微调的时候指定 --output_dir /ossfs/workspace/llama2-sft/checkpoint-01,所以使用下面的命令来测试微调后的模型即可:

CUDA_VISIBLE_DEVICES=0 python src/cli_demo.py \--model_name_or_path /ossfs/workspace/models/Chinese-Llama-2-7b \--checkpoint_dir /ossfs/workspace/llama2-sft/checkpoint-01\--template default

运行测试命令需要再次加载模型,比较耗时,需要耐心等待,运行成功后可以看到如下输出:

接下来就可以问一些自我认知的问题进行验证了,如下所示:

6、导出微调后的模型

经过前面的微调,如果微调后的模型通过了测试就可以将微调后的模型导出,使用如下命令即可:

CUDA_VISIBLE_DEVICES=0 python src/export_model.py \--model_name_or_path /ossfs/workspace/models/Chinese-Llama-2-7b \--checkpoint_dir /ossfs/workspace/llama2-sft/checkpoint-01\--output_dir /ossfs/workspace/sft-models/my-llama5 \--template default

我这里使用 --output_dir 将模型导出到 /ossfs/workspace/sft-models/my-llama5目录中,可以看到目录中包括如下内容:

7、微调模型提供RESTful API接口

经过前面的步骤,我们已经将自己微调后的模型导出了,现在我们可以使用 FastChat 将模型发布为 openai 兼容的RESTful API以便外部服务使用。

FastChat 为其支持的模型提供与 OpenAI 兼容的 API,因此您可以使用 FastChat 作为 OpenAI API 的本地直接替代品。FastChat 服务器与openai-python库和 cURL 命令兼容。

支持以下 OpenAI API:

  • Chat Completions
  • Completions
  • Embeddings

RESTful API 服务器

首先,启动控制器

python3 -m fastchat.serve.controller

运行命令输出内容如下:

然后,启动模型,通过 --model-path 指定模型的路径,这里我们指定前面微调后的模型路径 /ossfs/workspace/sft-models/my-llama5

python3 -m fastchat.serve.model_worker --model-path /ossfs/workspace/sft-models/my-llama5

运行命令输出内容如下:

最后,启动 RESTful API 服务器

python3 -m fastchat.serve.openai_api_server --host localhost --port 8000

运行命令输出内容如下:

现在,让我们测试 API 服务器。

OpenAI官方SDK

目标openai_api_server.py是实现一个完全兼容 OpenAI 的 API 服务器,因此模型可以直接与openai-python库一起使用。

首先,安装openai-python:

pip install --upgrade openai

然后使用下面的代码与模型进行测试:

import openai
# to get proper authentication, make sure to use a valid key that's listed in
# the --api-keys flag. if no flag value is provided, the `api_key` will be ignored.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"# 这里指定微调的模型名字,也就是保存模型文件的文件夹名称
model = "my-llama5"# create a chat completion
completion = openai.ChatCompletion.create(model=model,messages=[{"role": "user", "content": "你是谁"}]
)
# print the completion
print(completion.choices[0].message.content)

在jupyterlab中运行上面的代码,输出结果如下:

8、微调模型和本地知识库整合

因为后面打算学习了解一下将大模型和知识库整合,所以我这里先使用本地知识库问答应用框架: LangChain-Chatchat ****和微调后的模型整合。下面我先简单介绍一下整合的步骤,后面会再写一篇文章详细介绍一下大模型和本地知识库相关的内容。

首先下载 Langchain-Chatchat,使用如下命令:

git clone https://github.com/chatchat-space/Langchain-Chatchat.git

进入 Langchain-Chatchat,使用下面的命令安装 python 依赖库:

cd Langchain-Chatchatpip install -r requirements.txtpip install -r requirements_api.txt

使用下面的命令复制一份配置文件:

cp configs/model_config.py.example configs/model_config.py

如下所示:

model_config.py 配置文件需要修改如下内容,在llm_model_dict指定模型的地址,并且设置LLM_MODEL的名称和 llm_model_dict的 key 对应,如下所示:

llm_model_dict = {"llama2": {"local_model_path": "/ossfs/workspace/sft-models/my-llama5","api_base_url": "http://localhost:8888/v1",  # 修改为fastchat服务中的"api_base_url""api_key": "EMPTY"}
}# LLM 名称
LLM_MODEL = "llama2"

接下来就可以使用下面的命令启动 llm_api.py

python server/llm_api.py

启动成功后可以使用下面的代码在 jupyterlab 中进行验证:

# 服务启动后接口调用示例:
import openai
openai.api_key = "EMPTY" # Not support yet
openai.api_base = "http://localhost:8888/v1"model = "llama2"def get_answer(content):# create a chat completioncompletion = openai.ChatCompletion.create(model=model,messages=[{"role": "user", "content": content}])print('用户:', content)# print the completionprint('模型:',completion.choices[0].message.content)get_answer('你是谁')
get_answer('你叫什么名字')

验证输出结果如下:

参考文档

https://github.com/hiyouga/LLaMA-Efficient-Tuning

https://github.com/chatchat-space/Langchain-Chatchat

https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md

https://huggingface.co/LinkSoul/Chinese-Llama-2-7b

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/105098.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是网络中的服务质量 (QoS),其相关技术和关键指标有哪些?

QoS&#xff08;Quality of Service&#xff0c;服务质量&#xff09;指一个网络能够利用各种基础技术&#xff0c;为指定的网络通信提供更好的服务能力&#xff0c;是网络的一种安全机制&#xff0c;是用来解决网络延迟和阻塞等问题的一种技术。QoS的保证对于容量有限的网络来…

学习笔记230816---vue项目中使用第三方组件{el-dropdown}如何设置禁止事件功能

问题描述 使用第三方组件elementui&#xff0c;在导航菜单el-menu的el-menu-item中嵌入一个下拉菜框el-dropdown。点击...icon弹出下拉菜单el-dropdown-menu&#xff0c;那么这时会触发事件冒泡&#xff0c;el-menu-item菜单项的点击事件也会触发。 解决方法 阻止事件冒泡&am…

Java【手撕双指针】LeetCode 57. “两数之和“, 图文详解思路分析 + 代码

文章目录 前言一、两数之和1, 题目2, 思路分析3, 代码展示 前言 各位读者好, 我是小陈, 这是我的个人主页, 希望我的专栏能够帮助到你: &#x1f4d5; JavaSE基础: 基础语法, 类和对象, 封装继承多态, 接口, 综合小练习图书管理系统等 &#x1f4d7; Java数据结构: 顺序表, 链表…

c++ day3

#include <iostream>using namespace std; class per {string name;int age;int *p;int *q; public:per(string name,int age,int a,int b){this->name(name);this->ageage;pnew int(a);qnew int(b);*qb;*pa;cout << "有参构造"<<endl;}void…

如何提取视频的音频到手机?这个音频提取方法很简单

提取视频中的音频可以帮助您获得视频的声音部分&#xff0c;而无需观看整个视频。这对于那些只想听视频的声音或想将视频的声音与其他音频内容混合使用的人来说非常方便。此外&#xff0c;提取音频也可以为需要创建音频剪辑或混音的音频制作者提供帮助。那么怎么提取呢&#xf…

Ant Design Vue 日期选择器DatePicker传给后台日期参数格式问题

花了一个下午才解决&#xff0c;官方组件文档里面是没有处理方案说明的。 项目版本&#xff1a;Ant Design Vue 2.0.2 前端部分代码&#xff1a; <template><a-modal:visible"visible":width"windowWidth":height"800":title"tit…

Golang Gorm 一对多关系 关系表创建

一对多关系 我们先从一对多开始多表关系的学习因为一对多的关系生活中到处都是&#xff0c;例如&#xff1a; 老板与员工女神和添狗老师和学生班级与学生用户与文章 在创建的时候先将没有依赖的创建。表名称ID就是外键。外键要和关联的外键的数据类型要保持一致。 package ma…

创建和分析二维桁架和梁结构研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

最小二乘法,残差,线性模型-线性回归

目录 什么是最小二乘法 残差是什么意思 线性模型 线性回归 方法一&#xff1a;解析解法 代码实战&#xff1a; 方法二&#xff1a;数值解法 代码实战&#xff1a; 解析法&#xff08;最小二乘&#xff09;还是数值法&#xff08;梯度下降&#xff09;&#xff0c;如何…

Unity 之 Start 与Update 方法的区别

文章目录 当谈论Unity中的 Start和 Update方法时&#xff0c;我们实际上是在讨论MonoBehaviour类中的两个常用方法&#xff0c;用于编写游戏逻辑。这两个方法在不同的时机被调用&#xff0c;因此您可以根据需要选择在哪个方法中编写特定的代码。 Start 方法&#xff1a; Start…

23款奔驰GLS450升级原厂电动吸合门,体验绅士的关门状态

电吸门的工作原理是在门框(或门板边缘)上安装一个电磁线圈。当门打开时&#xff0c;电流会流过线圈&#xff0c;形成电磁场。这样&#xff0c;由于磁力的作用&#xff0c;当门靠近门框关闭时&#xff0c;门会自动关闭。 另外&#xff0c;电吸门也有有用的一面。如果下车&#…

在线ppt转pdf如何转换?用这一个方法就够了

在线PPT转PDF是一种将PPT文件转换为PDF格式的便捷且常用的工具。随着科技的发展&#xff0c;PPT已经成为了商务、教育等领域中最常用的演示工具之一。PDF格式具有较好的稳定性和兼容性。PPT文件可能因为不同的操作系统、软件版本或字体缺失等原因而导致显示不一致或乱码等问题&…

计算机竞赛 基于CNN实现谣言检测 - python 深度学习 机器学习

文章目录 1 前言1.1 背景 2 数据集3 实现过程4 CNN网络实现5 模型训练部分6 模型评估7 预测结果8 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于CNN实现谣言检测 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;学长非常推荐&am…

自动化测试之Selenium

自动化测试Selenium介绍环境搭建如何操作浏览器定位元素css类选择器定位元素xpath定位元素css选择语法xpath选择语法 常用操作添加等待打印信息浏览器更多操作键盘事件鼠标事件特殊场景只选复选框iframe标签下拉框处理弹窗显示上传文件 关闭浏览器切换窗口截图 自动化测试 自动…

一个程序员的工作日记--每天就干两件事,一年后让别人刮目相看

文章目录 成功源于专注一、早上布局二、晚上复盘三、技术细节四、专注与成功五、专注的重要性六、忙碌和赚钱七、结论以嵌入式开发为例&#xff1a;一、早上布局二、晚上复盘三、技术细节四、专注与成功五、忙碌和赚钱六、结论在嵌入式软件开发中&#xff0c;我们需要按照以下步…

elment-ui中使用el-steps案例

el-steps案例 样式 代码 <div class"active-box"><div class"active-title">请完善</div><el-steps :active"active" finish-status"success" align-center><el-step title"第一步" /><…

c语言练习题26:调整数组使奇数位于偶数前面

调整数组使奇数位于偶数前面 题目&#xff1a; 思路&#xff1a; 代码&#xff1a; #include<stdio.h> #include<string.h> void func(int* arr, int len) {int left 0;int right len - 1;while (left < right) {while (left < right && arr[lef…

【Vue框架】基本的login登录

前言 最近事情比较多&#xff0c;只能抽时间看了&#xff0c;放几天就把之前弄的都忘了&#xff0c;现在只挑着核心的部分看。现在铺垫了这么久&#xff0c;终于可以看前端最基本的登录了&#x1f602;。 1、views\login\index.vue 由于代码比较长&#xff0c;这里将vue和js…

7、Idea下载安装与激活

1、下载 1.1 官网地址 官网地址 https://www.jetbrains.com/idea/ 点击访问 1.2 官网首页 1.3 点击右上角dowload进入以下页面选择版本 1.4 选择需要的版本进行下载 2、安装

189. 轮转数组

189. 轮转数组 class Solution { public:void rotate(vector<int>& nums, int k) {int n nums.size();k k % n;reverse(nums.begin(),nums.end());reverse(nums.begin(),nums.begin()k);reverse(nums.begin()k,nums.end());} };