聊聊拉长LLaMA的一些经验

Sequence Length是指LLM能够处理的文本的最大长度,越长,自然越有优势:

  1. 更强的记忆性。更多轮的历史对话被拼接到对话中,减少出现遗忘现象

  2. 长文本场景下体验更佳。比如文档问答、小说续写等

当今开源LLM中的当红炸子鸡——LLaMA,第一版上下文长度是2048,第二版长度是4096。相比之下ChatGPT、GPT4已经支持到16k,Claude甚至支持到了100k。足以见得将LLaMA拉长是如此的任重而道远。本文将会介绍三种在旋转位置编码(RoPE)基础上扩充上下文的高性价比方案,在文末会介绍我的实践经验。

线性插值法

Kaiokendev的博客[1]中提到了方法,和Meta的一篇工作[2]不谋而合,其思想主要是将目标长度压缩到原始长度。如下图所示,LLaMA-1预训练的长度为2048,如果我们想把它拉长到4096:

  • 方法一:推理时直接拉长到4096。这考虑位置编码的外推性(即在短文本上训练,长文本上推理的能力[2]),而RoPE的外推性则是相当一般[2]。由于训练时长度都是小于2048的,超过2048部分Attention分数会飙升,导致困惑度急剧上升。

  • 方法二:在原始模型基础上做长度为4096的继续训练。这里先岔开介绍另一款模型——MPT-30B的做法,根据官方博客[3]的介绍:

    As mentioned earlier, MPT-30B was trained with a long context window of 8k tokens (vs. 2k for LLaMa and Falcon) and can handle arbitrarily long context windows via ALiBi or with fine-tuning. To build 8k support into MPT-30B efficiently, we first pre-trained on 1T tokens using sequences that were 2k tokens long, and continued training for an additional 50B tokens using sequences that were 8k tokens long.

    MPT-30B采用ALiBi位置编码(外推性优于RoPE),在2k的长度进行1T token的训练,然后在8k长度上进行50B token的预训练——这是在外推性强于RoPE的ALiBi上的情况。LLaMA-1预训练的token数是1T以上,想要在长度为4096样本上效果不下降,那需要训练足够多的token数才行,这就需要较大的成本了。

  • 方法三:另一种思路则是将4096的位置编码通过线性插值法压缩到2048内,这样只需要在少量的长度为4096的数据上进行继续预训练,便可达到不错的效果。

来自论文[2]

来自论文[2]

代码实现

线性插值法的实现代码相当的简单,这需要在原始RoPE上进行微小的改动,即加上下图的scale参数。

来自[7],scaled_rope/LlamaLinearScaledRotaryEmbedding.py

来自[7],scaled_rope/LlamaLinearScaledRotaryEmbedding.py

效果

Meta的工作[2]中进行了充足实验和公式推导证明,如果想看具体的代码,建议看lmsys.org(Vicuna的出品方)的一篇工作[4]:How Long Can Open-Source LLMs Truly Promise on Context Length? 他们对比了商用模型、号称支持长文本的开源模型和“Vicuna+线性插值法”的效果,并给出了几个结论:

  1. 商用模型在长文本的效果上很能打!而那些号称支持长文本的开源模型,在长文本上则表现不佳。

  2. 随着文本长度的增加,越接近边界,Vicuna+线性插值法的效果降低越明显。这可能是因为训练数据存在短文本的情况。

来自[4]

来自[4]

来自[4]

来自[4]

写这篇文章的同时,ChatGLM团队更新了ChatGLM2-6B-32K,也是使用了插值法。同时推出了长文本的中英评测集LongBench,在这个评测集上ChatGLM2-6B-32K展示了强大的实力,但值得注意的是,该评测集的评测方式是使用ChatGLM2-6B来进行评估的。

NTK插值法

NTK插值法的提出于一篇Reddit帖子[5],它提出使用Neural Tangent Kernel (NTK)来解决这个问题。

if you apply Neural Tangent Kernel (NTK) theory to this problem, it becomes clear that simply interpolating the RoPE's fourier space "linearly" is very sub-optimal, as it prevents the network to distinguish the order and positions of tokens that are very close by. Borrowing from NTK literature, scaling down the fourier features too much will eventually even prevent succesful finetunes (this is corroborated by the recent paper by Meta that suggests an upper bound of ~600x)

Instead of the simple linear interpolation scheme, I've tried to design a nonlinear interpolation scheme using tools from NTK literature. Basically this interpolation scheme changes the base of the RoPE instead of the scale, which intuitively changes the "spinning" speed which each of the RoPE's dimension vectors compared to the next. Because it does not scale the fourier features directly, all the positions are perfectly distinguishable from eachother, even when taken to the extreme (eg. streched 1million times, which is effectively a context size of 2 Billion)

帖子中作者还用了时钟的例子来解释线性插值和NTK插值的异同:

RoPE behaves like a clock. Your 12 hours wall clock is basically a RoPE of dimension 3 with a base of 60. So for each second, the minute hand turns 1/60th of a minute, and for each minute, the hour hand turns 1/60th.

Now if you slowed down time by a factor of 4x, that is a linear RoPE scaling used in SuperHOT. Unfortunately now it is really really hard to distinguish each second, because now the seconds hand barely moves each second. So if someone gave you two different times, which is only different by a single second, you won't be able to distinguish them from afar (let's say the NNs have myopia because that's basically what NTK predicts)

Now NTK-Aware RoPE scaling does not slow down the seconds. One second is still one second, but it slows down minutes by a factor of let's say 1.5, and the hours by a factor of 2. This way you can fit 90 minutes in a hour, and fit 24 hours in half a day. So now you basically have a clock that can measure 129.6k seconds instead of 43.2k seconds.

Because you don't need a precise measurement of the hour hand when looking at the time, scaling the hours more compared to seconds is crucial. You don't want to lose the precision of the seconds hand, but you can afford to lose precision on the minutes hand and even more on the hours hand.

Then, it's just a matter of deriving the base change formula in order to obtain such a scaling. (where less precise dimensions are scaled more and more)

代码实现

NTK的实现则更加简单了,根据超参数alpha,对应修改base变量即可:

来自[7],scaled-rope/scaled_rope/LlamaNTKScaledRotaryEmbedding.py

来自[7],scaled-rope/scaled_rope/LlamaNTKScaledRotaryEmbedding.py

效果

在效果上,帖子中也给出了NTK插值法和线性插值法的PPL比较,可以看到,在二者都不做Finetune的情况下,NTK插值法具备更低的PPL。

来自[5]

来自[5]

动态插值法

动态插值法同样出自于一篇Reddit帖子[6],它的出发点很简单:

My idea was to use the exact position values for the first 2k context (after all, why mess with a good thing?) and then re-calculate the position vector for every new sequence length as the model generates token by token.

这种做法可以和先前的两种方法相结合,[7]中也给出了详细的代码实现。

效果

作者和前两种方法做了对比,展示了动态插值法在PPL下降上的优势。

实践经验

我在实践的过程中,评估效果主要使用longChat[4]中使用的评估方式,以下是一些takeaway tips,欢迎大家一起交流。

  1. 线性插值法具备完整的理论支持和大量的实验证明,在我的实践中,“线性插值法+Finetune”取得了最佳效果。

  2. NTK插值法的实验中,对比的是不做Finetune的情况,在我的实践中,“NTK插值+Finetune”效果会明显优于单独的“NTK插值”,但它的收敛速度会慢于“线性插值法+Finetune”。

  3. 动态插值法的实验同样是在不做Finetune的情况对比的,目前为止我并没有尝试过这种方法。在Reddit的评论区有人提出一个很好的问题:如果采取这种方法,逐token推理时,文本的长度是在变化的,则导致无法使用kv-cache,这会对性能产生很大的影响。

最后,拉长LLaMA的方案可以不从RoPE入手(如:LongLLaMA[8]),但“线性插值法+Finetune”无疑是一种性价比很高的方案,推荐大家尝试!

——2023.07.31

Reference

[1] Extending Context is Hard…but not Impossible†

[2] EXTENDING CONTEXT WINDOW OF LARGE LANGUAGE MODELS VIA POSITION INTERPOLATION

[3] MPT-30B: Raising the bar for open-source foundation models

[4] How Long Can Open-Source LLMs Truly Promise on Context Length?

[5] NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.

[6] Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning

[7] GitHub - jquesnelle/scaled-rope

[8] Focused Transformer: Contrastive Training for Context Scaling

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/76629.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用docker部署Wordpress

文章目录 1.创建网络2.创建volume存储3.拉取镜像4.创建mysql容器mysql修改密码 5.创建wordpress容器6.访问localhost:80就可以直接使用啦 1.创建网络 docker network create --subnet172.18.0.0/24 pro-net2.创建volume存储 # mysql 存储 docker volume create volume_mysql…

latex插入不连续的中线

在上面的示例中,\cmidrule(lr){1-3} 表示在第 1 列到第 3 列之间添加不连续的中线,\cmidrule(lr){4-6} 表示在第 4 列到第 6 列之间添加不连续的中线,以此类推。你可以根据需要调整列的范围和对齐方式。

面试总结-Redis篇章(十)——Redis哨兵模式、集群脑裂

Redis哨兵模式、集群脑裂 哨兵模式哨兵的作用服务状态监控 Redis集群(哨兵模式)脑裂解决办法 哨兵模式 为了保证Redis的高可用,Redis提供了哨兵模式 哨兵的作用 服务状态监控 Redis集群(哨兵模式)脑裂 假设由于网络原…

指针进阶详解续---C语言

❤博主CSDN:啊苏要学习 ▶专栏分类:C语言◀ C语言的学习,是为我们今后学习其它语言打好基础,C生万物! 开始我们的C语言之旅吧!✈ 目录 前言: 一.函数指针数组 二.指向函数指针数组的指针 三.回调函数 …

Linux系统部署Python语言开发运行环境

目录 Ubuntu自带python Debian安装python 安装 pip 库列表 安装第三方库 使用国内镜像站 实装 tkinter 库 编写运行代码 测试代码1 1. 创建项目 2. 创建源码文件 3. 写入源代码 4. 修改权限 5. 运行代码 测试代码2 本文的使用环境是Windows的Linux 子系统&…

微前端中的 CSS

本文为翻译 本文译者为 360 奇舞团前端开发工程师原文标题:CSS in Micro Frontends 原文作者:Florian Rappl 原文地址:https://dev.to/florianrappl/css-in-micro-frontends-4jai 我被问得最多的问题之一是如何在微前端中处理 CSS。毕竟&…

键入网址到网页显示,期间发生了什么

HTTP 浏览器做的第一步工作是解析URL 首先浏览器做的第一步工作就是要对URL进行解析,从而生成发送给 web 服务器的请求信息。 所以图中长长的URL实际上是请求服务器里的文件资源。 如果图中的蓝色部分URL元素省略了,那应该请求哪个文件呢? 当…

HTML5注册页面

分析 注册界面实际上是一个表格&#xff08;对齐&#xff09;&#xff0c;一行有两个单元格。 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevic…

Linux:在使用UEFI固件的计算机上内核是如何被启动的

前言 启动计算机通常不是一件难事&#xff1a;按下电源键&#xff0c;稍等片刻&#xff0c;你就能看到一个登录界面&#xff0c;再输入正确的密码&#xff0c;就可以开启一天的网上冲浪之旅了。 但偶尔这件事没那么顺利&#xff0c;有时候迎接你的不是熟悉的登录界面&#xf…

java 数组的使用

数组 基本介绍 数组可以存放多个同一类型的数据&#xff0c;数组也是一种数据类型&#xff0c;是引用类型。 即&#xff1a;数组就是一组数据。 数组的使用 1、数组的定义 方法一 -> 单独声明 数据类型[] 数组名 new 数据类型[大小] 说明&#xff1a;int[] a new int…

上海亚商投顾:沪指震荡微涨 金融、地产午后大幅走强

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 市场情绪 三大指数早盘震荡&#xff0c;午后集体拉升反弹&#xff0c;创业板指涨超1%。券商等大金融板块午后再度走强&#…

学习线性表:掌握基本操作和应用

线性表 前言 欢迎来到本博客的线性表部分&#xff01;&#x1f604;&#x1f389;在这篇博文中&#xff0c;我将带您深入探索数据结构中的线性表&#xff0c;这是计算机科学中最基础也是最重要的数据结构之一。 线性表是一种线性数据结构&#xff0c;它由一组连续的元素组成…

socket server服务器开发常见的并发模型

两种高效的事件处理模式 服务器程序通常需要处理三类事件&#xff1a;I/O 事件、信号及定时事件。有两种高效的事件处理模式&#xff1a;Reactor和 Proactor&#xff0c;同步 I/O 模型通常用于实现Reactor 模式&#xff0c;异步 I/O 模型通常用于实现 Proactor 模式。 无论是 …

Flink开发环境准备: centos-jdk8

linux-jdk8 - Flink开发环境准备 一、基本介绍二、环境准备1.1 JDK环境1.2 开发工具1.3 Maven环境 三、flink下载安装配置3.1 Flink下载3.2 flink本地模式安装 - linux3.3 常用配置3.4 日志的查看和配置 四、单机 Standalone 的方式运行 Flink 一、基本介绍 Flink底层源码是基于…

python+django+mysql项目实践二(前端及数据库)

python项目实践 环境说明&#xff1a; Pycharm 开发环境 Django 前端 MySQL 数据库 Navicat 数据库管理 前端模板 添加模板 在templates下创建 views文件中添加 创建数据库 连接数据库 在setting文件中进行配置 创建表

Redis持久化机制

AOF 把对redis的写操作记录下来&#xff0c;先执行命令&#xff0c;再执行写入&#xff0c;优势在于&#xff1a; 当然也有风险&#xff1a;丢失和对下一个命令造成阻塞 丢失的原因是执行写操作和记录日志是两个过程 下一个命令造成阻塞的原因是两个过程是同步的 第二个问…

python GUI nicegui初识一(登录界面创建)

最近尝试了python的nicegui库&#xff0c;虽然可能也有一些不足&#xff0c;但个人感觉对于想要开发不过对ui设计感到很麻烦的人来说是很友好的了&#xff0c;毕竟nicegui可以利用TailwindCSS和Quasar进行ui开发&#xff0c;并且也支持定制自己的css样式。 这里记录一下自己利…

【在英伟达nvidia的jetson-orin-nx上使用调试can基础收发-硬件连接-开机自启动can-初步调试】

【在英伟达nvidia的jetson-orin-nx上使用调试can基础收发-硬件连接-开机自启动can-初步调试】 1、概述2、实验环境3、自我学习4-1、实验过程1、硬件原理图焊接连接模块2、输入命令3、测试过程 4-2、目前遗留问题# 1-1、发送可以发送&#xff0c;但是PC发送数据收不到。# 1-2、接…

Maven项目中Lifecycle和Plugins下的install的区别

在Maven中&#xff0c;如果你的web和service在不同的模块下&#xff0c;如果直接用用tomcat插件运行web层&#xff0c;那么运行时会报错 Failed to execute goal org.apache.maven.plugins:maven-install-plugin:2.5.2:install (default-cli) on project springboot: The pack…

ResNet-残差网络二

文章目录 残差结构的一般表达形式残差结构中的信息传播clean path propagation前向传播反向传播 h(x)为恒等映射的重要性h(x)的实验证明 激活层的位置 和其他网络的对比 上一篇讲了 ResNet 论文中的第一篇&#xff1a;Deep Residual Learning for Image Recognition&#xff0c…