论文地址:https://arxiv.org/abs/2309.12307
github地址:https://github.com/dvlab-research/LongLoRA
1. 背景与挑战
大语言模型(LLMs)通常在预定义的上下文长度下进行训练,例如 LLaMA 的 2048 个 token 和 Llama2 的 4096 个 token。然而,这种预定义的上下文长度限制了模型在处理长文档或回答长问题时性能。
主要挑战:
- 计算成本高昂: 扩展上下文长度会导致自注意力机制的计算成本呈二次增长,显著增加训练时间和 GPU 内存需求。例如,将上下文长度从 2048 扩展到 8192 会使自注意力层的计算成本增加 16 倍。
- 现有微调方法的局限性:
- 全量微调: 虽然效果最佳,但计算成本过高,普通研究者难以承受。例如,Position Interpolation 需要 32 个 A100 GPU 来将 LLaMA 模型从 2k 扩展到 8k 上下文长度,更长的上下文则需要 128 个 A100 GPU。
- 低秩适应(LoRA): 虽然比全量微调更高效,但在长上下文扩展方面效果不佳,困惑度较高。
2. LongLoRA:高效扩展上下文长度的解决方案
LongLoRA 旨在以更低的计算成本高效地扩展预训练 LLMs 的上下文长度,同时保持与全量微调相近的性能。
2.1 主要创新点
-
Shifted Sparse Attention (ΔS²-Attn): 一种高效的注意力机制替代方案,用于在训练过程中近似标准自注意力。
- 工作原理:
- 将上下文长度分割成若干组,在每组内独立进行注意力计算。
- 在一半的注意力头中,将 token 序列偏移半个组大小,以确保相邻组之间的信息流动。
- 例如,使用组大小为 2048 的 ΔS²-Attn 来近似总长度为 8192 的上下文训练。
- 优势:
- 高效性: 显著降低计算成本,与标准注意力相比,训练速度提升高达 1.8 倍。
- 灵活性: 在推理时仍使用标准自注意力,因此可以复用现有的优化和基础设施,例如 Flash-Attention2。
- 易于实现: 仅需两行代码即可实现。
图 1:LongLoRA 的工作流程图,展示了在微调过程中引入的 Shifted Sparse Attention (ΔS²-Attn)。训练后的模型在推理时保留原始的标准自注意力。 - 工作原理:
-
改进的 LoRA: 通过使嵌入层和归一化层可训练,弥合了 LoRA 与全量微调之间的性能差距。
- 原因: 嵌入层和归一化层虽然在整个 LLM 中占比较小,但在长上下文扩展中起着关键作用。例如,在 Llama2 7B 中,归一化层的参数占比仅为 0.004%。
- 效果: 显著提升了长上下文扩展的性能,缩小了与全量微调的差距。
图 2:不同微调方法在长上下文扩展中的困惑度对比。可以看出,改进后的 LoRA (LoRA+) 显著缩小了与全量微调的差距。
2.2 实验结果
-
长序列语言建模: LongLoRA 在 PG19 和 proof-pile 数据集上的表现优于基线模型,表明其高效微调方法的有效性。
- 例如,将上下文窗口大小从 8192 增加到 32768,Llama2 7B 模型的困惑度从 2.72 降低到 2.50。
-
检索任务评估: 在 LongChat 引入的主题检索任务中,LongLoRA 模型与最先进的 LongChat-13B 性能相当,甚至在 16k 评估中略有优势。
-
长上下文基准评估: 在 LongBench 和 LEval 基准测试中,LongLoRA 7B 模型表现出与基于 Llama2 的长上下文模型(如 Vicuna 和 LongChat)相当或更好的性能。
-
效率分析:
- 训练成本: LongLoRA 相比全量微调,训练时间和 GPU 内存需求显著降低。
- 例如,在 65536 上下文长度下,LongLoRA 的训练时间仅为 LoRA 的 56.6%。
- FLOPs 消耗: ΔS²-Attn 显著降低了 FLOPs 消耗,尤其是在长上下文长度下。
图 3:不同上下文长度下 Llama2 7B 模型的 FLOPs 消耗分解。可以看出,随着上下文长度的增加,注意力计算的比例急剧增加,而 ΔS²-Attn 有效降低了注意力计算的 FLOPs。 - 训练成本: LongLoRA 相比全量微调,训练时间和 GPU 内存需求显著降低。
3. 结论
LongLoRA 提供了一种高效扩展 LLMs 上下文长度的方法,具有以下优势:
- 计算效率高: 相比标准全量微调,GPU 内存需求和训练时间更少,同时精度损失最小。
- 架构兼容性: ΔS²-Attn 易于实现,训练后的模型在推理时保留原始的标准注意力架构,便于复用现有基础设施。
- 性能优越: 在长上下文扩展方面,LongLoRA 实现了与全量微调相近的性能。
4. 未来展望
LongLoRA 是一种通用方法,未来可以与更多类型的 LLMs 和位置编码技术相结合,进一步提升其适用性和性能。
5. 补充说明
- 实验设置: 所有实验均在单个 8x A100 机器上进行,使用 PyTorch、DeepSpeed 和 Flash-Attention2 进行训练。
- 数据集: 使用 Redpajama 数据集进行训练,在 PG19 和 proof-pile 数据集上进行评估。
- SFT(监督微调): 为了提高模型的问答能力,LongLoRA 进一步使用自收集的长指令跟随数据集 LongAlpaca 进行 SFT。
通过这些创新和实验结果,LongLoRA 为高效扩展 LLMs 的上下文长度提供了新的思路和方法。