大语言模型之十六-基于LongLoRA的长文本上下文微调Llama-2

增加LLM上下文长度可以提升大语言模型在一些任务上的表现,这包括多轮长对话、长文本摘要、视觉-语言Transformer模型的高分辨4k模型的理解力以及代码生成、图像以及音频生成等。

对长上下文场景,在解码阶段,缓存先前token的Key和Value(KV)需要巨大的内存开销,其次主流的LLM模型在推理的时候上下文长度都小于等于训练时的上下文长度。为了约束长文本时缓存先前KV的内存和计算量,很容易想到的方法是对KV进行加窗选择,这样可以限制参与当前token计算的KV历史数量,将内存和计算量约束在可控的范围内。

Llama 2官方支持的标准版模型(下称基座模型)上下文长度是是4k,而Chinese-LLaMA-Alpaca-2基于支持16K上下文,并可通过NTK方法进一步扩展至24K+,在大语言模型之十一 Transformer后继者Retentive Networks (RetNet)博客中知道随着上下文长度的增加,基座模型在训练和推理时Attention模块的内存和算力需求按照文本长度的平方倍增加。

EFFICIENT STREAMING LANGUAGE MODELS WITH ATTENTION SINKS,观察到一个有趣的现象,该论文称之为attention sink,即初始token的KV对于加窗Attention方法的得到的模型性能较为重要(尽管从语义上来看,对于长文本而言初始的token并不重要),StreamingLLM是根据这一思想(保留有限数量的最近KV以及初始的Attention sink)开源的不限上下文长度方法(论文中给的上下文token可长达4百万)。

LongLoRA的paper以Llama-2中给出了文本增加的算力需求增加情况,当上下文长度从2k变为8k的时,模型的自注意力模块计算量将增加16倍。

想要减少长上下文的内存和算力需求,可以从改变Transformer的Attention结构以及优化GPU计算单元这两个方面出发,在充分利用GPU算力基础上(FlashAttention结构),优化Transformer的Attention计算和推理结构(如LongLoRA),此外还有在大语言模型之十一 Transformer后继者Retentive Networks (RetNet)博客中提及微软的RetNet,其将并行计算的Attention改成了RNN结构的Attention结构,既可以在训练时并行又可以在推理的时RNN计算,这使得算力并不会随着上下文长度增加而显著上升,几乎可以做到不限上下文长度。

本篇介绍的基座模型Llama-2 7B将上下文长度从4k扩充到100k,基于FlashAttention和LongLoRA技术,二者利用了GPU和Transformer的Attention结构改进两方面的技术。此外通过finetune来扩展上下文长度还有以下一些方法:

  1. Position Interpolation: 增强版的RoPE将Llama扩展到32K
  2. Focused Transformer使用contrastive learning 训练得到 LongLLaMA
  3. Landmark attention计算量比1和2更高,但是精度损失较大,将长文本压缩成了retrieved tokens
  4. NTK-aware
  5. Yarn
  6. positional Skipping
  7. out-of-distribution related method

StreamingLLM

该方法利用attention sink现象,采用缓存头尾KV的方式,并不需要额外重新训练模型,只需要改变Attention模块的forward计算使用到的历史信息即可,即在基座模型上修改Attention的前向推理即可。该方法的对比原理图如下:

在这里插入图片描述
这是基于LLama-2-13B的测试数据情况,图a是密集注意力模型,即计算当前token时,会缓存先前T个token的KV值,缓存的长度是大于训练时上下文长度的,图b是加窗注意力模型,T-L个token状态将不再缓存,只缓存最近的L个KV值,这使得计算量大大减少,但是PPL却并不是很理想。图c的滑窗方法会重新计算最近L个token缓存的KV值,但是计算复杂度高,耗时长,图d是StreamingLLM中给出的方法,始终保留Attention sink(初始几个token的KV),然后在同保留的最近L个token的KV值一起计算当前token的Attention值,其在计算效率和混淆度上取得了不错的收益。
官方给出了前向推理llama的示例方法,enable_streaming_llm,这一方法并不一定在所有LLM上都能取得该表现,我个人觉得采取这一方法需要较为谨慎。这里分析主要是为了梳理其中的思想方法。
如果对图中每行的意义不是很理解,那么可以参考《大语言模型之四-LlaMA-2从模型到应用》中图三第一个Linear层的“你好!”为例,

  • 第一次输入是你,对应第一行,因为是第一个token,因而先前的缓存的KV没有;
  • 第二次输入是好,对应于第二行,因为是第二个token(深色),浅蓝色则是缓存的“你”输入是的KV,可以依次类推,
  • 第三次输入是!,对应于第三行,两个浅蓝色块是“你好”,从这个下三角可以看出模型是因果的(即当前的输入token只能看到历史的KV)。
    替换Attention模块的forward方法如下:
def enable_llama_pos_shift_attention(model):for name, module in reversed(model._modules.items()):if len(list(module.children())) > 0:enable_llama_pos_shift_attention(module,)if isinstance(module, LlamaAttention):model._modules[name].forward = types.MethodType(llama_pos_shift_attention_forward, model._modules[name])

启用LLAMA模型中的位置偏移注意力机制。它遍历了给定模型的所有模块,如果模块是LlamaAttention类型,则将其前向传递函数替换为llama_pos_shift_attention_forward函数。该函数实现了位置偏移注意力机制,它可以在LLAMA模型中提高性能。

    parser.add_argument("--start_size", type=int, default=4)parser.add_argument("--recent_size", type=int, default=2000)

默认缓存前四个token的KV值,并且缓存最近2000个KV值。

LongLoRA

LongLoRA的paper以Llama-2为实验对象,LongLoRA在Attention层和权重层都节约了算力资源。论文中主要引入了shifted short attention,这个结构和Flash-Attention是兼容的,并且在推理的时候并不需要,论文中将Llama-2的7B模型上下文从4K扩充到了100K,13B基座模型扩充到了64K,70B基座模型上下文扩充到了32K。
LoRA通过低秩矩阵将self-Attention矩阵进行了线性投影(矩阵的低秩分解),即将一个参数量为NxN的矩阵分解为Nxd和dxN的两个矩阵,而这2xd却远小于N,如d取64,而N取4096,这就使得需要训练的参数量大大减少。LoRA方法可行的前提是假设基座模型在迁移的时候本质上是一个低秩(low intrinsic rank)问题,但是这种方法对长上下文情况,单单采用LoRA会导致混淆度(perplexity)较高。在经典的Transformer架构中,LoRA方法通常只改变Attention层权重。

LongLoRA相比LoRA有两点改进,第一点是:类似Attention层一样,embedding和normalization层参数也会参与微调训练调整;第二点是:S2(shift short) Attention机制,这两点的修改显示在paper图中:
请添加图片描述
从右边可以看到LoRA是Self-Attention而言的,而LongLoRA除了有红色的LoRA部分,对Embedding和Norm层都进行参数调整(图中有🔥的部分都是调整的对象),对于算力以及内存的需求对比这里不展示了,有兴趣可以进一步参考论文。
有必要对shift short进行展开一下,假设LLM训练的时候是按照2k上下文进行的,对于一个长度为8k的输入文本token,记为[1,2,…,8192],则在训练的时候会被分为4个组,每个组上下文长度是2k,即:
[1,2,…,2048],[2049,2050,…,4096],[4097,4098,…,6144],[6145,6146,…,8192],这四个组在训练的时候是分开进行的,每个组之间Attention模块并没有进行信息交互,如果按照传统的方式分为n(8192)个组[1,2,…,2048],[2,3,…,2049]…,这样训练的时间又会非常多,S2-Attention就是在这个方式下提出的,方法是按照目标上下文长度的1/2循环移动上下文token,对于上面的8192个token,shift之后为[1025,1026,…,3072],[3073,3074,…,5120],[5121,5122,…,7168],[7169,7170,…,1023,1024],这使得每个shift之后的每个分组有一半是前一个分组,一半是后一个分组,如[1025,1026,…,3072],各有一半在[1,2,…,2048]和[2049,2050,…,4096]。

Deep Vison lab开源了longLoRA的fine-tune代码,fine-tune.py
首先加载模型和tokenizer,这在前几篇博客中反复提及了的,使用的是Huggingface提供的接口。

    # Load model and tokenizermodel = transformers.AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path,config=config,cache_dir=training_args.cache_dir,torch_dtype=torch.bfloat16,)tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path,cache_dir=training_args.cache_dir,model_max_length=training_args.model_max_length,padding_side="right",use_fast=True,)

多线程加载训练数据集

   rank = int(os.environ.get('RANK', -1))if rank > 0:barrier()dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=128, remove_columns=["text", "meta"])if rank == 0:barrier()

接来下是loRA参数配置

        config = LoraConfig(r=8,lora_alpha=16,target_modules=targets,lora_dropout=0,bias="none",task_type="CAUSAL_LM",)model = get_peft_model(model, config)

train是常规的train过程

    trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, train_dataset=dataset["train"],eval_dataset=None,data_collator=data_collator)trainer.train()

S2-Attention的实现和StreamingLLM的方法类似,都是更改基座模型Attention的forward方法。

def replace_llama_attn(use_flash_attn=True, use_full=False, inference=False):if use_flash_attn:cuda_major, cuda_minor = torch.cuda.get_device_capability()if cuda_major < 8:warnings.warn("Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.""ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593")if inference:transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask_inferencetransformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_inferenceelse:transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (_prepare_decoder_attention_mask)transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_flashattn_full if use_full else forward_flashattnelse:transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_noflashattn

其自身实现的方法添加了flashattn方法,这里可以忽略,在计算atten时,多了一个shift操作,修改(就是增加了几行代码)的代码行数并不多。

   # shiftdef shift(qkv, bsz, q_len, group_size, num_heads, head_dim):qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=2)qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)return qkvquery_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim)key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim)value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/150342.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新文件覆盖旧文件还能复原吗,3个方法快速恢复覆盖文件!

iPhone在解压压缩文件时&#xff0c;不小心将同名文件进行了覆盖&#xff0c;怎么撤回&#xff1f; 在使用U盘转移文档时&#xff0c;意外将同名文档进行了替换&#xff0c;怎么恢复&#xff1f; 当误将重名文件进行了替换&#xff0c;如何找回这些被覆盖的旧文件&#xff1f;…

oracle linux8.8上安装oracle 19c集群

1、操作系统版本告警 处理办法&#xff1a;export CV_ASSUME_DISTIDRHEL7.6 2、ssh互信故障 查看ssh版本 [rootdb1 ~]# ssh -V OpenSSH_8.0p1, OpenSSL 1.1.1k FIPS 25 Mar 2021 处理办法-2个节点都需要操作 安装前配置 # mv /usr/bin/scp /usr/bin/scp.orig # echo "…

解决 Jenkins 性能缓慢的问题~转

解决 Jenkins 性能缓慢的问题 Docker中文社区 ​​ 计算机技术与软件专业技术资格持证人 2 人赞同了该文章 没有什么比缓慢的持续集成系统更令人沮丧的了。它减慢了反馈循环并阻止代码快速投入生产。虽然像使用性能更好的服务器可以为您争取时间&#xff0c;但您最终必须投资…

c++day1

#include <iostream> //#预处理 using namespace std; //using :使用命名空间的关键字 //namespace:命名空间的关键字 //std:标准的命名空间//程序入口 int main() {//程序的开始int daxie 0,xiaoxie 0,sum 0, kong 0,other 0;string str;getline(cin , str);for(in…

基于SSM的资源共享平台设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

Zabbix监控系统 第一部分:zabbix服务部署+自定义监控项+自动发现与自动注册(附详细部署实例)

这里是目录 一、Zabbix概述1.1 简介1.2 zabbix组件1.2.1 zabbix server1.2.2 zabbix agent1.2.3 zabbix proxy1.2.4 zabbix get1.2.5 zabbix sender 1.3 工作原理1.4 端口号1.5 zabbix中预设的键值1.6 自定义监控项1.7 邮件报警的思路1.8 Zabbix自动发现和自动注册1.8.1 zabbix…

[图论]哈尔滨工业大学(哈工大 HIT)学习笔记16-22

视频来源&#xff1a;2.7.1 补图_哔哩哔哩_bilibili 目录 1. 补图 1.1. 补图 2. 双图 2.1. 双图定理 3. 图兰定理/托兰定理 4. 极图理论 5. 欧拉图 5.1. 欧拉迹 5.2. 欧拉闭迹 5.3. 欧拉图 5.4. 欧拉定理 5.5. 伪图 1. 补图 1.1. 补图 &#xff08;1&#xff09;…

使用mysql的cmd窗口,运行项目中的mapper层xml里的sql语句,查看运行结果

使用mysql的cmd窗口&#xff0c;运行项目中的mapper层xml里的sql语句&#xff0c;查看运行结果 项目代码或者从控制台复制sql语句从控制台搜索方式 运行效果或者使用idea的console窗口运行查看结果点击进入&#xff0c;查看表结构与字段 其他技巧根据from 表名寻找对应的sql代码…

微服务学习(十):安装Maven

微服务学习&#xff08;十&#xff09;&#xff1a;安装Maven 1、下载Maven 官网下载 2、将下载后的资源包上传到服务器 3、解压资源包并安装 tar -zxvf apache-maven-3.9.5-bin.tar.gz4、配置环境变量 vi /etc/profileexport MAVEN_HOME/home/maven/apache-maven-3.9.5 …

SSL证书是什么?1分钟get

在当今互联网世界中&#xff0c;保护数据的完整性和隐私性至关重要&#xff0c;由此&#xff0c;在网络数据安全保护领域&#xff0c;作为保护网络传输数据安全的SSL证书越来越频繁出现。那么你知道SSL证书是什么&#xff1f;SSL证书有哪些类型&#xff1f;SSL证书有什么用吗&a…

机器学习---RBM、KL散度、DBN

1. RBM 1.1 BM BM是由Hinton和Sejnowski提出的一种随机递归神经网络&#xff0c;可以看做是一种随机生成的 Hopfield网络&#xff0c;是能够通过学习数据的固有内在表示解决困难学习问题的最早的人工神经网络之 一&#xff0c;因样本分布遵循玻尔兹曼分布而命名为BM。BM由二…

基于Springboot实现旧物置换网站平台演示【项目源码+论文说明】分享

基于Springboot实现旧物置换网站平台演示 摘要 随着时代在一步一步在进步&#xff0c;旧物也成人们的烦恼&#xff0c;许多平台网站都在推广自已的产品像天猫、咸鱼、京东。所以开发出一套关于旧物置换网站成为必需。旧物置换网站主要是借助计算机&#xff0c;通过对用户进行管…

JVM上篇之虚拟机与java虚拟机介绍

目录 虚拟机 java虚拟机 简介 特点 作用 位置 整体结构 类装载子系统 运行时数据区 java执行引擎 Java代码执行流程 jvm架构模型 基于栈式架构 基于寄存器架构 总结 jvm的生命周期 1.启动 2.执行 3.退出 JVM的发展历程 虚拟机 所谓虚拟机&#xff0c;指的…

竞赛选题 深度学习 python opencv 动物识别与检测

文章目录 0 前言1 深度学习实现动物识别与检测2 卷积神经网络2.1卷积层2.2 池化层2.3 激活函数2.4 全连接层2.5 使用tensorflow中keras模块实现卷积神经网络 3 YOLOV53.1 网络架构图3.2 输入端3.3 基准网络3.4 Neck网络3.5 Head输出层 4 数据集准备4.1 数据标注简介4.2 数据保存…

外卖小程序源码vs定制开发:何时选择哪种方式?

在数字餐饮行业的蓬勃发展中&#xff0c;外卖应用程序已经成为餐厅和创业者的必备工具。然而&#xff0c;当涉及到开发外卖应用程序时&#xff0c;您会面临一个重要的决策&#xff1a;是使用外卖小程序源码还是进行定制开发&#xff1f;这两种方法各有优势和劣势&#xff0c;取…

vue3+elementPlus el-input的type=“number“时去除右边的上下箭头

改成 代码如下 <script lang"ts" setup> import {ref} from vue const inputBtn ref() </script> <template><el-input type"number" v-model"inputBtn" style"width: 80px;" class"no_number">…

cartographer-(0)-ubuntu(20.04)-环境安装

1.安装 ROS wiki.ros.org 1.1修改镜像源&#xff1a; 到网站上找与操作系统相匹配的镜像源 ubuntu | 镜像站使用帮助 | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror # 默认注释了源码镜像以提高 apt update 速度&#xff0c;如有需要可自行取消注释 deb htt…

Echarts 实现X轴多维效果

效果图 代码参考地址 https://download.csdn.net/download/Frazier1995/88403104

Android Studio 是如何和我们的手机共享剪贴板的

背景 近期完成了target33的项目适配升级,随着AGP和gradle的版本升级,万年老版本Android Studio(后文简称AS)也顺便升级到了最新版Android Studio Giraffe | 2022.3.1,除了新UI外,最让我好奇的是这次的Running Devices功能(官方也称为Device mirroring)可以控制真机了. 按照操…

东哥录了一些课程,你能想到应该都有了

哈喽&#xff0c;大家好&#xff0c;我是hahaCoderX。 我在B站录制了《快速入门C语言程序设计》、《Python3网络爬虫开发实战》、《机器学习实战》以及我的个人图书案例讲解指南等系列课程&#xff0c;目前正在陆续上传开放中&#xff0c;欢迎大家看我的视频&#xff0c;一块学…