揭秘LoRA与QLoRA:百次实验告诉你如何微调LLM!

原文链接:揭秘LoRA与QLoRA:百次实验告诉你如何微调LLM!​​​​​​​

LoRA(低秩适应)是目前应用最广泛、参数效率最高的自定义大型语言模型(LLM)微调技术之一。本文不仅介绍了使用QLoRA节省内存的方法,还讨论了选择最佳LoRA设置的实用技巧,为有兴趣应用此技术的读者提供了实践洞见。

如何充分利用LoRA

过去几个月里,我进行了数百次甚至上千次涉及LoRA的实验。几周前,我花时间深入研究了一些超参数的选择。

这篇文章更像是按时间顺序呈现的实验日记。我希望它对一些人有用。具体来说,我想回答关于QLoRA价值的问题,是否应该用SGD替换AdamW,潜在的使用调度器,以及如何调整LoRA超参数。

关于实验方面有很多内容需要讨论,因此我会简短介绍LoRA。

简而言之,LoRA(Hu等人,2021年提出的低秩适应)通过向模型添加少量可训练参数,同时保持原始模型参数不变,实现了功能。

LoRA通过将一个大的权重矩阵分解为两个较小的权重矩阵,如下图所示,以更高的参数效率近似实现完全的有监督微调。

评估任务和数据集

本文的重点是选择最优设置。为了保持合理的范围,我固定了数据集,仅专注于大型语言模型(LLM)的有监督指令微调。

对于模型评估,我从Eleuther AI的评估工具包中选取了一小部分任务,包括TruthfulQA、BLiMP因果关系、MMLU全球事实,以及两位数(算术2ds)和四位数(算术4ds)的简单算术任务。

在每个基准测试中,模型性能得分在0到1之间标准化,其中1为满分。TruthfulQA报告两个得分,定义如下:

  • MC1(单一真实):给定一个问题和4-5个答案选项,选择唯一正确的答案。模型的选择是它分配给紧随问题之后最高对数概率完成的答案选项,独立于其他答案选项。分数是所有问题的简单准确率。
  • MC2(多重真实):给定一个问题和多个真/假参考答案,得分是分配给一组真实答案的标准化总概率。

作为参考,175B GPT-3模型的TruthfulQA MC1和MC2值分别为0.21和0.33。

下面是两个例子,用以说明算术2ds和算术4ds之间的区别:

  • 算术2ds:“59减38是多少?”答案:“21”。
  • 算术4ds:“2762加2751是多少?”答案:“5513”。

如上所述,固定了数据集,使用了广为研究或常用的Alpaca数据集进行有监督指令微调。当然,还有许多其他用于指令微调的数据集,包括LIMA、Dolly、LongForm、FLAN等。然而,未来的研究中,探索在多个数据集和数据集组合上的训练将是一个有趣的话题。

数据集样例数据如下图所示:

代码框架

Lit-GPT:GitHub - Lightning-AI/lit-gpt: Hackable implementation of state-of-the-art open-source LLMs based on nanoGPT. Supports flash attention, 4-bit and 8-bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.

我在这篇文章中使用的自定义大型语言模型(LLM)微调代码基于开源的Lit-GPT仓库。为了使文章的前言简洁,我不会深入讨论使用细节,但你可以在Lit-GPT教程部分找到更详细的指南。

简要来说,使用方法如下:

  1. 克隆相关仓库和安装相关依赖
git clone https://github.com/Lightning-AI/lit-gptcd lit-gptpip install -r requirements.txt
  1. 下载模型ckpt文件
python scripts/download.py \
--repo_id mistralai/Mistral-7B-Instruct-v0.1
# there are many other supported models
python scripts/convert_hf_checkpoint.py \
--checkpoint_dir checkpoints/mistralai/Mistral-7B-Instruct-v0.1
  1. 数据准备
python scripts/prepare_alpaca.py \
--checkpoint_dir checkpoints/mistralai/Mistral-7B-Instruct-v0.1
# or from a custom CSV file
python scripts/prepare_csv.py \
--csv_dir MyDataset.csv \
--checkpoint_dir checkpoints/mistralai/Mistral-7B-Instruct-v0.1
  1. 进行监督微调
python finetune/lora.py \
--checkpoint_dir checkpoints/mistralai/Mistral-7B-Instruct-v0.1/ \
--precision bf16-true
  1. 将Lora权重合并到原始模型上
python scripts/merge_lora.py \
--checkpoint_dir "checkpoints/mistralai/Mistral-7B-Instruct-v0.1" \
--lora_path "out/lora/alpaca/Mistral-7B-Instruct-v0.1/lit_model_lora_finetuned.pth" \
--out_dir "out/lora_merged/Mistral-7B-Instruct-v0.1/"cp checkpoints/mistralai/Mistral-7B-Instruct-v0.1/*.json \
out/lora_merged/Mistral-7B-Instruct-v0.1/
  1. 效果评估
python eval/lm_eval_harness.py \
--checkpoint_dir "out/lora_merged/Mistral-7B-Instruct-v0.1/" \
--eval_tasks "[arithmetic_2ds, ..., truthfulqa_mc]" \
--precision "bf16-true" \
--batch_size 4 \
--num_fewshot 0 \
--save_filepath "results.json"
  1. 模型使用
python chat/base.py \ 
--checkpoint_dir "out/lora_merged/Mistral-7B-Instruct-v0.1/"

选择一个好的基础模型

首先,我需要为LoRA实验选择一个合适的基础模型。在此,我关注的是那些尚未经过指令微调的模型:phi-1.5 1.3B、Mistral 7B、Llama 2 7B、Llama 2 13B和Falcon 40B。值得注意的是,所有实验都是在单个A100 GPU上运行的。

从上表我们可以看出,Mistral 7B模型在数学基准测试上表现非常出色。与此同时,考虑到其相对较小的规模,phi-1.5 1.3B模型在TruthfulQA MC2上展现了令人印象深刻的性能。出于某种原因,Llama 2 13B在算术基准测试中表现欠佳,而较小的Llama 2 7B在这方面的表现显著优于它。

由于研究人员和从业者目前推测phi-1.5 1.3B和Mistral 7B可能已在基准测试数据上进行了训练,所以我选择不在我的实验中使用它们。此外,我认为选择剩余模型中最小的一个将在保持较低硬件要求的同时提供最大的改进空间。因此,本文的剩余部分将聚焦于Llama 2 7B。

评估LoRA的默认设置

首先,我使用以下默认设置评估了LoRA的微调(这些设置可以在finetune/lora.py脚本中更改):

Lit-GPT: GitHub - Lightning-AI/lit-gpt: Hackable implementation of state-of-the-art open-source LLMs based on nanoGPT. Supports flash attention, 4-bit and 8-bit quantization, LoRA and LLaMA-Adapter fine-tuning, pre-training. Apache 2.0-licensed.

# Hyperparameters
learning_rate = 3e-4
batch_size = 128
micro_batch_size = 1
max_iters = 50000  # train dataset size
weight_decay = 0.01
lora_r = 8
lora_alpha = 16
lora_dropout = 0.05
lora_query = True
lora_key = False
lora_value = True
lora_projection = False
lora_mlp = False
lora_head = False
warmup_steps = 100

(请注意,批处理大小为128,但我们使用带有1个微批处理的梯度累积来节省内存;这导致了与常规使用128批处理大小的训练相同的训练轨迹。)

这个配置训练了4,194,304个LoRA参数,总共有6,738,415,616个可训练参数,并且在我使用单个A100的机器上大约花费了1.8小时。最大内存使用量为21.33 GB。

为了衡量差异,我重复进行了三次实验,观察了不同运行之间性能的波动。

正如我们在上表中看到的,不同运行之间的性能非常一致和稳定。同样值得注意的是,LoRA默认模型在算术任务上表现非常差,但这可能是因为据我所知,Alpaca数据集并没有(或很少有)算术任务。

此外,我还研究了Meta使用RLHF对7B Llama 2版本进行指令微调后的模型。根据下表,Meta的Llama 2 Chat模型在算术性能上也更差。然而,Chat模型在其他基准测试(除BLiMP外)上有了显著改进,我们可以将其作为我们想要通过LoRA微调接近的参考。

使用QLoRA节省内存

在我们开始调整LoRA超参数之前,我想探索QLoRA(Dettmers等人提出的流行的量化LoRA技术)在模型性能和内存节省之间的权衡。

我们可以通过在Lit-GPT中使用–quantize标志(这里使用4位正常浮点类型)来启用QLoRA,如下所示:

此外,我还尝试了4位浮点精度作为对照。以下是对训练时间和最大内存使用量的影响:

默认LoRA(使用bfloat-16):

  • 训练时间:6685.75秒
  • 内存使用:21.33 GB

通过–-quantize “bnb.nf4”启用的QLoRA:

  • 训练时间:10059.53秒
  • 内存使用:14.18 GB

通过–quantize “bnb.fp4”启用的QLoRA:

  • 训练时间:9334.45秒
  • 内存使用:14.19 GB

我们可以看到,QLoRA将内存需求减少了近6 GB。然而,代价是训练时间延长了30%,这是由于额外的量化和反量化步骤所致。

接下来,让我们看看QLoRA训练如何影响模型性能:

从上表中可以看出,与常规QLoRA相比,QLoRA对模型性能确实有一些影响。模型在算术基准测试中有所改进,但在MMLU全球事实基准测试中有所下降。

由于内存节省相当可观(这通常会超过较长的训练时间,因为它允许用户在较小的GPU上运行模型),我将在本文的其余部分使用QLoRA。

学习率调度器和SGD

我在之前的所有实验中都使用了AdamW优化器,因为它是LLM训练的常见选择。然而,众所周知,Adam优化器可能非常占用内存。这是因为它为每个模型参数引入并跟踪两个额外的参数(动量m和v)。大型语言模型(LLM)有许多模型参数;例如,我们的Llama 2模型有70亿个模型参数。

本节将探讨用SGD优化器替换AdamW是否值得。然而,对于SGD优化器,引入学习率调度器尤为重要。我选择了一个余弦退火调度,它在每次批量更新后降低学习率。

不幸的是,将AdamW替换为SGD只节省了少量内存。

  • AdamW:14.18 GB
  • SGD:14.15 GB

这可能是因为大部分内存被用于大型矩阵乘法,而不是存储额外的参数。

但这种小差异或许是意料之中的。在当前选择的LoRA配置(r=8)下,我们有4,194,304个可训练参数。如果Adam为每个模型参数添加2个额外值,并且以16位浮点数存储,那么我们有4,194,304 * 2 * 16位 = 134.22兆比特 = 16.78兆字节。

当我们将LoRA的r增加到256时,我们可以观察到更大的差异,这一点我们稍后会做。在r=256的情况下,我们有648,871,936个可训练参数,使用上述同样的计算方法,相当于2.6 GB。实际测量结果显示有3.4 GB的差异,可能是由于存储和复制优化器状态的一些额外开销。

底线是,对于少量的可训练参数,例如在LoRA和低r(秩)值的情况下,与预训练相比,其中我们训练了更多的参数,使用SGD替换AdamW的内存收益可能非常小。

尽管SGD在这里没有提供显著的内存节省,但让我们还是快速看一下结果模型的性能:

看来,SGD优化器的性能与AdamW相当。有趣的是,当向AdamW添加调度器时,在TruthfulQA MC2和MMLU全球事实性能上有所提高,但算术性能有所下降。(注:TruthfulQA MC2是其他公共排行榜上广为认可的基准测试。)目前,我们不会过多强调算术性能,将在本文的剩余实验中使用带调度器的AdamW。

如果您想复制这些实验,我发现最佳的AdamW学习率是3e-4,衰减率为0.01。最佳的SGD学习率是0.1,动量为0.9。在这两种情况下,我都使用了额外的100步学习率热身。

(基于这些实验,余弦调度器已被添加到Lit-GPT中,并且现在默认启用。)

多次迭代数据集

到目前为止,我已经用50k次迭代训练了所有模型——Alpaca数据集有50k个训练示例。一个明显的问题是,我们是否可以通过多次迭代训练集来提高模型性能,所以我用100k次迭代运行了之前的实验,这是两倍的增加:

有趣的是,增加的迭代次数导致了整体性能的下降。下降最显著的是算术基准测试。我的假设是,Alpaca数据集不包含任何相关的算术任务,当模型更多地关注其他任务时,它会主动忘记基本的算术运算。

不管怎样,如果我说这个结果不令人欣慰,那是撒谎。这样一来,我可以在本文的剩余部分继续进行较短的50k次迭代实验。

LoRA超参数调整第一部分:对所有层启用LoRA

既然我们已经探索了围绕LoRA微调脚本的基本设置,现在让我们关注LoRA超参数本身。默认情况下,LoRA只针对多头自注意力块中的Key和Query矩阵启用。现在,我们还将其用于Value矩阵、投影层和线性层:

LoRA超参数调整第二部分:增加R

LoRA参数中最重要的一个是“r”,它决定了LoRA矩阵的秩或维度,直接影响模型的复杂度和容量。较高的“r”意味着更强的表达能力,但可能导致过拟合,而较低的“r”可以减少过拟合,但代价是表达能力的降低。保持对所有层启用LoRA,我们将r从8增加到16,看看这对性能有什么影响:

我们可以看到,仅仅增加r本身使结果变差了,那么发生了什么呢?让我们在下一节中找出答案。

LoRA超参数调整第三部分:改变Alpha

在上一节中,我们增加了矩阵秩r,而保持LoRA的alpha参数不变。较高的“alpha”将更多地强调低秩结构或正则化,而较低的“alpha”将减少其影响,使模型更多地依赖原始参数。调整“alpha”有助于在拟合数据和通过正则化模型来防止过拟合之间找到平衡。

作为一个经验法则,微调LLM时通常选择一个alpha,其大小是秩的两倍(注意,这在处理扩散模型时有所不同)。让我们尝试一下,看看将alpha增加一倍会发生什么:

我们可以看到,将alpha增加到32现在产生了迄今为止最好的模型!但是我们又以更多的可训练参数为代价获得了这一改进:

r=8:

  • 可训练参数数量:20,277,248
  • 不可训练参数数量:6,738,415,616
  • 内存使用量:16.42 GB

r=16:

  • 可训练参数数量:40,554,496
  • 不可训练参数数量:6,738,415,616
  • 内存使用量:16.47 GB

然而,可训练参数的数量仍然足够小,以至于不会明显影响峰值内存需求。

无论如何,我们现在终于开始取得一些成果,通过更明显的幅度改进模型性能。那么,让我们继续前进,看看通过增加秩和alpha能够达到多远:

我还进行了一些使用异常大的秩(512、1024和2048)的额外实验,但这些实验的结果较差。有些运行甚至在训练期间没有收敛到接近零的损失,这就是为什么我没有将它们添加到表格中。

到目前为止,我们可以注意到最后一行的r=256和alpha=512模型在总体上表现最佳。作为额外的对照实验,我重复了使用alpha为1的运行,并注意到大的alpha值对于良好的性能确实是必要的:

我还重复了使用alpha值为16和32的实验,我观察到与选择alpha值为秩的两倍相比,性能同样更差。

LoRA超参数调整第四部分:非常大的R

对于本文的最后一个调整实验,我想进一步优化上一节中最佳模型的alpha值(r=256,最后一行),怀疑它可能有点过大。

正如上表所示,当增加秩时,选择较大的alpha值似乎是至关重要的。

对于r=256和a=512的QLoRA模型,很明显我们的模型相比基础模型有了显著的改进。唯一的区域是微调模型与基础模型相比在四位数算术上的表现不足。然而,考虑到Alpaca数据集可能没有包含这样的训练示例,这是可以理解的。

上面我们看到,选择alpha为秩的两倍(例如,r=256和alpha=512)的常见建议确实产生了最佳结果,较小的alpha值导致了更差的结果。但是,将alpha增加到“秩的两倍”建议之外会怎样呢?

根据上表提供的结果,选择alpha值超过“秩的两倍”建议也使基准测试结果变差。

结论

本文探索了使用LoRA训练自定义LLM时可以调整的各种设置。我们发现QLoRA是一个很好的内存节省器,尽管它增加了运行时间成本。此外,尽管学习率调度器可能有益,但在AdamW和SGD优化器之间选择影响不大。而且,多次迭代数据集甚至可能使结果更糟。通过优化LoRA设置(包括秩)可以获得最佳性价比。增加秩将导致更多的可训练参数,可能导致更高程度的过拟合和运行成本。然而,增加秩时选择合适的alpha值很重要。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/235243.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用判断对齐大语言模型

1、写作动机: 目前的从反馈中学习方法仅仅使用判断来促使LLMs产生更好的响应,然后将其作为新的示范用于监督训练。这种对判断的间接利用受到无法从错误中学习的限制,这是从反馈中学习的核心精神,并受到LLMs的改进能力的制约。 2…

科学和统计分析软件GraphPad Prism mac介绍说明

GraphPad Prism for Mac是一款科学和统计分析软件,旨在帮助研究者、科学家和学生更轻松地处理和可视化数据。 GraphPad Prism for Mac是一款功能强大、易于使用的科学和统计分析软件,适用于各种类型的数据处理和可视化需求。无论您是进行基础研究、临床试…

给自己创建的GPTs添加Action(查天气)

前言 在这篇文章中,我将分享如何利用ChatGPT 4.0辅助论文写作的技巧,并根据网上的资料和最新的研究补充更多好用的咒语技巧。 GPT4的官方售价是每月20美元,很多人并不是天天用GPT,只是偶尔用一下。 如果调用官方的GPT4接口&…

Django的数据库模型的CharField字段的max_length参数与中文字符数的关系探索(参数max_length的单位是字符个数还是字节数?)

01-清理干净之前的数据库迁移信息 02-根据setting.py中的信息删除掉之前建立的数据库 03-删除之后重新创建数据库 04-models.py中创建数据库模型 from django.db import modelsclass User(models.Model):username models.CharField(max_length4)email models.EmailField(uni…

接口测试管理续集

今天应大家需要,接着谈app端数据返回层面的用例设计方法。第二部分给大家安利一个“接口管理平台”,以帮助大家解决接口文档维护、接口测试数据Mock、接口自动化测试等问题。希望对小伙伴们有用。 言归正传,进入今天的话题。 一、用例设计 …

深度学习算法应用实战 | 利用 CLIP 模型进行“零样本图像分类”

文章目录 1. 零样本图像分类简介1.1 什么是零样本图像分类?1.2 通俗一点的解释 2. 模型原理图3. 环境配置4. 代码实战5. Gradio前端页面5.1 什么是 Gradio ? 6 进阶操作7. 总结 1. 零样本图像分类简介 1.1 什么是零样本图像分类? “零样本图像分类”(Zero-shot …

自学Python,需要注意哪些?

为什么要学习Python? 在学习Python之前,你不要担心自己没基础或“脑子笨”,我始终认为,只要你想学并为之努力,就能学好,就能用Python去做很多事情。在这个喧嚣的时代,很多技术或概念会不断兴起…

配置安装nginx

目录 一、yum安装 1.进入nginx官网 nginx.org 2、进入下载列表 以主线版为例 3、服务器安装工具包集合 4、设置 yum 存储库 5、配置成功后,默认情况下,使用稳定的 nginx 包的存储库。如果要使用主线 nginx 包,需要运行以下命令&#xf…

Asp .Net Web应用程序(.Net Framework4.8)网站发布到IIS

开启IIS 如果已开启跳过这步 打开控制面板-程序 打开IIS 发布Web程序(.Net Framework 4.8 web网页) 进入IIS管理器新建一个应用池 新建一个网站 网站创建完毕 为文件夹添加访问权限 如果不添加访问权限,运行时将会得到如下错误 设置权限 勾…

PHP开发日志 ━━ 不同方法判断某个数组中是否存在指定的键名,测试哪种方法效率高

我们可以用isset($arr[a]) 或者 array_key_exists(a, $arr) 来判断a键名是否存在与$arr数组。 那么这两种方式哪个运行速度快呢? 不多废话了,现在我们写一段代码来测试一下: $array [a > 1, b > 2, c > 3];$start microtime(tru…

YOLOv8优化策略:轻量化改进 | 华为Ghostnet,超越谷歌MobileNet | CVPR2020

🚀🚀🚀本文改进:Ghost bottleneck为堆叠Ghost模块 ,与YOLOV8建立轻量C2f_GhostBottleneck 🚀🚀🚀YOLOv8改进专栏:http://t.csdnimg.cn/hGhVK 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定科研; 1.Ghostnet介绍 论文: https://arxiv.org/pdf/1911.11907.…

Logstash:迁移数据到 Elasticsearch

在生产环境中,不使用 Apache Kafka 等流平台进行数据迁移并不是一个好的做法。 在这篇文章中,我们将详细探讨 Apache Kafka 和 Logstash 的关系。 但首先让我们简单了解一下 Apache Kafka 的含义。 Apache Kafka 是分布式流平台,擅长实时数据…

机器学习笔记一之入门概念

目录 一 基本分类二 按模型分类概率模型(Probabilistic Models)非概率模型(Non-Probabilistic Models)对比结论线性模型 (Linear Models)非线性模型 (Non-linear Models)对比 三 按算法分类1.批量学习(Batch Learning&…

前端开发Docker了解

1,docker简介 docker主要解决了最初软件开发环境配置的困难,完善了虚拟机部署的资源占用多,启动慢等缺点,保证了一致的运行环境,可以更轻松的维护和扩展。docker在linux容器的基础上进行了进一步的封装,提…

电脑USB接口不同颜色的含义

当你看到笔记本电脑或台式机的USB端口时,你会发现USB端口的颜色很多;这些颜色可不只是为了好看,实际上不同颜色代表着不同的性能,那么这些带颜色的USB端口都是什么含义呢,下面就具体介绍下不同颜色代表的含义。-----吴…

僵尸毁灭工程手动存档工具

介绍 这是一个可以对僵毁游戏存档进行备份的小工具,其基本原理是对僵毁存档中数以万计的小文件做哈希值计算并保存下来,下一次备份时再对存档文件进行哈希值计算,每次备份只对两次计算结果中存在差异的文件进行复制与替换从而忽略掉大部分未…

1.10 Unity中的数据存储 JSON

一、介绍 Json是最常用也是目前用的比较多的一种,超轻量级,可便捷性使用,平时用到比较多的都是解析Json和往Json中添加数据、修改数据等等JSON(JavaScript Object Notation,JS对象标记)是一种轻量级的数据交换格式,它基于ECMAScr…

Spark---RDD序列化

文章目录 1 什么是序列化2.RDD中的闭包检查3.Kryo 序列化框架 1 什么是序列化 序列化是指 将对象的状态信息转换为可以存储或传输的形式的过程。 在序列化期间,对象将其当前状态写入到临时或持久性存储区。以后,可以通过从存储区中读取或反序列化对象的…

[算法与数据结构][c++]:Static关键字和全局变量

Static关键字和全局变量 1. 生命周期、作用域和初始化时机2. 全局变量3. Static 关键字3.1 面向过程3.1.1 静态全局变量3.1.2 静态局部变量(单例中会使用)3.1.3 静态函数 3.2 面向对象3.2.1 类内静态成员变量3.2.2 类内静态成员函数 Reference 写在前面&…

详细讲解MybatisPlus实现逻辑删除

目录 前言1. 基本知识2. 实战Demo3. 拓展 前言 对于MybatisPlus的相关知识可在我的博客进行搜索 对应的CRUD相关知识也可看我这篇文章:【Java项目】实战CRUD的功能整理(持续更新) 在讲述逻辑删除这个概念时,先引入另外一个概念&…